概述
2018年google推出了bert模型,这个模型的性能要远超于以前所使用的模型,总的来说就是很牛。但是训练bert模型是异常昂贵的,对于一般人来说并不需要自己单独训练bert,只需要加载预训练模型,就可以完成相应的任务。下面我将以情感分类为例,介绍使用bert的方法。这里与我们之前调用API写代码有所区别,已经有大神将bert封装成.py文件,我们只需要简单修改一下,就可以直接调用这些.py文件了。
官方文档
- tensorflow版:点击传送门
- pytorch版(注意这是一个第三方团队实现的):点击传送门
- 论文:点击传送门
一切以官方论文为准,如果有什么疑问,请仔细阅读官方文档
具体实现
我这里使用的是pytorch版本。
前置需要
- 安装pytorch和tensorflow。
- 安装PyTorch pretrained bert。(pip install pytorch-pretrained-bert)
- 将pytorch-pretrained-BERT提供的文件,整个下载。
- 选择并且下载预训练模型。地址:请点击
注意这里的model是tensorflow版本的,需要进行相应的转换才能在pytorch中使用
模型转换
文档里提供了convert_tf_checkpoint_to_pytorch.py 这个脚本来进行模型转换。使用方法如下:
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch
$BERT_BASE_DIR/bert_model.ckpt
$BERT_BASE_DIR/bert_config.json
$BERT_BASE_DIR/pytorch_model.bin
修改源码
这里是需要实现情感分类。只需要用到run_classifier_dataset_utils.py和run_classifier.py这两个文件。run_classifier_dataset_utils.py是用来处理文本的输入,我们只需要添加一个类用来处理输入即可。
class MyProcessor(DataProcessor):
'''Processor for the sentiment classification data set'''
def get_train_examples(self, data_dir):
"""See base class."""
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
return ["-1", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = line[0]
label = line[1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
train.tsv和dev.tsv分别表示训练集和测试集。记得要在下面的代码加上之前定义的类。
def compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels)
if task_name == "cola":
return {"mcc": matthews_corrcoef(labels, preds)}
elif task_name == "sst-2":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mrpc":
return acc_and_f1(preds, labels)
elif task_name == "sts-b":
return pearson_and_spearman(preds, labels)
elif task_name == "qqp":
return acc_and_f1(preds, labels)
elif task_name == "mnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mnli-mm":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "qnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "rte":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "wnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "my":
return acc_and_f1(preds, labels)
else:
raise KeyError(task_name)
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mnli-mm": MnliMismatchedProcessor,
"mrpc": MrpcProcessor,
"sst-2": Sst2Processor,
"sts-b": StsbProcessor,
"qqp": QqpProcessor,
"qnli": QnliProcessor,
"rte": RteProcessor,
"wnli": WnliProcessor,
"my": MyProcessor
}
output_modes = {
"cola": "classification",
"mnli": "classification",
"mrpc": "classification",
"sst-2": "classification",
"sts-b": "regression",
"qqp": "classification",
"qnli": "classification",
"rte": "classification",
"wnli": "classification",
"my": "classification"
}
运行bert
编辑shell脚本:
#!/bin/bash
export TASK_NAME=my
python run_classifier.py
--task_name $TASK_NAME
--do_train
--do_eval
--do_lower_case
--data_dir /home/garvey/Yuqinfenxi/
--bert_model /home/garvey/uncased_L-12_H-768_A-12
--max_seq_length 410
--train_batch_size 8
--learning_rate 2e-5
--num_train_epochs 3.0
--output_dir /home/garvey/bertmodel
运行即可。这里要注意max_seq_length和train_batch_size这两个参数,设置过大是很容易爆掉显存的,一般来说运行bert需要11G左右的显存。
备注
max_seq_length是指词的数量而不是指字符的数量。参考代码中的注释:
The maximum total input sequence length after WordPiece tokenization. Sequences longer than this will be truncated, and sequences shorter than this will be padded.
对于sequence的理解,网上很多博客都把这个翻译为句子,我个人认为是不准确的,序列是可以包含多个句子的,而不只是单独一个句子。
注意
Bert开源的代码中,只提供了train和dev数据,也就是训练集和验证集。对于评测论文标准数据集的时候,只需要把训练集和测试集送进去就可以得到结果,这一过程是没有调参的(没有验证集),都是使用默认参数。但是如果用Bert来打比赛,注意这个时候的测试集是没有标签的,这就需要在源码中加上一个处理test数据集的部分,并且通过验证集来选择参数。
转载于:https://www.cnblogs.com/mlgjb/p/11158009.html
最后
以上就是平淡蜜粉为你收集整理的使用bert进行情感分类官方文档具体实现备注注意的全部内容,希望文章能够帮你解决使用bert进行情感分类官方文档具体实现备注注意所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复