mindformers.models.llama.llama 源代码

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LLaMA models' APIs."""

import mindspore.common.dtype as mstype
try:
    from mindspore._checkparam import Validator
except ImportError:
    import mindspore._checkparam as Validator
from mindspore import nn, ops
from mindspore.common.tensor import Tensor
from mindspore.context import ParallelMode
from mindspore.ops import operations as P
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation

from mindformers.core.loss.loss import CrossEntropyLoss
from mindformers.mindformer_book import MindFormerBook
from mindformers.models.base_model import BaseModel
from mindformers.modules.layers import Linear
from mindformers.modules.transformer.op_parallel_config import _check_config
from mindformers.modules.transformer.transformer import AttentionMask
from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister
from mindformers.pet.tuners.pet_adapter import PetAdapter
from mindformers.pet.tuners.lora_adapter import LoraAdapter

from .llama_config import LlamaConfig
from .llama_layer import LlamaEmbedding, LlamaRMSNorm, precompute_freqs_cis
from .llama_transformer import LLamaDecodeLayer

__all__ = ['LlamaModel', 'LlamaForCausalLM', 'LlamaForCausalLMWithLora']

def layer_compute_dtype(layer, layer_id, offset, parallel_config, n_layers):
    r"""
        Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.

        Args:
            layer(Cell) - Represents the transformer block
            parallel_config(dict) - Parallel Config
            layer_id(int) - Means the layer index for the current module, counts from zero.
            offset(int) - Means the layer_index needs a offset, if there are other modules in the net.
            n_layers(int) - The total layers used for the model.
    """
    pp_dis = max(int((n_layers + 1) / parallel_config.pipeline_stage), 1)
    pp_id = min((layer_id + offset) // pp_dis,
                parallel_config.pipeline_stage - 1)
    layer.pipeline_stage = pp_id

    # Used for optimizer's fusion tag
    dis = max(int((n_layers + 1) / parallel_config.gradient_aggregation_group), 1)
    if parallel_config.pipeline_stage > 1:
        layer.set_comm_fusion(2)
    else:
        layer.set_comm_fusion(int((layer_id + offset) / dis) + 1)
    if isinstance(parallel_config.recompute, bool):
        if parallel_config.recompute:
            layer.recompute()
    else:
        if parallel_config.recompute.recompute:
            layer.recompute(
                recompute_slice_activation=parallel_config.recompute.recompute_slice_activation)


[文档]class LlamaModel(BaseModel): r""" Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] Args: config(LlamaConfig): the config of network Inputs: input_ids: the tokenized inputs with datatype int32 Returns: output: Tensor, the output of llama decoderlayer """ _support_list = MindFormerBook.get_model_support_list()['llama'] def __init__(self, config: LlamaConfig = None): super().__init__(config, auto_prefix=True) _check_config(config.parallel_config) if config.batch_size or config.use_past: Validator.check_positive_int(config.batch_size) self.parallel_config = config.parallel_config self.vocab_size = config.vocab_size self.num_layers = config.num_layers self.pad_token_id = config.pad_token_id self.slice = P.StridedSlice().shard(((1, 1),)) self.reshape = P.Reshape() self.cast = P.Cast() if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation(): self.tok_embeddings = LlamaEmbedding(config.vocab_size, config.hidden_size, param_init_type=config.param_init_type, parallel_config=config.parallel_config) self.tok_embeddings.pipeline_stage = 0 if config.parallel_config.pipeline_stage > 1: self.tok_embeddings.set_comm_fusion(2) else: self.tok_embeddings.set_comm_fusion(config.parallel_config.gradient_aggregation_group) self.layers = nn.CellList() for layer_id in range(config.num_layers): layer = LLamaDecodeLayer(config.batch_size, config.seq_length, layer_id, dim=config.hidden_size, n_heads=config.num_layers, multiple_of=config.multiple_of, norm_eps=config.rms_norm_eps, compute_dtype=config.compute_dtype, layernorm_compute_dtype=config.layernorm_compute_type, softmax_compute_dtype=config.softmax_compute_type, param_init_type=config.param_init_type, use_past=config.use_past, parallel_config=config.parallel_config) layer_compute_dtype(layer, layer_id, config.offset, config.parallel_config, self.num_layers) self.layers.append(layer) self.norm_out = LlamaRMSNorm( config.hidden_size, config.rms_norm_eps, param_init_type=config.param_init_type).to_float(config.layernorm_compute_type) self.norm_out.pipeline_stage = config.parallel_config.pipeline_stage - 1 if config.parallel_config.pipeline_stage > 1: self.norm_out.set_comm_fusion(2) else: self.norm_out.set_comm_fusion(config.parallel_config.gradient_aggregation_group) elif _get_parallel_mode() not in (ParallelMode.AUTO_PARALLEL,): self.tok_embeddings = LlamaEmbedding(config.vocab_size, config.hidden_size, param_init_type=config.param_init_type, parallel_config=config.parallel_config) self.tok_embeddings.pipeline_stage = 0 if config.parallel_config.pipeline_stage > 1: self.tok_embeddings.set_comm_fusion(2) else: self.tok_embeddings.set_comm_fusion(config.parallel_config.gradient_aggregation_group) self.layers = nn.CellList() for layer_id in range(config.num_layers): layer = LLamaDecodeLayer(config.batch_size, config.seq_length, layer_id, dim=config.hidden_size, n_heads=config.num_heads, multiple_of=config.multiple_of, norm_eps=config.rms_norm_eps, compute_dtype=config.compute_dtype, layernorm_compute_dtype=config.layernorm_compute_type, softmax_compute_dtype=config.softmax_compute_type, param_init_type=config.param_init_type, use_past=config.use_past, parallel_config=config.parallel_config) layer_compute_dtype(layer, layer_id, config.offset, config.parallel_config, self.num_layers) self.layers.append(layer) self.norm_out = LlamaRMSNorm( config.hidden_size, config.rms_norm_eps, param_init_type=config.param_init_type).to_float(config.layernorm_compute_type) if config.parallel_config.pipeline_stage > 1: self.norm_out.set_comm_fusion(2) else: self.norm_out.set_comm_fusion(config.parallel_config.gradient_aggregation_group) self.norm_out.shard(((config.parallel_config.data_parallel, 1, 1),)) self.norm_out.pipeline_stage = config.parallel_config.pipeline_stage - 1 if config.parallel_config.pipeline_stage > 1: self.norm_out.set_comm_fusion(2) else: self.norm_out.set_comm_fusion(config.parallel_config.gradient_aggregation_group) self.freqs_cos, self.freqs_sin, self.mins_mask, self.rotary_mask = precompute_freqs_cis( config.hidden_size // config.num_heads, config.seq_length, dtype=config.compute_dtype ) self.get_attention_mask = AttentionMask( config.seq_length, parallel_config=config.parallel_config.dp_mp_config).to_float(config.compute_dtype) self.not_equal = P.NotEqual().shard(((config.parallel_config.data_parallel, 1), ())) self.freqs_size = config.hidden_size // config.num_heads # used for increased predict self.gather = P.Gather().shard(((1, 1), (1,))) # when in train process,it's always True;when in predict process,only first iteration is True. self.is_first_iteration = True self.all_ones_attention_mask = ops.ones((1, 1, 1), mstype.float32) self.use_past = config.use_past def construct(self, input_ids: Tensor, input_position=None, init_reset=True, batch_valid_length=None): """Forward of llama model.""" bs, _ = input_ids.shape # (b, t, d) , dp, 1, 1 h = self.tok_embeddings(input_ids) mask = None if self.is_first_iteration is False: # for increase predict freqs_cis = (self.gather(self.freqs_cos, input_position, 0), self.gather(self.freqs_sin, input_position, 0), self.mins_mask, self.rotary_mask) mask = P.Tile()(self.all_ones_attention_mask, (bs, 1, 1)) else: # first iteration of predict; all iterations of train freqs_cis = (self.freqs_cos, self.freqs_sin, self.mins_mask, self.rotary_mask) input_mask = self.cast(self.not_equal(input_ids, self.pad_token_id), mstype.float32) mask = self.get_attention_mask(input_mask) # dp,1,1 -> dp,1,1 for i in range(self.num_layers): h, _ = self.layers[i](h, freqs_cis, mask, init_reset=init_reset, batch_valid_length=batch_valid_length) # dp,1,1 -> dp,1,1 output = self.norm_out(h) return output
[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS) class LlamaForCausalLM(BaseModel): r""" Provide llama training loss or logits through network. Args: config (LlamaConfig): The config of llama model. Inputs: input_ids(Tensor): the tokenized inputs with datatype int32, Tensor of shape :math:`(batch, seq\_length)`. label_ids(Tensor): the tokenized labels with datatype int32, Tensor of shape :math:`(batch, seq\_length)` input_position(Tensor): current position, used by model.predict (bool, optional): Default: True. attention_mask(Tensor): Reserved param, not used. batch_valid_length(Tensor): Reserved param, not used. Returns: Tensor, the loss or logits of the network. Examples: >>> from mindformers.models.llama import LlamaConfig, LlamaForCausalLM >>> config = LlamaConfig(batch_size=2) >>> network = LlamaForCausalLM(config=config) """ _support_list = MindFormerBook.get_model_support_list()['llama'] def __init__(self, config: LlamaConfig = None): super(LlamaForCausalLM, self).__init__(config, auto_prefix=True) _check_config(config.parallel_config) self.model = LlamaModel(config=config) if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation(): self.lm_head = Linear(in_channels=config.hidden_size, out_channels=config.vocab_size, has_bias=False, compute_dtype=config.compute_dtype, param_init_type=config.param_init_type, weight_init="normal") # meta default: xavier_normal if config.parallel_config.pipeline_stage > 1: self.lm_head.pipeline_stage = config.parallel_config.pipeline_stage - 1 else: self.lm_head = Linear(in_channels=config.hidden_size, out_channels=config.vocab_size, has_bias=False, compute_dtype=config.compute_dtype, param_init_type=config.param_init_type, weight_init="normal") # meta default: xavier_normal if config.parallel_config.vocab_emb_dp: self.lm_head.shard(strategy_matmul=((config.parallel_config.data_parallel, 1), (1, 1))) else: self.lm_head.shard(strategy_matmul=((config.parallel_config.data_parallel, 1), (config.parallel_config.model_parallel, 1))) if config.parallel_config.pipeline_stage > 1: self.lm_head.pipeline_stage = config.parallel_config.pipeline_stage - 1 self.ignore_token_id = config.ignore_token_id self.pad_token_id = config.pad_token_id parallel_config = config.parallel_config self.loss = CrossEntropyLoss(parallel_config=parallel_config) dp = parallel_config.data_parallel self.slice = P.StridedSlice().shard(((dp, 1),)) self.not_equal = P.NotEqual().shard(((dp, 1), ())) self.reshape = P.Reshape() self.cast = P.Cast() self.mul = P.Mul().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1))) self.add = P.Add().shard(((parallel_config.data_parallel, 1), ())) # used for increased predict self.is_first_iteration = True self.load_checkpoint(config) # pylint: disable=W0613 def construct(self, input_ids, label_ids=None, input_position=None, position_ids=None, attention_mask=None, init_reset=True, batch_valid_length=None): """LlamaForCausalLM forward.""" bsz, seqlen = input_ids.shape if self.phase == "train": tokens = self.slice(input_ids, (0, 0), (bsz, seqlen - 1), (1, 1)) else: tokens = input_ids output = self.model(tokens, input_position, init_reset, batch_valid_length) logits = self.lm_head(output) input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float32) if label_ids is None: label_ids = self.slice(input_ids, (0, 1), (bsz, seqlen), (1, 1)) else: label_ids = self.slice(label_ids, (0, 1), (bsz, seqlen), (1, 1)) label_mask = self.cast(self.not_equal(label_ids, self.ignore_token_id), mstype.float32) input_mask = self.mul(input_mask, label_mask) logits = self.cast(logits, mstype.float32) if self.phase != "train": logits = self.reshape(logits, (bsz, seqlen, -1)) # makes cast effective to avoid allgather issue in Mindspore1.10 input_mask = self.add(input_mask, 1) return logits, tokens, input_mask if logits.ndim > 2: logits = self.reshape(logits, (-1, logits.shape[-1])) label_ids = self.reshape(label_ids, (-1,)) input_mask = self.reshape(input_mask, (-1,)) loss = self.loss(logits, label_ids, input_mask) return loss
[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS) class LlamaForCausalLMWithLora(LlamaForCausalLM): """Llama Model for finetuning with LoRA Args: config (LlamaConfig): The config of network. """ def __init__(self, config: LlamaConfig = None, pet=None): ckpt_cfg = config.checkpoint_name_or_path config.checkpoint_name_or_path = None super().__init__(config) # get Pet tuning model. self.pet = pet self.pet.pet_config.reg_rules = r'.*wq|.*wk|.*wv|.*wo' self.model = LoraAdapter.get_pet_model(self.model, self.pet.pet_config) # load lora ckpt config.checkpoint_name_or_path = ckpt_cfg self.load_checkpoint(config) # freeze pretrained model PetAdapter.freeze_pretrained_model(self, self.pet.pet_type)