mindformers.models.bert.BertForPreTraining¶
-
class
mindformers.models.bert.BertForPreTraining(config={'assessment_method': '', 'attention_probs_dropout_prob': 0.1, 'batch_size': 16, 'checkpoint_name_or_path': '', 'compute_dtype': mindspore.float16, 'dropout_prob': 0.1, 'dtype': mindspore.float32, 'hidden_act': 'gelu', 'hidden_dropout_prob': 0.1, 'hidden_size': 768, 'initializer_range': 0.02, 'intermediate_size': 3072, 'is_training': True, 'layernorm_dtype': mindspore.float32, 'max_position_embeddings': 128, 'model_type': 'bert', 'moe_config': <mindformers.modules.transformer.moe.MoEConfig object>, 'num_attention_heads': 12, 'num_hidden_layers': 12, 'num_labels': 1, 'parallel_config': <mindformers.modules.transformer.transformer.TransformerOpParallelConfig object>, 'post_layernorm_residual': True, 'seq_length': 128, 'softmax_dtype': mindspore.float32, 'type_vocab_size': 2, 'use_one_hot_embeddings': False, 'use_past': False, 'use_relative_positions': False, 'vocab_size': 30522})[源代码]¶ Provide bert pre-training loss through network.
- 参数
config (BertConfig) – The config of BertForPreTraining.
is_training (bool) – Specifies whether to use the training mode.
use_one_hot_embeddings (bool) – Specifies whether to use one-hot for embeddings. Default: False.
- 返回
Tensor, the loss of the network.
实际案例
>>> from mindspore import Tensor >>> import mindspore.common.dtype as mstype >>> from mindformers import BertForPreTraining, BertTokenizer >>> model = BertForPreTraining.from_pretrained('bert_base_uncased') >>> tokenizer = BertTokenizer.from_pretrained('bert_base_uncased') >>> data = tokenizer(["Paris is the [MASK] of France."], ... return_tensors='ms', max_length=128, padding="max_length") >>> input_ids = data['input_ids'] >>> attention_mask = data['attention_mask'] >>> token_type_ids = data['token_type_ids'] >>> masked_lm_positions = Tensor([[4]], mstype.int32) >>> next_sentence_labels = Tensor([[1]], mstype.int32) >>> masked_lm_weights = Tensor([[1]], mstype.int32) >>> masked_lm_ids = Tensor([[3007]], mstype.int32) >>> output = model(input_ids, attention_mask, token_type_ids, next_sentence_labels, ... masked_lm_positions, masked_lm_ids, masked_lm_weights) >>> print(output) [0.6706]