mindformers.models.llama.llama 源代码

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LLaMA models' APIs."""
import numpy as np
import mindspore.common.dtype as mstype

try:
    from mindspore._checkparam import Validator
except ImportError:
    import mindspore._checkparam as Validator
from mindspore import Tensor, nn
from mindspore.context import ParallelMode
from mindspore.ops import operations as P
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
try:
    # pylint: disable=W0611
    from mindspore.nn.layer.flash_attention import FlashAttention
    FLASHATTENTION_VALID = True
except ImportError:
    FLASHATTENTION_VALID = False

from mindformers.core.loss.loss import CrossEntropyLoss
from mindformers.mindformer_book import MindFormerBook
from mindformers.models.base_model import BaseModel
from mindformers.modules.layers import Linear
from mindformers.modules.transformer.op_parallel_config import _check_config
from mindformers.modules.transformer.transformer import AttentionMask
from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister
from mindformers.pet.tuners.pet_adapter import PetAdapter
from mindformers.pet.tuners.lora_adapter import LoraAdapter

from .llama_config import LlamaConfig
from .llama_layer import LlamaEmbedding, LlamaRMSNorm, precompute_freqs_cis
from .llama_transformer import LLamaDecodeLayer
from ..utils import cell_reuse
from ...tools.logger import logger

__all__ = ['LlamaModel', 'LlamaForCausalLM', 'LlamaForCausalLMWithLora']

def layer_compute_dtype(layer, layer_id, offset, parallel_config, n_layers, select_recompute=False):
    r"""
        Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.

        Args:
            layer(Cell) - Represents the transformer block
            parallel_config(dict) - Parallel Config
            layer_id(int) - Means the layer index for the current module, counts from zero.
            offset(Union[int, List[int]]) - Means the layer_index needs a offset, if there are other modules in the net.
            n_layers(int) - The total layers used for the model.
    """
    pp_dis = max(int((n_layers + 1) / parallel_config.pipeline_stage), 1)
    if isinstance(offset, list):
        if len(offset) != parallel_config.pipeline_stage:
            raise ValueError(f"The length of `offset` {len(offset)} do not match "
                             "`pipeline stage` {parallel_config.pipeline_stage}.")
        i = min(layer_id // pp_dis, parallel_config.pipeline_stage - 1)
        offset_layer = offset[i]
    elif isinstance(offset, int):
        offset_layer = offset
    else:
        raise TypeError(f"`offset` must be `int` of list of `int`, but got {type(offset)}.")

    pp_id = min((layer_id + offset_layer) // pp_dis, parallel_config.pipeline_stage - 1)
    layer.pipeline_stage = pp_id

    # Used for optimizer's fusion tag
    dis = max(int((n_layers + 1) / parallel_config.gradient_aggregation_group), 1)
    if parallel_config.pipeline_stage > 1:
        layer.set_comm_fusion(2)
    else:
        layer.set_comm_fusion(int((layer_id + offset_layer) / dis) + 1)
    if isinstance(parallel_config.recompute, bool):
        if parallel_config.recompute and not select_recompute:
            layer.recompute()
    else:
        if parallel_config.recompute.recompute and not select_recompute:
            layer.recompute(
                recompute_slice_activation=parallel_config.recompute.recompute_slice_activation)


[文档]class LlamaModel(BaseModel): r""" Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] Args: config(LlamaConfig): the config of network Inputs: input_ids: the tokenized inputs with datatype int32 Returns: output: Tensor, the output of llama decoderlayer """ _support_list = MindFormerBook.get_model_support_list()['llama'] def __init__(self, config: LlamaConfig = None): super().__init__(config, auto_prefix=True) _check_config(config.parallel_config) if config.batch_size or config.use_past: Validator.check_positive_int(config.batch_size) self.dtype = config.compute_dtype self.num_layers = config.num_layers self.pad_token_id = config.pad_token_id self.is_first_iteration = True self.use_past = config.use_past self.use_flash_attention = config.use_flash_attention and FLASHATTENTION_VALID if self.use_flash_attention: logger.info("Enable flash attention.") elif config.use_flash_attention: logger.info("Current MindSpore do not support flash attention.") self.freqs_cos, self.freqs_sin, self.swap_mask = precompute_freqs_cis( config.hidden_size // config.num_heads, config.seq_length, dtype=config.rotary_dtype, pretrain_seqlen=config.pretrain_seqlen, extend_method=config.extend_method) self.get_attention_mask = AttentionMask( config.seq_length, parallel_config=config.parallel_config.dp_mp_config).to_float(config.compute_dtype) self.multiply_data = Tensor([-10000.0], dtype=config.compute_dtype) self.one = Tensor([1.0], dtype=config.compute_dtype) self.reshape = P.Reshape() self.cast = P.Cast() self.tile = P.Tile() self.mul_mask = P.Mul() self.sub = P.Sub() self.expand_dims = P.ExpandDims() self.not_equal = P.NotEqual() self.gather = P.Gather() self.tok_embeddings = LlamaEmbedding( config.vocab_size, config.hidden_size, param_init_type=config.param_init_type) self.layers = nn.CellList() for layer_id in range(config.num_layers): layer = LLamaDecodeLayer(config.batch_size, config.seq_length, layer_id, dim=config.hidden_size, n_heads=config.num_heads, multiple_of=config.multiple_of, n_kv_heads=config.n_kv_heads, ffn_dim_multiplier=config.ffn_dim_multiplier, norm_eps=config.rms_norm_eps, compute_dtype=config.compute_dtype, layernorm_compute_dtype=config.layernorm_compute_type, softmax_compute_dtype=config.softmax_compute_type, rotary_dtype=config.rotary_dtype, param_init_type=config.param_init_type, use_past=config.use_past, use_flash_attention=config.use_flash_attention, compute_in_2d=config.compute_in_2d, use_past_shard=config.use_past_shard, parallel_config=config.parallel_config) layer_compute_dtype(layer, layer_id, config.offset, config.parallel_config, config.num_layers, select_recompute=config.parallel_config.recompute.select_recompute) self.layers.append(layer) self.norm_out = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps, compute_type=config.layernorm_compute_type) dp = config.parallel_config.data_parallel if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): self.tok_embeddings.pipeline_stage = 0 if config.parallel_config.pipeline_stage > 1: self.norm_out.pipeline_stage = config.parallel_config.pipeline_stage - 1 self.tok_embeddings.set_comm_fusion(2) self.norm_out.set_comm_fusion(2) else: self.tok_embeddings.set_comm_fusion(config.parallel_config.gradient_aggregation_group) self.norm_out.set_comm_fusion(config.parallel_config.gradient_aggregation_group) self.tok_embeddings.shard(config.parallel_config) self.tile.shard(((1, 1, 1, 1), ())) self.sub.shard(((1,), (dp, 1, 1))) self.mul_mask.shard(((dp, 1, 1, 1), (1,))) self.expand_dims.shard(((dp, 1, 1),)) self.not_equal.shard(((dp, 1), ())) self.gather.shard(((dp, 1), (1,))) if config.compute_in_2d: self.norm_out.shard((dp, 1)) else: self.norm_out.shard((dp, 1, 1)) if self.use_past: seq_range = np.arange(config.seq_length).reshape(1, 1, -1) self.range = Tensor(np.tile(seq_range, (config.batch_size, 1, 1)), mstype.int32) self.gather_past = P.Gather() self.expand_dims = P.ExpandDims() self.le_past = P.LessEqual() # pylint: disable=W0613 def construct(self, tokens: Tensor, input_position=None, init_reset=True, batch_valid_length=None): """Forward of llama model.""" # preprocess bs, seq_len = tokens.shape if self.is_first_iteration: freqs_cis = (self.tile(self.reshape(self.freqs_cos, (1, 1, seq_len, -1)), (bs, 1, 1, 1)), self.tile(self.reshape(self.freqs_sin, (1, 1, seq_len, -1)), (bs, 1, 1, 1)), self.swap_mask) input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), self.dtype) mask = self.get_attention_mask(input_mask) # mask: [bs, seq, seq] else: cur_pos = batch_valid_length - 1 valid_length = self.reshape(cur_pos, (-1, 1, 1)) freqs_cis = (self.reshape(self.gather_past(self.freqs_cos, cur_pos, 0), (bs, 1, seq_len, -1)), self.reshape(self.gather_past(self.freqs_sin, cur_pos, 0), (bs, 1, seq_len, -1)), self.swap_mask) mask = self.cast(self.le_past(self.range, valid_length), self.dtype) # mask: [bs, 1, 1] mask = self.sub(self.one, self.cast(mask, self.dtype)) if not self.use_flash_attention: mask = self.expand_dims(mask, 1) mask = self.mul_mask(mask, self.multiply_data) # tokens: [bs, seq/1] h = self.tok_embeddings(tokens) # h: [bs, seq/1, hidden_dim] for i in range(self.num_layers): h, _ = self.layers[i](h, freqs_cis, mask, init_reset=init_reset, batch_valid_length=batch_valid_length) output = self.norm_out(h) return output
[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS) class LlamaForCausalLM(BaseModel): r""" Provide llama training loss or logits through network. Args: config (LlamaConfig): The config of llama model. Inputs: input_ids(Tensor): the tokenized inputs with datatype int32, Tensor of shape :math:`(batch, seq\_length)`. labels(Tensor): the tokenized labels with datatype int32, Tensor of shape :math:`(batch, seq\_length)`. input_position(Tensor): current position, used by model.predict. position_ids(Tensor): Reserved param, not used. attention_mask(Tensor): Reserved param, not used. input_embeds(Tensor): Reserved param, not used. init_reset(bool, optional): A bool tensor with shape [1], used to clear the past key parameter and past value parameter used in the incremental prediction. Default True. batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental prediction. Tensor of shape :math:`(batch_size,)`. Default None. Returns: Tensor, the loss or logits of the network. Examples: >>> from mindformers.models.llama import LlamaConfig, LlamaForCausalLM >>> config = LlamaConfig(batch_size=2) >>> network = LlamaForCausalLM(config=config) """ _support_list = MindFormerBook.get_model_support_list()['llama'] @cell_reuse() def __init__(self, config: LlamaConfig = None): super(LlamaForCausalLM, self).__init__(config, auto_prefix=True) _check_config(config.parallel_config) self.ignore_token_id = config.ignore_token_id self.pad_token_id = config.pad_token_id self.reshape = P.Reshape() self.cast = P.Cast() self.slice = P.StridedSlice() self.not_equal = P.NotEqual() self.mul = P.Mul() self.add = P.Add() self.model = LlamaModel(config=config) self.lm_head = Linear(in_channels=config.hidden_size, out_channels=config.vocab_size, has_bias=False, compute_dtype=config.compute_dtype, param_init_type=config.param_init_type, weight_init="normal") # meta default: xavier_normal self.loss = CrossEntropyLoss(parallel_config=config.parallel_config) dp = config.parallel_config.data_parallel mp = config.parallel_config.model_parallel if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): self.slice.shard(((dp, 1),)) self.not_equal.shard(((dp, 1), ())) self.mul.shard(((dp, 1), (dp, 1))) self.add.shard(((dp, 1), ())) if config.parallel_config.vocab_emb_dp: self.lm_head.shard(strategy_matmul=((dp, 1), (1, 1))) else: self.lm_head.shard(strategy_matmul=((dp, 1), (mp, 1))) if config.parallel_config.pipeline_stage > 1: self.lm_head.pipeline_stage = config.parallel_config.pipeline_stage - 1 self.load_checkpoint(config) def prepare_inputs_for_generation(self, input_ids, **kwargs): return { "input_ids": Tensor(input_ids, mstype.int32) } # pylint: disable=W0613 def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None, input_embeds=None, init_reset=True, batch_valid_length=None): """LlamaForCausalLM forward.""" bsz, seqlen = input_ids.shape if self.training: tokens = self.slice(input_ids, (0, 0), (bsz, seqlen - 1), (1, 1)) else: tokens = input_ids output = self.model(tokens, input_position, init_reset, batch_valid_length) logits = self.lm_head(output) input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float32) if labels is None: labels = self.slice(input_ids, (0, 1), (bsz, seqlen), (1, 1)) else: if labels.ndim > 1: if self.training: labels = self.slice(labels, (0, 1), (bsz, seqlen), (1, 1)) label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), mstype.float32) input_mask = self.mul(input_mask, label_mask) logits = self.cast(logits, mstype.float32) if not self.training: logits = self.reshape(logits, (bsz, seqlen, -1)) # makes cast effective to avoid allgather issue in Mindspore1.10 input_mask = self.add(input_mask, 1) return logits, tokens, input_mask if logits.ndim > 2: logits = self.reshape(logits, (-1, logits.shape[-1])) labels = self.reshape(labels, (-1,)) input_mask = self.reshape(input_mask, (-1,)) loss = self.loss(logits, labels, input_mask) return loss
[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS) class LlamaForCausalLMWithLora(LlamaForCausalLM): """Llama Model for finetuning with LoRA Args: config (LlamaConfig): The config of network. """ def __init__(self, config: LlamaConfig = None): ckpt_cfg = config.checkpoint_name_or_path config.checkpoint_name_or_path = None super().__init__(config) # get Pet tuning model. config.pet_config.reg_rules = r'.*wq|.*wk|.*wv|.*wo' self.model = LoraAdapter.get_pet_model(self.model, config.pet_config) # load lora ckpt config.checkpoint_name_or_path = ckpt_cfg self.load_checkpoint(config) # freeze pretrained model PetAdapter.freeze_pretrained_model(self, config.pet_config.pet_type)