Bert 下游任务微调¶
Bert 模型介绍¶
BERT:全名Bidirectional Encoder Representations from Transformers模型是谷歌在2018年基于Wiki数据集训练的Transformer模型。
论文: J Devlin,et al., Pre-training of Deep Bidirectional Transformers for Language Understanding, 2019
Bert 下游任务微调¶
下面以question_answering任务为例介绍Bert下游任务微调的流程。
数据集
SQuAD v1.1数据集:该数据集包含 10 万个(问题,原文,答案)三元组,原文来自于 536 篇维基百科文章,而问题和答案的构建主要是通过众包的方式,让标注人员提出最多 5 个基于文章内容的问题并提供正确答案,且答案出现在原文中。
下载地址:SQuAD v1.1训练集,SQuAD v1.1验证集
新建名为squad文件夹,将下载的json格式数据集文件放入文件夹中。
└─squad ├─train-v1.1.json └─dev-v1.1.json
初始化question_answering任务trainer
使用mindformers.trainer.Trainer类,初始化question_answering任务的trainer。
from mindformers.trainer import Trainer # 初始化question_answering任务trainer trainer = Trainer(task='question_answering', model='qa_bert_base_uncased', train_dataset='./squad/', eval_dataset='./squad/')
参数含义如下:
task(str) - 任务名称,’question_answering’为问答任务。
model(str) - 模型名称, ‘qa_bert_base_uncased’为Bert接question_answering下游任务模型。
train_dataset(str) - 训练数据集所在路径。
eval_dataset(str) - 评估数据集所在路径。
使用现有的预训练权重进行finetune微调
从obs上下载bert_base_uncased预训练权重,加载预训练权重,并在下游任务qa_bert_base_uncased模型上进行微调。
# 使用现有的预训练权重进行finetune微调 trainer.finetune(finetune_checkpoint="qa_bert_base_uncased")
参数含义如下:
finetune_checkpoint(str) - 权重名称,’qa_bert_base_uncased’为问答任务对应的Bert预训练权重。
训练过程中会实时打印训练时长、Loss等信息。
使用finetune获得的权重进行eval评估
从finetune保存的权重文件中,取最后一次保存的checkpoint文件的权重加载进网络中,并进行评估。
# 使用finetune获得的最新权重进行eval评估 trainer.evaluate(eval_checkpoint=True)
参数含义如下:
eval_checkpoint(bool) - 是否加载最后一次保存的权重进行评估,True表示加载最后一次保存的权重文件中的权重进网络中。
obs上训练好的权重评估结果如下:
INFO - QA Metric = {'QA Metric': {'exact_match': 80.74739829706716, 'f1': 88.33552874684968}}使用finetune获得的权重进行predict推理
从finetune保存的权重文件中,取最后一次保存的checkpoint文件的权重加载进网络中,并进行推理。推理输入的文本包括context和question两部分,两者以短横线“-”为标志分隔开。
# 使用finetune获得的最新权重进行predict推理 # 测试数据,测试数据分为context和question两部分,两者以 “-” 分隔 input_data = ["My name is Wolfgang and I live in Berlin - Where do I live?"] trainer.predict(predict_checkpoint=True, input_data=input_data)
参数含义如下:
predict_checkpoint(bool) - 是否加载最后一次保存的权重进行推理,True表示加载最后一次保存的权重文件中的权重进网络中。
input_data(str) - 输入文本,分为context和question两部分,两者以 “-” 分隔。
得到的输出为:
[{'text': 'Berlin', 'score': 0.9941, 'start': 34, 'end': 40}]