mindformers.models.gpt2.gpt2 源代码

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""GPT model"""
import copy
import numpy as np

import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
from mindspore.ops import operations as P
from mindspore.ops import functional as F

from mindformers.modules.transformer.moe import default_moe_config
from mindformers.modules.layers import LayerNorm, Dropout
from mindformers.core.loss import CrossEntropyLoss
from mindformers.modules.transformer import AttentionMask, VocabEmbedding
from mindformers.tools.register import MindFormerRegister, MindFormerModuleType
from mindformers.models.base_model import BaseModel
from mindformers.mindformer_book import MindFormerBook
from mindformers.tools.logger import logger
from mindformers.pet import LoraAdapter, PetAdapter
from .gpt2_config import GPT2Config
from .gpt_modules import GPTTransformerDecoderLayer

__all__ = ['GPT2LMHeadModel', 'GPT2WithLora']


[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS) class GPT2LMHeadModel(BaseModel): r""" Provide gpt training loss or logits through network. Args: config (GPT2Config): The config of Gpt2Model. Returns: Tensor, the loss or logits of the network. """ _support_list = MindFormerBook.get_model_support_list()['gpt2'] def __init__(self, config: GPT2Config = None): config = config if config is not None else GPT2Config() super(GPT2LMHeadModel, self).__init__(config, auto_prefix=True) self.eos_token_id = self.config.eos_token_id parallel_config = self.config.parallel_config self.stridedslice = P.StridedSlice().shard(((parallel_config.data_parallel, 1),)) self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ())) self.get_attention_mask = AttentionMask(seq_length=config.seq_length, parallel_config=parallel_config.dp_mp_config) self.backbone = GPT2Model(config) self.head = GPTHead(hidden_size=config.hidden_size, vocab_size=config.vocab_size, parallel_config=self.config.parallel_config) if parallel_config.pipeline_stage > 1: self.head.pipeline_stage = parallel_config.pipeline_stage - 1 self.backbone.embedding.word_embedding.embedding_table.add_pipeline_stage(self.head.pipeline_stage) mp = config.parallel_config.model_parallel vocab_size = config.vocab_size loss_parallel_config = copy.deepcopy(parallel_config) if vocab_size % mp != 0: logger.warning("The vocab size of GPT Loss is: %s, it is not divide by model_parallel: %s", vocab_size, mp) logger.warning("Now, the model_parallel num of GPT Loss will be changed: mp = 1") loss_parallel_config.model_parallel = 1 self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config) self.reshape = P.Reshape() self.cast = P.Cast() self.load_checkpoint(config) self.add = P.Add().shard(((parallel_config.data_parallel, 1), ())) def construct(self, input_ids, input_mask=None): r""" construct function for Language Modeling Args: input_ids (Tensor): the indices of input sequence tokens in the vocabulary. input_mask (Tensor): input sentences padding mask, where 0 indicates padding position. Returns: logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits, otherwise, return the computed loss. """ batch_size, seq_length = input_ids.shape if self.phase == "train": tokens = self.stridedslice(input_ids, (0, 0), (batch_size, seq_length - 1), (1, 1)) if input_mask is not None: input_mask = self.stridedslice(input_mask, (0, 0), (batch_size, seq_length - 1), (1, 1)) else: input_mask = self.not_equal(tokens, self.eos_token_id) else: tokens = input_ids if input_mask is None: input_mask = self.not_equal(tokens, self.eos_token_id) input_mask = self.cast(input_mask, mstype.float32) attention_mask = self.get_attention_mask(input_mask) # [batch_size, seq_length, vocab_size] output_states, embedding_table = self.backbone(tokens, attention_mask) logits = self.head(output_states, embedding_table) if self.phase != 'train': logits = self.reshape(logits, (batch_size, seq_length, -1)) # makes cast effective to avoid allgather issue in Mindspore1.10 input_mask = self.add(input_mask, 1) return logits, tokens, input_mask labels = self.stridedslice(input_ids, (0, 1), (batch_size, seq_length), (1, 1)) labels = self.reshape(labels, (-1,)) input_mask = self.reshape(input_mask, (-1,)) loss = self.loss(logits, labels, input_mask) return loss
class GPTEmbeddingLayer(nn.Cell): r"""The Embedding Layer of GPT-2 network.""" def __init__(self, config: GPT2Config = None): super(GPTEmbeddingLayer, self).__init__() parallel_config = copy.deepcopy(config.parallel_config) embedding_mp = config.parallel_config.embedding_dp_mp_config.model_parallel vocab_size = config.vocab_size if vocab_size % embedding_mp != 0: logger.warning("The vocab size of embedding layer is: %s, it is not divide by model_parallel: %s", vocab_size, embedding_mp) logger.warning("Now, model_parallel will be changed: mp = 1") parallel_config.embedding_dp_mp_config.model_parallel = 1 self.word_embedding = VocabEmbedding(vocab_size=vocab_size, embedding_size=config.hidden_size, param_init=initializer('normal', [vocab_size, config.hidden_size], dtype=mstype.float32), parallel_config=parallel_config.embedding_dp_mp_config) new_parallel_config = copy.deepcopy(parallel_config) new_parallel_config.vocab_emb_dp = True self.position_embedding = VocabEmbedding(vocab_size=config.seq_length, embedding_size=config.hidden_size, param_init=initializer('normal', [config.seq_length, config.hidden_size], dtype=mstype.float32), parallel_config=new_parallel_config.embedding_dp_mp_config) self.add = P.Add().shard( ((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) self.dropout = Dropout(1 - config.embedding_dropout_prob) self.dropout.shard(((parallel_config.data_parallel, 1, 1),)) def construct(self, input_ids, input_position): """The forward compute of Embedding Layer.""" word_embedding, word_table = self.word_embedding(input_ids) position_embedding, _ = self.position_embedding(input_position) embedding = self.add(word_embedding, position_embedding) embedding = self.dropout(embedding) return embedding, word_table def set_parallel_configure_for_layer(network, layer_id, offset, parallel_config, layers): r""" Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`. Args: network(Cell) - Represents the transformer block parallel_config(dict) - Parallel Config layer_id(int) - Means the layer index for the current module, counts from zero. offset(int) - Means the layer_index needs a offset, if there are other modules in the net. layers(int) - The total layers used for the model. """ pp_dis = max(int((layers + 1) / parallel_config.pipeline_stage), 1) pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1) network.pipeline_stage = pp_id # Used for optimizer's fusion tag dis = max(int((layers + 1) / parallel_config.gradient_aggregation_group), 1) if parallel_config.pipeline_stage > 1: network.set_comm_fusion(2) else: network.set_comm_fusion(int((layer_id + offset) / dis) + 1) if isinstance(parallel_config.recompute, bool): if parallel_config.recompute: network.recompute() else: if parallel_config.recompute.recompute: network.recompute(recompute_slice_activation=parallel_config.recompute.recompute_slice_activation)
[文档]class GPT2Model(nn.Cell): """ The backbone of GPT network Args: config(GPT2Config): the config of network Inputs: input_ids: the tokenized inputs with datatype int32 input_mask: the mask indicating whether each position is a valid input Returns: output_state: Tensor, the output logit of backbone present_layer: Tensor, the current feature map embedding_table: Tensor, the embedding table for the vocabulary """ def __init__(self, config): super(GPT2Model, self).__init__() self.embedding = GPTEmbeddingLayer(config) self.embedding.pipeline_stage = 0 self.layernorm = LayerNorm((config.hidden_size,)).to_float(config.layernorm_compute_type) if config.parallel_config.pipeline_stage > 1: self.layernorm.set_comm_fusion(2) else: self.layernorm.set_comm_fusion(config.parallel_config.gradient_aggregation_group) self.layernorm.shard(((config.parallel_config.data_parallel, 1),)) self.layernorm.pipeline_stage = config.parallel_config.pipeline_stage - 1 if not hasattr(config.parallel_config, "moe_config"): config.parallel_config.moe_config = default_moe_config moe_config = config.parallel_config.moe_config self.blocks = nn.CellList() for i in range(config.num_layers): block = GPTTransformerDecoderLayer( hidden_size=config.hidden_size, batch_size=config.batch_size, ffn_hidden_size=config.hidden_size * config.expand_ratio, seq_length=config.seq_length, num_heads=config.num_heads, attention_dropout_rate=config.attention_dropout_rate, hidden_dropout_rate=config.hidden_dropout_rate, hidden_act=config.hidden_act, param_init_type=config.param_init_type, layernorm_compute_type=config.layernorm_compute_type, softmax_compute_type=config.softmax_compute_type, parallel_config=config.parallel_config.dp_mp_config, moe_config=moe_config) set_parallel_configure_for_layer( block, layer_id=i, layers=config.num_layers, offset=0, parallel_config=config.parallel_config) self.blocks.append(block) self.cast = P.Cast() self.tile = P.Tile().shard(((config.parallel_config.data_parallel,),)) self.dtype = mstype.float16 self.num_layers = config.num_layers self.input_position = Tensor(np.arange(config.seq_length), mstype.int32) def construct(self, input_ids, attention_mask): """GPT model""" batch_size, seq_length = F.shape(input_ids) if batch_size == 1: input_position = F.reshape(self.input_position, (1, seq_length)) else: input_position = self.tile(self.input_position, (batch_size, 1)) input_embedding, embedding_table = self.embedding(input_ids, input_position) hidden_states = self.cast(input_embedding, self.dtype) hidden_shape = F.shape(hidden_states) hidden_states = F.reshape(hidden_states, (-1, hidden_shape[-1])) for i in range(self.num_layers): hidden_states = self.blocks[i](hidden_states, attention_mask) output_state = self.layernorm(hidden_states) return output_state, embedding_table
class GPTHead(nn.Cell): r"""Head for GPT to get the logits of each token in the vocab.""" def __init__(self, hidden_size, vocab_size, compute_type=mstype.float16, parallel_config=None): super().__init__() copied_parallel_config = copy.deepcopy(parallel_config) mp = copied_parallel_config.model_parallel if vocab_size % mp != 0: logger.warning("The vocab size of GPTHead MatMul is: %s, it is not divide by model_parallel: %s", vocab_size, mp) logger.warning("Now, the model_parallel num of GPTHead MatMul will be changed: mp = 1") copied_parallel_config.model_parallel = 1 if copied_parallel_config.pipeline_stage > 1: copied_parallel_config.vocab_emb_dp = False if copied_parallel_config.vocab_emb_dp: self.matmul = P.MatMul(transpose_b=True).shard(((copied_parallel_config.data_parallel, 1), (1, 1))) else: self.matmul = P.MatMul(transpose_b=True).shard(((copied_parallel_config.data_parallel, 1), ( copied_parallel_config.model_parallel, 1))) self.hidden_size = hidden_size self.dtype = compute_type self.cast = P.Cast() def construct(self, state, embedding_table): logits = self.matmul(self.cast(state, self.dtype), self.cast(embedding_table, self.dtype)) return logits @MindFormerRegister.register(MindFormerModuleType.MODELS) class GPT2WithLora(GPT2LMHeadModel): """ GPT2LMHeadModel with LoRA parameter-efficient tuning algorithm Args: config (GPT2Config): The config of Gpt2Model. Returns: Tensor, the loss or logits of the network. """ def __init__(self, config: GPT2Config = None): checkpoint_name_or_path = config.pop("checkpoint_name_or_path") super().__init__(config) config.pet_config.reg_rules = r'.*dense1.*|.*dense3.*' self.backbone = LoraAdapter.get_pet_model(self.backbone, config.pet_config) config.checkpoint_name_or_path = checkpoint_name_or_path self.load_checkpoint(config) # freeze pretrained model PetAdapter.freeze_pretrained_model(self, config.pet_config.pet_type) class CrossEntropyCalculationWithMask(nn.Cell): """ Cross Entropy loss """ def __init__(self, is_training=None, num_labels=None): super(CrossEntropyCalculationWithMask, self).__init__() self.onehot = P.OneHot() self.on_value = Tensor(1.0, mstype.float32) self.off_value = Tensor(0.0, mstype.float32) self.reduce_sum = P.ReduceSum() self.reduce_mean = P.ReduceMean() self.reshape = P.Reshape() self.last_idx = (-1,) self.neg = P.Neg() self.cast = P.Cast() self.is_training = is_training self.num_labels = num_labels self.log_softmax = P.LogSoftmax(axis=-1) def construct(self, logits, label_ids, input_mask=None): """ Calculate loss Args: logits (Tensor): the probability distribution over vocabulary. label_ids (Tensor): the indices of input sequence tokens in the vocabulary. input_mask (Tensor): input sentences padding mask, where 0 indicates padding position. Returns: return_value (Tensor, mstype.float32): if is_training is False, directly return the logits, otherwise, return the computed loss. """ # logits [batch * (seq_length-1), vocab_size] label_ids [batch, seq_length-1] logits = self.log_softmax(logits) if self.is_training: label_ids = self.reshape(label_ids, self.last_idx) # label_ids [batch * (seq_length-1)] one_hot_labels = self.onehot(label_ids, self.num_labels, self.on_value, self.off_value) # [batch * (seq_length-1), vocab_size] per_example_loss = self.neg( self.reduce_sum(one_hot_labels * logits, self.last_idx)) # [batch * (seq_length-1)] # for PPL calculation in evaluation if input_mask is not None: input_mask = self.cast(self.reshape(input_mask, self.last_idx), mstype.float32) # [batch * (seq_length-1)] valid_loss_sum = self.reduce_sum(input_mask * per_example_loss, ()) valid_element_sum = self.reduce_sum(input_mask, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32) loss = valid_loss_sum / valid_element_sum else: loss = self.reduce_mean(per_example_loss, self.last_idx) # a number return_value = self.cast(loss, mstype.float32) else: return_value = logits * 1.0 # [batch * (seq_length-1), vocab_size] return return_value