mindformers.trainer.TranslationTrainer¶
-
class
mindformers.trainer.TranslationTrainer(model_name: str = None)[源代码]¶ Translation Task For Trainer. :param model_name: The model name of Task-Trainer. Default: None :type model_name: str
实际案例
>>> from mindformers.trainer import TranslationTrainer >>> from mindformers import T5ForConditionalGeneration, TranslationTrainer >>> # follow the instruction in t5 section in the README.md and download wmt16 dataset. >>> # change the dataset_files path of configs/t5/wmt16_dataset.yaml >>> trans_trainer = TranslationTrainer(model_name='t5_small') >>> trans_trainer.train() >>> model = T5ForConditionalGeneration.from_pretrained('t5_small') >>> trans_trainer = TranslationTrainer(model_name="t5_small") >>> res = trans_trainer.predict(input_data="translate the English to Romanian: a good boy!", network=model) [{'translation_text': ['hello world']}]
- 引发
NotImplementedError – If train method or evaluate method or predict method not implemented.
-
predict(config: Union[dict, mindformers.tools.register.config.MindFormerConfig, mindformers.trainer.config_args.ConfigArguments, mindformers.trainer.training_args.TrainingArguments, None] = None, input_data: Union[str, list, mindspore.dataset.engine.datasets_user_defined.GeneratorDataset, None] = None, network: Union[mindspore.nn.cell.Cell, mindformers.models.base_model.BaseModel, None] = None, tokenizer: Optional[mindformers.models.base_tokenizer.BaseTokenizer] = None, **kwargs)[源代码]¶ Executes the predict of the trainer.
- 参数
config (Optional[Union[dict, MindFormerConfig, ConfigArguments, TrainingArguments]]) – The task config which is used to configure the dataset, the hyper-parameter, optimizer, etc. It supports config dict or MindFormerConfig or TrainingArguments or ConfigArguments class. Default: None.
input_data (Optional[Union[Tensor, str, list]]) – The predict data. It supports 1) a text string to be translated, 1) a file name where each line is a text to be translated and 3) a generator dataset. Default: None.
network (Optional[Union[Cell, BaseModel]]) – The network for trainer. It supports model name or BaseModel or MindSpore Cell class. Default: None.
tokenizer (Optional[BaseTokenizer]) – The tokenizer for tokenizing the input text. Default: None.
- 返回
A list of prediction.
-
train(config: Union[dict, mindformers.tools.register.config.MindFormerConfig, mindformers.trainer.config_args.ConfigArguments, mindformers.trainer.training_args.TrainingArguments, None] = None, network: Union[mindspore.nn.cell.Cell, mindformers.models.base_model.BaseModel, None] = None, dataset: Union[mindformers.dataset.base_dataset.BaseDataset, mindspore.dataset.engine.datasets_user_defined.GeneratorDataset, None] = None, wrapper: Optional[mindspore.nn.wrap.cell_wrapper.TrainOneStepCell] = None, optimizer: Optional[mindspore.nn.optim.optimizer.Optimizer] = None, callbacks: Union[mindspore.train.callback._callback.Callback, List[mindspore.train.callback._callback.Callback], None] = None, **kwargs)[源代码]¶ Train task for Translation Trainer. This function is used to train or fine-tune the network.
The trainer interface is used to quickly start training for general task. It also allows users to customize the network, optimizer, dataset, wrapper, callback.
- 参数
config (Optional[Union[dict, MindFormerConfig, ConfigArguments, TrainingArguments]]) – The task config which is used to configure the dataset, the hyper-parameter, optimizer, etc. It supports config dict or MindFormerConfig or TrainingArguments or ConfigArguments class. Default: None.
network (Optional[Union[Cell, BaseModel]]) – The network for trainer. It supports model name or BaseModel or MindSpore Cell class. Default: None.
dataset (Optional[Union[BaseDataset, GeneratorDataset]]) – The training dataset. It support real dataset path or BaseDateset class or MindSpore Dataset class. Default: None.
optimizer (Optional[Optimizer]) – The training network’s optimizer. It support Optimizer class of MindSpore. Default: None.
wrapper (Optional[TrainOneStepCell]) – Wraps the network with the optimizer. It support TrainOneStepCell class of MindSpore. Default: None.
callbacks (Optional[Union[Callback, List[Callback]]]) – The training callback function. It support CallBack or CallBack List of MindSpore. Default: None.
- 引发
NotImplementedError – If wrapper not implemented.