mindformers.trainer.TranslationTrainer¶
- class mindformers.trainer.TranslationTrainer(model_name: Optional[str] = None)[源代码]¶
Translation Task For Trainer. Args:
model_name (str): The model name of Task-Trainer. Default: None
- Examples:
>>> from mindformers.trainer import TranslationTrainer >>> from mindformers import T5ForConditionalGeneration, TranslationTrainer >>> # follow the instruction in t5 section in the README.md and download wmt16 dataset. >>> # change the dataset_files path of configs/t5/wmt16_dataset.yaml >>> trans_trainer = TranslationTrainer(model_name='t5_small') >>> trans_trainer.train() >>> model = T5ForConditionalGeneration.from_pretrained('t5_small') >>> trans_trainer = TranslationTrainer(model_name="t5_small") >>> res = trans_trainer.predict(input_data="translate the English to Romanian: a good boy!", network=model) [{'translation_text': ['hello world']}]
- Raises:
NotImplementedError: If train method or evaluate method or predict method not implemented.
- predict(config: Optional[Union[dict, MindFormerConfig, ConfigArguments, TrainingArguments]] = None, input_data: Optional[Union[str, list, GeneratorDataset]] = None, network: Optional[Union[Cell, BaseModel]] = None, tokenizer: Optional[BaseTokenizer] = None, **kwargs)[源代码]¶
Executes the predict of the trainer.
- Args:
- config (Optional[Union[dict, MindFormerConfig, ConfigArguments, TrainingArguments]]):
The task config which is used to configure the dataset, the hyper-parameter, optimizer, etc. It supports config dict or MindFormerConfig or TrainingArguments or ConfigArguments class. Default: None.
- input_data (Optional[Union[Tensor, str, list]]): The predict data. It supports 1) a text string to be
translated, 1) a file name where each line is a text to be translated and 3) a generator dataset. Default: None.
- network (Optional[Union[Cell, BaseModel]]): The network for trainer.
It supports model name or BaseModel or MindSpore Cell class. Default: None.
- tokenizer (Optional[BaseTokenizer]): The tokenizer for tokenizing the input text.
Default: None.
- Returns:
A list of prediction.
- train(config: Optional[Union[dict, MindFormerConfig, ConfigArguments, TrainingArguments]] = None, network: Optional[Union[Cell, BaseModel]] = None, dataset: Optional[Union[BaseDataset, GeneratorDataset]] = None, wrapper: Optional[TrainOneStepCell] = None, optimizer: Optional[Optimizer] = None, callbacks: Optional[Union[Callback, List[Callback]]] = None, **kwargs)[源代码]¶
Train task for Translation Trainer. This function is used to train or fine-tune the network.
The trainer interface is used to quickly start training for general task. It also allows users to customize the network, optimizer, dataset, wrapper, callback.
- Args:
- config (Optional[Union[dict, MindFormerConfig, ConfigArguments, TrainingArguments]]):
The task config which is used to configure the dataset, the hyper-parameter, optimizer, etc. It supports config dict or MindFormerConfig or TrainingArguments or ConfigArguments class. Default: None.
- network (Optional[Union[Cell, BaseModel]]): The network for trainer.
It supports model name or BaseModel or MindSpore Cell class. Default: None.
- dataset (Optional[Union[BaseDataset, GeneratorDataset]]): The training dataset.
It support real dataset path or BaseDateset class or MindSpore Dataset class. Default: None.
- optimizer (Optional[Optimizer]): The training network’s optimizer. It support Optimizer class of MindSpore.
Default: None.
- wrapper (Optional[TrainOneStepCell]): Wraps the network with the optimizer.
It support TrainOneStepCell class of MindSpore. Default: None.
- callbacks (Optional[Union[Callback, List[Callback]]]): The training callback function.
It support CallBack or CallBack List of MindSpore. Default: None.
- Raises:
NotImplementedError: If wrapper not implemented.