bert 中文基于文本的问答系统
1
2
3
4
5
6
7
8
9
10# -!- coding: utf-8 -!- import torch if torch.cuda.is_available(): device = torch.device("cuda") print('there are %d GPU(s) available.'% torch.cuda.device_count()) print('we will use the GPU: ', torch.cuda.get_device_name(0)) else: print('No GPU availabel, using the CPU instead.') device = torch.device('cpu')
1
2
3
4there are 1 GPU(s) available. we will use the GPU: GeForce GTX 1070
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25import os import time import json import random import datetime import numpy as np from tqdm import tqdm from transformers import AdamW from torch.utils.tensorboard import SummaryWriter from transformers import get_linear_schedule_with_warmup from torch.utils.data import TensorDataset, DataLoader,random_split from transformers import WEIGHTS_NAME, CONFIG_NAME from transformers import ( DataProcessor, BertTokenizer, squad_convert_examples_to_features, BertForQuestionAnswering, ) # 设置随机种子. seed_val = 42 random.seed(seed_val) np.random.seed(seed_val) torch.manual_seed(seed_val) torch.cuda.manual_seed_all(seed_val)
定义DataProcessor
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69class MySquadProcessor(DataProcessor): def get_train_examples(self, data_dir, filename=None): """ Returns the training examples from the data directory. Args: data_dir: Directory containing the data files used for training and evaluating. filename: None by default, specify this if the training file has a different name than the original one which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. """ if data_dir is None: data_dir = "" if self.train_file is None: raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") with open( os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8" ) as reader: input_data = json.load(reader)["data"] return self._create_examples(input_data, "train") def get_dev_examples(self, data_dir, filename=None): """ Returns the evaluation example from the data directory. Args: data_dir: Directory containing the data files used for training and evaluating. filename: None by default, specify this if the evaluation file has a different name than the original one which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. """ if data_dir is None: data_dir = "" if self.dev_file is None: raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") with open( os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8" ) as reader: input_data = json.load(reader)["data"] return self._create_examples(input_data, "dev") def _create_examples(self, input_data, set_type): is_training = set_type == "train" examples = [] for entry in tqdm(input_data): title = entry["title"] for paragraph in entry["paragraphs"]: context_text = paragraph["context"] for qa in paragraph["qas"]: qas_id = qa["id"] question_text = qa["question"] start_position_character = None answer_text = None answers = [] if "is_impossible" in qa: is_impossible = qa["is_impossible"] else: is_impossible = False if not is_impossible: answer = qa["answers"][0] answer_text = qa["answers"][0]["text"] start_position_character = qa["answers"][0]["answer_start"] example = ChineseSquadExample( qas_id=qas_id, question_text=question_text, context_text=context_text, answer_text=answer_text, start_position_character = start_position_character, title=title, is_impossible=is_impossible, answers=answers, ) examples.append(example) return examples
导入json 文件
1
2
3
4class SquadV3Processor(MySquadProcessor): train_file = "train-v2.0.json" dev_file = "dev-v2.0.json"
定义 中文的SquadExample (中文的SquadExample 和 英文的SquadExample不同,所有我们要自己编写)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40class ChineseSquadExample(object): """ A single training/test example for the Squad dataset, as loaded from disk. Args: qas_id: The example's unique identifier question_text: The question string context_text: The context string answer_text: The answer string start_position_character: The character position of the start of the answer title: The title of the example answers: None by default, this is used during evaluation. Holds answers as well as their start positions. is_impossible: False by default, set to True if the example has no possible answer. """ def __init__( self, qas_id, question_text, context_text, answer_text, start_position_character, title, answers=[], is_impossible=False, ): self.qas_id = qas_id self.question_text = question_text self.context_text = context_text.replace(" ","").replace(" ","").replace(" ","") self.answer_text ="" for e in answer_text.replace(" ","").replace(" ","").replace(" ",""): self.answer_text += e self.answer_text +=" " self.answer_text = self.answer_text[0:-1] self.title = title self.is_impossible = is_impossible self.answers = answers self.doc_tokens = [e for e in self.context_text] self.char_to_word_offset = [i for i, e in enumerate(self.context_text)] self.start_position = self.context_text.find(answer_text.replace(" ","").replace(" ","").replace(" ","")) self.end_position = self.start_position + len(answer_text.replace(" ","").replace(" ","").replace(" ",""))
定义计时函数
1
2
3
4
5def format_time(elapsed): elapsed_rounded = int(round((elapsed))) # 返回 hh:mm:ss 形式的时间 return str(datetime.timedelta(seconds=elapsed_rounded))
定义训练函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53def training(train_dataloader, model): t0 = time.time() total_train_loss = 0 total_train_accuracy = 0 model.train() for step, batch in enumerate(train_dataloader): # 每隔40个batch 输出一下所用时间. if step % 40 == 0 and not step == 0: elapsed = format_time(time.time() - t0) print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(train_dataloader), elapsed)) # `batch` 包括5个 tensors: # [0]: input ids # [1]: attention masks # [2]: token_type_ids # [3]: start_positions # [4]: end_positions input_ids = batch[0].to(device) attention_mask = batch[1].to(device) token_type_ids = batch[2].to(device) start_positions = batch[3].to(device) end_positions = batch[4].to(device) # 清空梯度 model.zero_grad() # forward # 参考 https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification loss, start_scores, end_scores = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, start_positions=start_positions, end_positions=end_positions) total_train_loss += loss.item() # backward 更新 gradients. loss.backward() # 减去大于1 的梯度,将其设为 1.0, 以防梯度爆炸. torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # 更新模型参数 optimizer.step() # 更新 learning rate. scheduler.step() # 计算batches的平均损失. avg_train_loss = total_train_loss / len(train_dataloader) print(" 平均训练损失 loss: {0:.2f}".format(avg_train_loss)) return avg_train_loss
定义校验函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30def train_evalution(test_dataloader,model): total_eval_loss = 0 model.eval() for batch in test_dataloader: # `batch` 包括5个 tensors: # [0]: input ids # [1]: attention masks # [2]: token_type_ids # [3]: start_positions # [4]: end_positions input_ids = batch[0].to(device) attention_mask = batch[1].to(device) token_type_ids = batch[2].to(device) start_positions = batch[3].to(device) end_positions = batch[4].to(device) # 在valuation 状态,不更新权值,不改变计算图 with torch.no_grad(): # 参考 https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification loss, start_scores, end_scores = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, start_positions=start_positions, end_positions=end_positions) # 计算 validation loss. total_eval_loss += loss.item() return total_eval_loss,len(test_dataloader)
读数据
1
2
3
4
5
6
7
8
9
10
11
12#if __name__ == '__main__': data_dir = ".//data//" processor = SquadV3Processor() Train_data = processor.get_train_examples(data_dir) Dev_data = processor.get_dev_examples(data_dir) tokenizer = BertTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext') model = BertForQuestionAnswering.from_pretrained('hfl/chinese-roberta-wwm-ext') model.to(device) max_seq_length = 1280 max_query_length = 128
1
2
3
4
5
6
7
8100%|██████████████████████████████████████████████████████████████████████████████| 848/848 [00:00<00:00, 1462.47it/s] 100%|██████████████████████████████████████████████████████████████████████████████| 848/848 [00:00<00:00, 1532.02it/s] Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertForQuestionAnswering: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias'] - This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model). - This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at hfl/chinese-roberta-wwm-ext and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1
2
3
4
5
6
7
8
9
10
11is_training = True Train_features,Train_dataset = squad_convert_examples_to_features( examples=Train_data[0:10], tokenizer=tokenizer, max_seq_length= max_seq_length, doc_stride= True, max_query_length= max_query_length, is_training=is_training, return_dataset='pt', )
1
2
3
4convert squad examples to features: 0%| | 0/10 [00:00<?, ?it/s]
1
2
3
4
5
6
7
8
9
10
11is_training = False Dev_features,Dev_dataset = squad_convert_examples_to_features( examples=Dev_data[0:10], tokenizer=tokenizer, max_seq_length=max_seq_length, doc_stride=True, max_query_length=max_query_length, is_training=is_training, return_dataset='pt', )
设计dataloader
1
2
3train_dataloader = DataLoader(Train_dataset, batch_size=1, shuffle=True) dev_dataloader = DataLoader(Dev_dataset, batch_size=1, shuffle=True)
设置模型参数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18# AdamW 是一个 huggingface library 的类,'W' 是'Weight Decay fix"的意思。 optimizer = AdamW(model.parameters(), lr=2e-5, # args.learning_rate - 默认是 5e-5 eps=1e-8 # args.adam_epsilon - 默认是 1e-8, 是为了防止衰减率分母除到0 ) # bert 推荐 epochs 在2到4之间为好。 epochs = 2 # training steps 的数量: [number of batches] x [number of epochs]. total_steps = len(train_dataloader) * epochs # 设计 learning rate scheduler. scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, # Default value in run_glue.py num_training_steps=total_steps)
训练模型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42output_dir = "./this_model/" output_model_file = os.path.join(output_dir, WEIGHTS_NAME) output_config_file = os.path.join(output_dir, CONFIG_NAME) writer = SummaryWriter("./log_models/") # 设置总时间. total_t0 = time.time() for epoch_i in range(0, epochs): print('Epoch {:} / {:}'.format(epoch_i + 1, epochs)) # ======================================== # training # ======================================== t0 = time.time() avg_train_loss = training(train_dataloader, model) # 计算训练时间. training_time = format_time(time.time() - t0) print(" 训练时间: {:}".format(training_time)) # ======================================== # Validation # ======================================== t0 = time.time() total_eval_loss, valid_dataloader_length = train_evalution(dev_dataloader, model) print("") # 计算batches的平均损失. avg_val_loss = total_eval_loss / valid_dataloader_length # 计算validation 时间. validation_time = format_time(time.time() - t0) print(" 平均测试损失 Loss: {0:.2f}".format(avg_val_loss)) print(" 测试时间: {:}".format(validation_time)) writer.add_scalars(f'Acc/Loss', { 'Training Loss': avg_train_loss, 'Valid Loss': avg_val_loss, }, epoch_i + 1) print("训练一共用了 {:} (h:mm:ss)".format(format_time(time.time() - total_t0))) writer.close() torch.save(model.state_dict(), output_model_file) model.config.to_json_file(output_config_file)
由于本人GPU 只要8个g, 训练这个模型是非常困难,下面是在pycharm上训练的结果。由于训练时间比较短,训练效果不佳。
there are 1 GPU(s) available.
we will use the GPU: GeForce GTX 1070
2020-08-04 09:12:29.425728: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library cudart64_100.dll
100%|██████████| 2402/2402 [00:01<00:00, 1612.06it/s]
100%|██████████| 848/848 [00:00<00:00, 2078.73it/s]
there are 1 GPU(s) available.
we will use the GPU: GeForce GTX 1070
2020-08-04 09:12:47.730465: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library cudart64_100.dll
convert squad examples to features: 100%|██████████| 10137/10137 [05:29<00:00, 30.75it/s]
add example index and unique id: 100%|██████████| 10137/10137 [00:00<00:00, 781775.82it/s]
there are 1 GPU(s) available.
we will use the GPU: GeForce GTX 1070
2020-08-04 09:18:27.213451: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library cudart64_100.dll
convert squad examples to features: 100%|██████████| 3219/3219 [01:46<00:00, 30.22it/s]
add example index and unique id: 100%|██████████| 3219/3219 [00:00<00:00, 807021.19it/s]
Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertForQuestionAnswering: [‘cls.predictions.bias’, ‘cls.predictions.transform.dense.weight’, ‘cls.predictions.transform.dense.bias’, ‘cls.predictions.transform.LayerNorm.weight’, ‘cls.predictions.transform.LayerNorm.bias’, ‘cls.predictions.decoder.weight’, ‘cls.seq_relationship.weight’, ‘cls.seq_relationship.bias’]
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at hfl/chinese-roberta-wwm-ext and are newly initialized: [‘qa_outputs.weight’, ‘qa_outputs.bias’]
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1 / 2
Batch 400 of 10,137. Elapsed: 0:03:02.
Batch 800 of 10,137. Elapsed: 0:06:04.
Batch 1,200 of 10,137. Elapsed: 0:09:06.
Batch 1,600 of 10,137. Elapsed: 0:12:08.
Batch 2,000 of 10,137. Elapsed: 0:15:11.
Batch 2,400 of 10,137. Elapsed: 0:18:14.
Batch 2,800 of 10,137. Elapsed: 0:21:20.
Batch 3,200 of 10,137. Elapsed: 0:25:02.
Batch 3,600 of 10,137. Elapsed: 0:28:10.
Batch 4,000 of 10,137. Elapsed: 0:31:45.
Batch 4,400 of 10,137. Elapsed: 0:34:57.
Batch 4,800 of 10,137. Elapsed: 0:38:36.
Batch 5,200 of 10,137. Elapsed: 0:42:16.
Batch 5,600 of 10,137. Elapsed: 0:45:37.
Batch 6,000 of 10,137. Elapsed: 0:48:55.
Batch 6,400 of 10,137. Elapsed: 0:52:27.
Batch 6,800 of 10,137. Elapsed: 0:56:11.
Batch 7,200 of 10,137. Elapsed: 1:00:02.
Batch 7,600 of 10,137. Elapsed: 1:03:15.
Batch 8,000 of 10,137. Elapsed: 1:06:45.
Batch 8,400 of 10,137. Elapsed: 1:10:21.
Batch 8,800 of 10,137. Elapsed: 1:14:00.
Batch 9,200 of 10,137. Elapsed: 1:17:34.
Batch 9,600 of 10,137. Elapsed: 1:20:38.
Batch 10,000 of 10,137. Elapsed: 1:24:19.
平均训练损失 loss: 2.20
训练时间: 1:25:24
平均测试损失 Loss: 10.32
测试时间: 0:09:32
Epoch 2 / 2
Batch 400 of 10,137. Elapsed: 0:03:05.
Batch 800 of 10,137. Elapsed: 0:06:38.
Batch 1,200 of 10,137. Elapsed: 0:10:13.
Batch 1,600 of 10,137. Elapsed: 0:13:43.
Batch 2,000 of 10,137. Elapsed: 0:17:17.
Batch 2,400 of 10,137. Elapsed: 0:20:23.
Batch 2,800 of 10,137. Elapsed: 0:23:57.
Batch 3,200 of 10,137. Elapsed: 0:27:31.
Batch 3,600 of 10,137. Elapsed: 0:31:13.
Batch 4,000 of 10,137. Elapsed: 0:34:53.
Batch 4,400 of 10,137. Elapsed: 0:38:35.
Batch 4,800 of 10,137. Elapsed: 0:42:10.
Batch 5,200 of 10,137. Elapsed: 0:45:12.
Batch 5,600 of 10,137. Elapsed: 0:48:46.
Batch 6,000 of 10,137. Elapsed: 0:52:18.
Batch 6,400 of 10,137. Elapsed: 0:55:53.
Batch 6,800 of 10,137. Elapsed: 0:59:29.
Batch 7,200 of 10,137. Elapsed: 1:02:33.
Batch 7,600 of 10,137. Elapsed: 1:06:07.
Batch 8,000 of 10,137. Elapsed: 1:09:41.
Batch 8,400 of 10,137. Elapsed: 1:13:13.
Batch 8,800 of 10,137. Elapsed: 1:16:54.
Batch 9,200 of 10,137. Elapsed: 1:20:38.
Batch 9,600 of 10,137. Elapsed: 1:24:29.
Batch 10,000 of 10,137. Elapsed: 1:28:17.
平均训练损失 loss: 1.36
训练时间: 1:29:22
平均测试损失 Loss: 10.24
测试时间: 0:09:05
训练一共用了 3:13:23 (h:mm:ss)
测试一下
1
2
3model.load_state_dict(torch.load('roberta_models/pytorch_model.bin')) model.to(device)
1
2<All keys matched successfully>
1
2
3
4
5
6
7
8
9
10
11
12context = "株洲北站全称广州铁路(集团)公司株洲北火车站。除站场主体,另外管辖湘潭站、湘潭东站和三个卫星站,田心站、白马垅站、十里冲站,以及原株洲车站货房。车站办理编组、客运、货运业务。车站机关地址:湖南省株洲市石峰区北站路236号,邮编412001。株洲北站位于湖南省株洲市区东北部,地处中南路网,是京广铁路、沪昆铁路两大铁路干线的交汇处,属双向纵列式三级七场路网性编组站。车站等级为特等站,按技术作业性质为编组站,按业务性质为客货运站,是株洲铁路枢纽的主要组成部分,主要办理京广、沪昆两大干线四个方向货物列车的到发、解编作业以及各方向旅客列车的通过作业。每天办理大量的中转车流作业,并有大量的本地车流产生和集散,在路网车流的组织中占有十分重要的地位,是沟通华东、华南、西南和北方的交通要道,任务艰巨,作业繁忙。此外,株洲北站还有连接石峰区喻家坪工业站的专用线。株洲北站的前身是田心车站。" qestion = "株洲北站的机关地址是什么" inputs = tokenizer(context, qestion, return_tensors="pt").to(device) tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) start_scores, end_scores = model(**inputs) answer_start = torch.argmax(start_scores) answer_end = torch.argmax(end_scores) answer =tokens[answer_start:answer_end] str="" print(str.join(answer)) print("标准答案:湖南省株洲市石峰区北站路236号,邮编412001。" )
1
2
3车站机关地址:湖南省株洲市石峰区北站路236 标准答案:湖南省株洲市石峰区北站路236号,邮编412001。
1
2
3
4
5
6
7
8
9
10
11
12context = "地方税务局是一个泛称,是中华人民共和国1994年分税制改革的结果。1994年分税制把税种分为中央税、地方税、中央地方共享税;把征税系统由税务局分为国家税务系统与地方税务系统。其中中央税、中央地方共享税由国税系统(包括国家税务总局及各地的国家税务局)征收,地方税由地方税务局征收。地方税务局在省、市、县、区各级地方政府中设置,国务院中没有地方税务局。地税局长由本级人民政府任免,但要征求上级国家税务局的意见。一般情况下,地方税务局与财政厅(局)是分立的,不是一个机构两块牌子。但也有例外,例如,上海市在2008年政府机构改革之前,上海市财政局、上海市地方税务局和上海市国家税务局为合署办公,一个机构、三块牌子,而2008年政府机构改革之后,上海市财政局被独立设置,上海市地方税务局和上海市国家税务局仍为合署办公,一个机构、两块牌子。同时县一级,财政局长常常兼任地税局长。地方税务局主要征收:营业税、企业所得税、个人所得税、土地增值税、城镇土地使用税、城市维护建设税、房产税、城市房地产税、车船使用税、车辆使用牌照税、屠宰税、资源税、固定资产投资方向调节税、印花税、农业税、农业特产税、契税、耕地占用税、筵席税,城市集体服务事业费、文化事业建设费、教育费附加以及地方税的滞补罚收入和外商投资企业土地使用费。" qestion = "地方税务局是中华人民共和国哪一年分税制改革的结果" inputs = tokenizer(context, qestion, return_tensors="pt").to(device) tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) start_scores, end_scores = model(**inputs) answer_start = torch.argmax(start_scores) answer_end = torch.argmax(end_scores) answer =tokens[answer_start:answer_end] str="" print(str.join(answer)) print("标准答案:19" )
1
2
3不是一个机构两块牌 标准答案:19
1
2
3
4
5
6
7
8
9
10
11
12context = "萤火虫工作室是一家总部设在英国伦敦和康涅狄格州坎顿,并在苏格兰阿伯丁设有质量部门的电子游戏开发商。1999年8月,西蒙·布雷德伯里,埃里克·乌列特和大卫·莱斯特成立萤火虫工作室,一起开发了很多游戏,包括非常成功的“凯撒” 和“王国霸主”系列。公司成立后,萤火虫工作室发布了一个未来前景规划:"“萤火虫工作室要创造一个人们游戏其中的引人瞩目的新世界。我们要提供一个丰富多彩的游戏环境,令玩家在我们的图像和编码技术不断提升的游戏世界中感到愉快。我们的专长是在游戏中开发战略,而我们今后要继续发展,与我们精彩的视觉效果,引人瞩目的人物和易于上手的特点相结合。如果我们能这样完成工作,玩家将会发现一个自己创造的,加进自己个性的世界”"。该公司将市场定位于PC(Windows)和苹果电脑上的即时战略游戏领域,特别是公司成功的“要塞”系列。目前,他们正在开发PC和Xbox360上的次时代游戏。" qestion = "萤火虫工作室的总部设在哪里" inputs = tokenizer(context, qestion, return_tensors="pt").to(device) tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) start_scores, end_scores = model(**inputs) answer_start = torch.argmax(start_scores) answer_end = torch.argmax(end_scores) answer =tokens[answer_start:answer_end] str="" print(str.join(answer)) print("标准答案:英国伦敦和康涅狄格州坎顿。" )
1
2
3英国伦敦和康涅狄格州坎 标准答案:英国伦敦和康涅狄格州坎顿。
由此可见 我们的模型只是训练了2个epoch,模型答案就和标准答案十分接近了
1
2
最后
以上就是细心身影最近收集整理的关于bert 中文基于文本的问答系统的全部内容,更多相关bert内容请搜索靠谱客的其他文章。
发表评论 取消回复