断点续训

介绍

在长时间的训练当中,如果遇到意外情况导致训练中断,可以使用断点续训的方式恢复之前的状态继续训练。

恢复训练时,从之前训练的checkpoint文件或文件夹中,加载网络权重,并恢复训练时的epoch、step、loss_scale_value等信息。

局限:由于数据集暂不支持跳过已训练的数据,当前仅支持epoch级的断点续训,即仅支持恢复到某个epoch开始的状态。

使用

脚本启动场景

在run_xxx.yaml中配置load_checkpoint为checkpoint文件或文件夹路径,并将resume_training改为True。

load_checkpoint: checkpoint_file_or_dir_path
resume_training: True

注意:如果load_checkpoint配置为文件路径,则认为是完整权重,如果配置为文件夹,则认为是分布式权重。下同。

Trainer高阶接口启动场景

  • 方式1:在Trainer.train()或Trainer.finetune()中配置

在Trainer.train()或Trainer.finetune()中,配置train_checkpoint或finetune_checkpoint参数为checkpoint文件或文件夹路径,并将resume_training参数设置为True。

from mindformers import Trainer

cls_trainer = Trainer(task='image_classification', # 已支持的任务名
                      model='vit_base_p16', # 已支持的模型名
                      train_dataset="/data/imageNet-1k/train", # 传入标准的训练数据集路径,默认支持ImageNet数据集格式
                      eval_dataset="/data/imageNet-1k/val") # 传入标准的评估数据集路径,默认支持ImageNet数据集格式

cls_trainer.train(train_checkpoint="", resume_training=True) # 启动训练
cls_trainer.finetune(finetune_checkpoint="", resume_training=True) # 启动微调
  • 方式2:在TrainingArguments中配置

在TrainingArguments中配置resume_from_checkpoint为checkpoint文件或文件夹路径,并将resume_training参数设置为True。

from mindformers import Trainer, TrainingArguments

training_args = TrainingArguments(resume_from_checkpoint="", resume_training=True)

cls_trainer = Trainer(task='image_classification', # 已支持的任务名
                      args=training_args,
                      model='vit_base_p16', # 已支持的模型名
                      train_dataset="/data/imageNet-1k/train", # 传入标准的训练数据集路径,默认支持ImageNet数据集格式
                      eval_dataset="/data/imageNet-1k/val") # 传入标准的评估数据集路径,默认支持ImageNet数据集格式

cls_trainer.train() # 启动训练
cls_trainer.finetune() # 启动微调