mindformers.trainer.MaskedLanguageModelingTrainer

class mindformers.trainer.MaskedLanguageModelingTrainer(model_name: str = None)[源代码]

MaskedLanguageModeling Task For Trainer. :param model_name: The model name of Task-Trainer. Default: None :type model_name: str

实际案例

>>> from mindformers import MaskedLanguageModelingTrainer
>>> def generator():
>>>     data = np.random.randint(low=0, high=15, size=(128,)).astype(np.int32)
>>>     input_mask = np.ones_like(data)
>>>     token_type_id = np.zeros_like(data)
>>>     next_sentence_lables = np.array([1]).astype(np.int32)
>>>     masked_lm_positions = np.array([1, 2]).astype(np.int32)
>>>     masked_lm_ids = np.array([1, 2]).astype(np.int32)
>>>     masked_lm_weights = np.ones_like(masked_lm_ids)
>>>     train_data = (data, input_mask, token_type_id, next_sentence_lables,
...                   masked_lm_positions, masked_lm_ids, masked_lm_weights)
>>>     for _ in range(512):
...         yield train_data
>>> dataset = GeneratorDataset(generator, column_names=["input_ids", "input_mask", "segment_ids",
...                                                     "next_sentence_labels", "masked_lm_positions",
...                                                     "masked_lm_ids", "masked_lm_weights"])
>>> dataset = dataset.batch(batch_size=16)
>>> mlm_trainer = MaskedLanguageModelingTrainer(model_name="bert_tiny_uncased")
>>> mlm_trainer.train(dataset=dataset)
>>> res = mlm_trainer.predict(input_data = "hello world [MASK]")
引发

NotImplementedError – If train method or evaluate method or predict method not implemented.

predict(config: Union[dict, mindformers.tools.register.config.MindFormerConfig, mindformers.trainer.config_args.ConfigArguments, mindformers.trainer.training_args.TrainingArguments, None] = None, input_data: Union[str, list, None] = None, network: Union[str, mindformers.models.base_model.BaseModel, None] = None, tokenizer: Optional[mindformers.models.base_tokenizer.BaseTokenizer] = None, **kwargs)[源代码]

Executes the predict of the trainer.

参数
  • config (Optional[Union[dict, MindFormerConfig, ConfigArguments, TrainingArguments]]) – The task config which is used to configure the dataset, the hyper-parameter, optimizer, etc. It supports config dict or MindFormerConfig or TrainingArguments or ConfigArguments class. Default: None.

  • input_data (Optional[Union[Tensor, str, list]]) – The predict data. Default: None.

  • network (Optional[Union[str, BaseModel]]) – The network for trainer. It support model name supported or BaseModel class. Supported model name can refer to model support list. For . Default: None.

  • tokenizer (Optional[BaseTokenizer]) – The tokenizer for tokenizing the input text. Default: None.

返回

A list of prediction.

train(config: Union[dict, mindformers.tools.register.config.MindFormerConfig, mindformers.trainer.config_args.ConfigArguments, mindformers.trainer.training_args.TrainingArguments, None] = None, network: Union[mindspore.nn.cell.Cell, mindformers.models.base_model.BaseModel, None] = None, dataset: Union[mindformers.dataset.base_dataset.BaseDataset, mindspore.dataset.engine.datasets_user_defined.GeneratorDataset, None] = None, wrapper: Optional[mindspore.nn.wrap.cell_wrapper.TrainOneStepCell] = None, optimizer: Optional[mindspore.nn.optim.optimizer.Optimizer] = None, callbacks: Union[mindspore.train.callback._callback.Callback, List[mindspore.train.callback._callback.Callback], None] = None, **kwargs)[源代码]

Train task for MaskedLanguageModeling Trainer. This function is used to train or fine-tune the network.

The trainer interface is used to quickly start training for general task. It also allows users to customize the network, optimizer, dataset, wrapper, callback.

参数
  • config (Optional[Union[dict, MindFormerConfig, ConfigArguments, TrainingArguments]]) – The task config which is used to configure the dataset, the hyper-parameter, optimizer, etc. It supports config dict or MindFormerConfig or TrainingArguments or ConfigArguments class. Default: None.

  • network (Optional[Union[Cell, BaseModel]]) – The network for trainer. It supports model name or BaseModel or MindSpore Cell class. Default: None.

  • dataset (Optional[Union[BaseDataset, GeneratorDataset]]) – The training dataset. It support real dataset path or BaseDateset class or MindSpore Dataset class. Default: None.

  • optimizer (Optional[Optimizer]) – The training network’s optimizer. It support Optimizer class of MindSpore. Default: None.

  • wrapper (Optional[TrainOneStepCell]) – Wraps the network with the optimizer. It support TrainOneStepCell class of MindSpore. Default: None.

  • callbacks (Optional[Union[Callback, List[Callback]]]) – The training callback function. It support CallBack or CallBack List of MindSpore. Default: None.

引发

NotImplementedError – If wrapper not implemented.