mindformers.models.t5.t5 源代码

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""T5 model."""

import math
import numpy as np
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore import context
from mindspore.ops.primitive import constexpr
from mindspore.ops import functional as F
from mindspore.ops import operations as P

import mindspore.numpy
from mindspore.context import ParallelMode
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
import mindspore.common.dtype as mstype
from mindspore.common.initializer import TruncatedNormal, initializer

from mindspore.parallel._utils import _get_parallel_mode

from mindformers.modules.layers import Linear, _check_past_none_input_none, _check_input_dtype
from mindformers.modules.transformer.moe import MoE, _check_moe_config
from mindformers.modules.transformer.transformer import default_transformer_config, default_moe_config, \
    default_dpmp_config, \
    EmbeddingOpParallelConfig, OpParallelConfig
from mindformers.core.loss import CrossEntropyLoss
from mindformers.modules import VocabEmbedding

from .t5_config import T5Config

from ..base_model import BaseModel
from ...tools import logger
from ...tools.register import MindFormerRegister, MindFormerModuleType
from ...mindformer_book import MindFormerBook

__all__ = ['T5ForConditionalGeneration']


class LayerNorm(nn.Cell):
    """
        T5 layer norm nn.Cell
    """

    def __init__(self, normalized_shape, eps=1e-5, param_init_type=mstype.float32):
        super(LayerNorm, self).__init__()
        if param_init_type not in [mstype.float32, mstype.float16]:
            raise TypeError("The type of parameter 'param_init_type' should in [float32, float16], "
                            "but got the type : {}.".format(type(param_init_type)))
        self.gamma = Parameter(initializer('ones', normalized_shape, param_init_type), name="gamma",
                               parallel_optimizer=False)
        self.mean = P.ReduceMean(keep_dims=True)
        self.square = P.Square()
        self.sqrt = P.Sqrt()
        self.sub1 = P.Sub()
        self.sub2 = P.Sub()
        self.add = P.Add()
        self.eps = eps
        self.mul = P.Mul()
        self.add2 = P.Add()
        self.real_div = P.RealDiv()

    def construct(self, x):
        r"""
          x : batch x seq_length x hidden_size
        """
        variance = self.mean(self.square(x), -1)
        variance_eps = self.sqrt(self.add(variance, self.eps))
        output = self.real_div(x, variance_eps)
        output = self.mul(output, self.gamma)
        return output

    def shard(self, strategy):
        r"""
        Set the shard for the layer norm. the strategy size should be equal to the inputs.

        Note:
            It is valid only in semi auto parallel or auto parallel mode.
            In other parallel modes, strategies set here will be ignored.

        Args:
            strategy (tuple): The strategy for the dropout. Should be the same shape as the inputs.
        """
        self.mean.shard(strategy)
        self.square.shard(strategy)
        self.sqrt.shard(strategy)
        self.sub1.shard((strategy[0], strategy[0]))
        self.sub2.shard((strategy[0], strategy[0]))
        self.add.shard((strategy[0], ()))
        self.mul.shard((strategy[0], (1,)))
        self.add2.shard((strategy[0], (1,)))
        self.real_div.shard((strategy[0], strategy[0]))
        return self


class T5FeedFoward(nn.Cell):
    """
        T5 feedfoward cell with relu as hidden act
    """

    def __init__(self, hidden_size,
                 ffn_hidden_size,
                 dropout_rate,
                 hidden_act='gelu',
                 expert_num=1,
                 expert_group_size=None,
                 param_init_type=mstype.float32,
                 parallel_config=default_dpmp_config):
        super(T5FeedFoward, self).__init__()
        mp = parallel_config.model_parallel
        if expert_num > 1:
            ep = parallel_config.expert_parallel
        else:
            ep = 1
        # ffn use less dp than other ops when use_moe, due to there are ops use dp and ep.
        dp = int(parallel_config.data_parallel / ep)
        if ffn_hidden_size % mp != 0:
            raise ValueError("For 'T5FeedFoward', the class variable 'ffn_hidden_size' must be a multiple of the"
                             "num of model parallel, but got the ffn_hidden_size is {} and the num of model "
                             "parallel is {}.".format(ffn_hidden_size, mp))
        if hidden_size % mp != 0:
            raise ValueError("For 'T5FeedFoward', the class variable 'hidden_size' must be a multiple of the num of "
                             "model parallel, but got the hidden_size is {} and the num of model parallel is {}."
                             .format(hidden_size, mp))
        if dropout_rate < 0 or dropout_rate >= 1:
            raise ValueError("For 'T5FeedFoward', the class variable 'dropout_rate' must be in the range [0, 1.0), "
                             "but got the value : {}.".format(dropout_rate))
        input_size = hidden_size
        output_size = ffn_hidden_size

        # Project to ffn_hidden_size
        self.mapping = Linear(in_channels=input_size,
                              out_channels=output_size,
                              activation=hidden_act,
                              transpose_b=False,
                              has_bias=False,
                              expert_num=expert_num,
                              expert_group_size=expert_group_size,
                              outer_batch=dp,
                              param_init_type=param_init_type)

        if expert_num > 1:
            self.mapping.shard(strategy_matmul=((dp, ep, 1, 1), (ep, 1, mp)),
                               strategy_bias=((dp, ep, 1, mp), (1, ep, 1, mp)),
                               strategy_activation=((dp, ep, 1, mp),))
        else:
            self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
                               strategy_bias=((dp, mp), (mp,)),
                               strategy_activation=((dp, mp),))
        # Project back to hidden_size
        self.projection = Linear(in_channels=output_size,
                                 out_channels=input_size,
                                 transpose_b=False,
                                 expert_num=expert_num,
                                 has_bias=False,
                                 expert_group_size=expert_group_size,
                                 outer_batch=dp,
                                 param_init_type=param_init_type)
        if expert_num > 1:
            self.projection.shard(strategy_matmul=((dp, ep, 1, mp), (ep, mp, 1)),
                                  strategy_bias=((dp, ep, 1, 1), (1, ep, 1, 1)))
        else:
            self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)),
                                  strategy_bias=((dp, 1), (1,)))
        self.dropout = nn.Dropout(1 - dropout_rate)
        self.dropout.dropout.shard(((dp, 1),))
        self.dropout_3d = nn.Dropout(1 - dropout_rate)
        self.dropout_3d.dropout.shard(((dp, 1, 1),))
        self.dropout_4d = nn.Dropout(1 - dropout_rate)
        self.dropout_4d.dropout.shard(((dp, ep, 1, 1),))
        self.cast = P.Cast()

    def construct(self, x):
        """The forward function of FFN"""
        _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
        x = self.cast(x, mstype.float16)
        # returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
        hidden = self.mapping(x)
        output = self.projection(hidden)
        # returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
        if len(F.shape(output)) == 3:
            output = self.dropout_3d(output)
        elif len(F.shape(output)) == 2:
            output = self.dropout(output)
        else:
            output = self.dropout_4d(output)
        return output


class RelaPosMatrixGenerator(nn.Cell):
    """
        The relative position index generator. The result of the cell should be feed into the bias embedding table.
    """

    def __init__(self, max_relative_position, log_relative_distance):
        super(RelaPosMatrixGenerator, self).__init__()
        self._max_relative_position = max_relative_position
        self._min_relative_position = -max_relative_position

        self.tile = P.Tile()
        self.range_mat = P.Reshape()
        self.sub = P.Sub()
        self.expanddims = P.ExpandDims()
        self.cast = P.Cast()
        self.log_relative_distance = log_relative_distance

    def construct(self, relative_position, bidirectional=True, num_buckets=32):
        """The forward of the bias position"""
        relative_bucket = 0
        if bidirectional:
            num_buckets = num_buckets // 2
            relative_bucket = relative_bucket + (relative_position > 0).astype(mstype.int32) * num_buckets
            relative_position = P.Abs()(relative_position)
        else:
            relative_position = -P.Minimum()(relative_position, P.ZerosLike()(relative_position))

        max_exact = num_buckets // 2
        is_small = relative_position < max_exact
        relative_position_if_large = max_exact + (P.Log()(relative_position.astype(mstype.float32) / max_exact)
                                                  / self.log_relative_distance
                                                  * (num_buckets - max_exact))
        relative_position_if_large = relative_position_if_large.astype(mstype.int32)
        relative_position_if_large = P.Minimum()(relative_position_if_large,
                                                 mindspore.numpy.full_like(relative_position_if_large,
                                                                           num_buckets - 1))
        relative_bucket += mindspore.numpy.where(is_small, relative_position, relative_position_if_large)
        return relative_bucket


class RelaPosEmbeddingsGenerator(nn.Cell):
    """The relative position embedding generator."""

    def __init__(self,
                 depth,
                 max_relative_position,
                 initializer_range,
                 is_decoder):
        super(RelaPosEmbeddingsGenerator, self).__init__()
        self.depth = depth
        self.vocab_size = max_relative_position
        self.embeddings_table = Parameter(initializer(TruncatedNormal(initializer_range),
                                                      [self.vocab_size, self.depth]))
        self.reshape = P.Reshape()
        self.one_hot = nn.OneHot(depth=self.vocab_size)
        self.shape = P.Shape()
        self.gather = P.Gather()
        self.matmul = P.BatchMatMul()
        self.relative_attention_num_buckets = 32
        self.relative_attention_max_distance = 128
        self.is_decoder = is_decoder

        num_buckets = self.relative_attention_num_buckets

        max_exact = self.relative_attention_num_buckets // 2
        if not self.is_decoder:
            max_exact = max_exact // 2
            num_buckets //= 2

        self.log_relative_distance = math.log(self.relative_attention_max_distance / max_exact)
        self.relative_position_matrix = RelaPosMatrixGenerator(max_relative_position=max_relative_position,
                                                               log_relative_distance=self.log_relative_distance)

    def construct(self, query_length, key_length):
        """The forward function"""
        context_position = mindspore.numpy.arange(query_length, dtype=mstype.int32).expand_dims(-1)
        memory_position = mindspore.numpy.arange(key_length, dtype=mstype.int32).expand_dims(0)
        relative_position = memory_position - context_position
        relative_position_bucket = self.relative_position_matrix(
            relative_position,
            bidirectional=(not self.is_decoder),
            num_buckets=self.relative_attention_num_buckets)
        embeddings = self.gather(self.embeddings_table,
                                 relative_position_bucket, 0)
        embeddings = embeddings.transpose((2, 0, 1)).expand_dims(0)
        return embeddings


class T5MultiHeadAttention(nn.Cell):
    """
        T5 multi head attention
    """

    def __init__(self, batch_size,
                 src_seq_length,
                 tgt_seq_length,
                 hidden_size,
                 num_heads,
                 kv_size=64,
                 hidden_dropout_rate=0.1,
                 attention_dropout_rate=0.1,
                 compute_dtype=mstype.float16,
                 softmax_compute_type=mstype.float32,
                 param_init_type=mstype.float32,
                 use_past=False,
                 has_relative_bias=False,
                 is_decoder=False,
                 is_cross_atten=False,
                 parallel_config=default_dpmp_config):
        super(T5MultiHeadAttention, self).__init__()
        self._is_ascend = context.get_context('device_target') in ["Ascend"]
        self.is_parallel_mode = _get_parallel_mode() in (
            ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
        self.src_seq_length = src_seq_length
        self.tgt_seq_length = tgt_seq_length
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.has_relative_bias = has_relative_bias
        if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1:
            raise ValueError("For 'T5MultiHeadAttention', the class variable 'hidden_dropout_rate' must be "
                             "in range [0, 1.0), but got the value : {}.".format(hidden_dropout_rate))
        if attention_dropout_rate < 0 or attention_dropout_rate >= 1:
            raise ValueError("For 'T5MultiHeadAttention', the class variable 'attention_dropout_rate' must be "
                             "in range [0, 1.0), but got the value : {}.".format(attention_dropout_rate))
        if hidden_size % num_heads != 0:
            raise ValueError("For 'T5MultiHeadAttention', the class variable 'hidden_size' must be a multiple "
                             "of 'num_heads', but got the hidden_size is {} and the num_heads is {}."
                             .format(hidden_size, num_heads))
        if num_heads % parallel_config.model_parallel != 0:
            raise ValueError("For 'T5MultiHeadAttention', the class variable 'num_heads' must be a multiple of "
                             "'parallel_config.model_parallel', but got the num_heads is {} "
                             "and the parallel_config.model_parallel  is {}."
                             .format(num_heads, parallel_config.model_parallel))
        if self.is_parallel_mode and batch_size % parallel_config.data_parallel != 0:
            raise ValueError("For 'T5MultiHeadAttention', the class variable 'batch_size' must be a multiple of "
                             "'parallel_config.data_parallel', but got the batch_size is {} "
                             "and the parallel_config.data_parallel is {}."
                             .format(batch_size, parallel_config.data_parallel))
        self.is_first_iteration = True
        self.inner_dim = num_heads * kv_size
        # Output layer
        self.projection = Linear(in_channels=self.inner_dim,
                                 out_channels=hidden_size,
                                 transpose_b=False,
                                 has_bias=False,
                                 compute_dtype=compute_dtype,
                                 param_init_type=param_init_type)
        self.projection.shard(strategy_bias=((parallel_config.data_parallel, 1), (1,)),
                              strategy_matmul=((parallel_config.data_parallel, parallel_config.model_parallel),
                                               (parallel_config.model_parallel, 1)))
        self.transpose = P.Transpose().shard(
            ((parallel_config.data_parallel, 1, parallel_config.model_parallel, 1),))
        self.merger_head_transpose = P.Transpose().shard(
            ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),))
        self.reshape = P.Reshape()
        self.n_head = num_heads
        # embedding size per head
        self.size_per_head = kv_size
        self.concat_k = P.Concat(axis=3)
        self.concat_v = P.Concat(axis=2)
        self.multiply_data = Tensor([
            -10000.0,
        ], dtype=softmax_compute_type)
        self.batch_matmul = P.BatchMatMul().shard(
            ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),
             (parallel_config.data_parallel, parallel_config.model_parallel, 1, 1)))
        self.real_div = P.RealDiv().shard(
            ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1), ()))
        self.sub = P.Sub().shard(
            ((1,), (parallel_config.data_parallel, 1, 1, 1)))
        self.mul = P.Mul().shard(
            ((parallel_config.data_parallel, 1, 1, 1), (1,)))
        self.add = P.Add().shard(
            ((parallel_config.data_parallel, 1, 1, 1),
             (parallel_config.data_parallel, parallel_config.model_parallel, 1, 1)))
        # Normalize factor for attention, sqrt(dk) as widely used
        self.use_past = use_past
        self.dropout = nn.Dropout(1 - hidden_dropout_rate)
        self.dropout.dropout.shard(((parallel_config.data_parallel, 1),))
        self.prob_dropout = nn.Dropout(1 - attention_dropout_rate)
        self.prob_dropout.dropout.shard(
            ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),))
        self.softmax = nn.Softmax().to_float(softmax_compute_type)
        self.softmax.softmax.shard(((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),))
        self.softmax_3d = nn.Softmax().to_float(softmax_compute_type)
        self.softmax_3d.softmax.shard(((parallel_config.data_parallel, parallel_config.model_parallel, 1),))
        self.expand_dims = P.ExpandDims().shard(((parallel_config.data_parallel, 1, 1),))

        # Query
        self.dense1 = Linear(hidden_size,
                             self.inner_dim,
                             has_bias=False,
                             compute_dtype=compute_dtype,
                             param_init_type=param_init_type)
        self.dense1.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
                          strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
                                         (parallel_config.model_parallel,)))
        # Key
        self.dense2 = Linear(hidden_size,
                             self.inner_dim,
                             has_bias=False,
                             compute_dtype=compute_dtype,
                             param_init_type=param_init_type)
        self.dense2.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
                          strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
                                         (parallel_config.model_parallel,)))

        # Value
        self.dense3 = Linear(hidden_size,
                             self.inner_dim,
                             has_bias=False,
                             compute_dtype=compute_dtype,
                             param_init_type=param_init_type)
        self.dense3.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
                          strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
                                         (parallel_config.model_parallel,)))
        self.dtype = compute_dtype
        self.softmax_dtype = softmax_compute_type

        self.is_decoder = is_decoder
        self.has_relative_bias = has_relative_bias

        self.is_cross_atten = is_cross_atten
        self.cross_bias = None
        if self.has_relative_bias:
            if not self.is_cross_atten:
                self.bias_generator = RelaPosEmbeddingsGenerator(depth=num_heads,
                                                                 max_relative_position=32,
                                                                 initializer_range=0.02,
                                                                 is_decoder=self.is_decoder)
            else:
                self.cross_bias = Parameter(initializer("zero", [1, self.src_seq_length, self.tgt_seq_length]),
                                            name='cross_attention_bias', parallel_optimizer=False)

        if self.use_past:
            # operators used for state reuse
            seq_range = np.arange(src_seq_length).reshape(1, 1, -1)
            self.range = Tensor(np.tile(seq_range, (batch_size, 1, 1)), mstype.int32)
            self.seq_length = src_seq_length
            self.attention_mask = Tensor(np.tril(np.ones(shape=(self.seq_length, self.seq_length))), mstype.int32)
            self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
            self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
            self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
            self.expand_dims = P.ExpandDims().shard(((1, 1, 1),))
            self.tensor_le = P.LessEqual().shard(((1, 1, 1), (1, 1, 1)))
            self.add = P.Add().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
            self.equal = P.Equal().shard(((1, 1, 1), (1, 1, 1)))
            self.sub1 = P.Sub().shard(((1,), ()))
            self.tile = P.Tile().shard(((1, 1, 1, 1),))
            self.less = P.Less().shard(((1, 1, 1), (1, 1, 1)))
            self.mul1 = P.Mul().shard(((1, 1, 1, 1), (1, 1, 1, 1)))

    def construct(self, query_tensor, key_tensor, value_tensor, attention_mask, bias=None, key_past=None,
                  value_past=None, batch_valid_length=None):
        """forward function for attention"""
        self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
                           value_past, batch_valid_length)
        query_tensor, key_tensor, value_tensor, batch_size, ori_shape = self._convert_to_2d_tensor(query_tensor,
                                                                                                   key_tensor,
                                                                                                   value_tensor,
                                                                                                   attention_mask)
        ori_dtype = F.dtype(query_tensor)
        query_tensor = F.cast(query_tensor, self.dtype)
        key_tensor = F.cast(key_tensor, self.dtype)
        value_tensor = F.cast(value_tensor, self.dtype)
        # multi head attention: query, key, value are derived from the same inputs
        query = self.dense1(query_tensor)
        key = self.dense2(key_tensor)
        value = self.dense3(value_tensor)
        # the returned shape is [bs, num_heads, seq_length, size_per_head]
        query = self.transpose(
            F.reshape(
                query,
                (batch_size, -1, self.n_head, self.size_per_head)),
            (0, 2, 1, 3))
        # the returned shape is [bs, size_per_head, seq_length, num_heads]
        key = self.transpose(
            F.reshape(
                key, (batch_size, -1, self.n_head, self.size_per_head)),
            (0, 2, 3, 1))
        # the returned shape is [bs, num_heads, seq_length, size_per_head]
        value = self.transpose(
            F.reshape(
                value,
                (batch_size, -1, self.n_head, self.size_per_head)),
            (0, 2, 1, 3))
        # support input shape is [bs, seq, seq] or [bs, heads, seq, seq]
        if len(F.shape(attention_mask)) == 3:
            # expand attention mask from [bs, seq, seq] -> [bs, 1, seq, seq]
            attention_mask = self.expand_dims(attention_mask, 1)
        # key and value for current token(s)
        key_present = key
        value_present = value
        if self.use_past:
            # The first graph with the input size of (bs, seq_length)
            if self.is_first_iteration:
                # Get the valid input length without padding
                valid_length_vector = F.cast(self.less(self.range, batch_valid_length.view(-1, 1, 1)), self.dtype)
                # Cover the key and value numbers corresponding to the padding position
                key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
                value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
            # The second graph with the inpus size of (bs, 1)
            # the shape of query is (bs, num_heads, 1, size_per_head)
            # the shape of key is   (bs, num_heads, size_per_head, 1)
            # the shape of value is (bs, num_heads, 1, size_per_head)
            else:
                # Get the current token position index
                valid_length = self.reducesum(F.cast(self.not_equal(self.slice(key_past, (0, 0, 0, 0),
                                                                               (F.shape(key_tensor)[0], 1, 1,
                                                                                self.src_seq_length),
                                                                               (1, 1, 1, 1)),
                                                                    0), mstype.float32), (1, 2, 3))
                valid_length = F.reshape(valid_length, (-1, 1, 1))
                valid_length_vector = F.cast(self.equal(valid_length, self.range), self.dtype)
                # Pad the key and value to seq_length with only the position index not zero
                current_key = self.mul1(self.tile(key, (1, 1, 1, self.seq_length)),
                                        self.expand_dims(valid_length_vector, 2))
                current_value = self.mul1(self.tile(value, (1, 1, self.seq_length, 1)),
                                          self.expand_dims(valid_length_vector, 3))
                # Concat the previous saved state and current state
                key = self.add(key_past, current_key)
                value = self.add(value_past, current_value)
                # Update key_present and value_present for state update
                key_present = key
                value_present = value
                attention_mask = F.reshape(self.attention_mask, (self.seq_length, self.seq_length, 1, 1))

        layer_present = (key_present, value_present)
        # multi head attention considering attention mask
        # the return shape is [bs * seq_length, hidden_size]
        attention, bias = self._attn(query, key, value, attention_mask, bias)
        # Output
        output = self.projection(attention)
        output = self.dropout(output)
        output = F.reshape(output, ori_shape)
        output = F.cast(output, ori_dtype)
        return output, layer_present, bias

    def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
                      value_past=None, batch_valid_length=None):
        r"""Check inputs"""
        _check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
        _check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
        _check_input_dtype(F.dtype(value_tensor), "value_tensor", [mstype.float32, mstype.float16], self.cls_name)
        _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)

        key_is_tensor = isinstance(key_past, Tensor)
        value_is_tensor = isinstance(value_past, Tensor)
        batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor)
        key_is_default = key_past is None
        value_is_default = value_past is None
        batch_is_default = batch_valid_length is None
        _check_past_none_input_none(self.use_past, "key_past", self.cls_name, None, key_is_tensor,
                                    key_is_default)
        _check_past_none_input_none(self.use_past, "value_past", self.cls_name, None, value_is_tensor,
                                    value_is_default)
        _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
                                    batch_valid_length_is_tensor, batch_is_default)
        return True

    def _convert_to_2d_tensor(self, query_tensor, key_tensor, value_tensor, attention_mask):
        """convert a nd tensor to a 2d tensor"""
        query_shape = F.shape(query_tensor)
        query_tensor = F.reshape(query_tensor, (-1, query_shape[-1]))
        key_shape = F.shape(key_tensor)
        key_tensor = F.reshape(key_tensor, (-1, key_shape[-1]))
        value_shape = F.shape(value_tensor)
        value_tensor = F.reshape(value_tensor, (-1, value_shape[-1]))
        return query_tensor, key_tensor, value_tensor, F.shape(attention_mask)[0], query_shape

    def _merge_heads(self, x):
        """
        convert a 4d input to a 2d output

        Inputs:
            x: input tensor

        Output:
            x_merge: the 2d output
        """
        x = self.merger_head_transpose(
            x, (0, 2, 1, 3))  # bs, seq_length, head, size_per_head
        x_shape = P.Shape()(x)
        new_shape = (-1, x_shape[-2] * x_shape[-1])
        x_merge = self.reshape(x, new_shape)
        return x_merge

    def _softmax(self, attention_scores):
        """
        For the consideration of the performance, do softmax according to different situations
        :param attention_scores: a 3d tensor before softmax
        :return: the attention scores.
        """

        if self._is_ascend and self.softmax_dtype == mstype.float16 or not self._is_ascend:
            attention_probs = self.softmax(attention_scores)
        else:
            shape = F.shape(attention_scores)
            # attention probs
            attention_probs = self.softmax_3d(
                F.reshape(attention_scores,
                          (shape[0], -1, shape[-1])))
            attention_probs = F.reshape(attention_probs, shape)
        return attention_probs

    def _attn(self, query, key, value, attention_mask, bias):
        """
        Get the weighted score along the seq_length

        Inputs:
            query: the query matrix
            key: the key matrix
            value: the value matrix
            attention_mask: the attention mask matrix with shape (batch_size,
            1, seq_length, seq_length)
        Outputs:
            weighted_values: Tensor, the weighted sum scores
        """
        # Normalize query and key before MatMul, default off
        # Attention score [bs, num_heads, seq_length, seq_length]
        score = self.batch_matmul(query, key)

        ori_dtype = P.DType()(score)
        score = P.Cast()(score, self.softmax_dtype)

        # for input size of (bs, 1) namely the second graph,
        # the shape of attention_mask matrix should be (bs, 1, 1, seq_length)
        if self.use_past and not self.is_first_iteration:
            # Calculate the current total token
            current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0),
                                                                            (F.shape(query)[0], 1, 1, self.seq_length),
                                                                            (1, 1, 1, 1)),
                                                                 0), mstype.float32), (1, 2, 3))
            # Get the precise position index
            index = self.sub1(F.cast(current_index, mstype.int32), 1)
            index = F.reshape(index, (-1, 1, 1))
            # Calculate the attention_mask matrix via the position index
            attention_mask = F.cast(self.tensor_le(self.range, index), mstype.int32)
            attention_mask = self.expand_dims(attention_mask, 2)
        if bias is None and self.has_relative_bias:
            if not self.is_cross_atten:
                bias = self.bias_generator(F.shape(score)[-1], F.shape(score)[-1])
            else:
                bias = P.ExpandDims()(self.cross_bias, 0)

        score = self.add(score, bias)
        # Minus 10000 for the position where masked to exclude them from softmax
        multiplu_out = self.sub(
            P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
            P.Cast()(attention_mask, P.DType()(score)))

        adder = self.mul(multiplu_out, self.multiply_data)
        attention_scores = self.add(adder, score)
        # attention probs
        attention_probs = self._softmax(attention_scores)
        attention_probs = P.Cast()(attention_probs, ori_dtype)

        attention_probs = self.prob_dropout(attention_probs)
        # Weighted sum output [bs, num_heads, seq_length, size_per_head]
        weighted_values = self.batch_matmul(attention_probs, value)
        attention_merge = self._merge_heads(weighted_values)
        return attention_merge, bias


class TransformerEncoderLayer(nn.Cell):
    """
        Transformer Encoder Layer
    """

    def __init__(self,
                 batch_size,
                 hidden_size,
                 ffn_hidden_size,
                 num_heads,
                 seq_length,
                 kv_size=64,
                 attention_dropout_rate=0.1,
                 hidden_dropout_rate=0.1,
                 post_layernorm_residual=False,
                 layernorm_compute_type=mstype.float32,
                 softmax_compute_type=mstype.float32,
                 param_init_type=mstype.float32,
                 layer_norm_epsilon=1e-6,
                 hidden_act='gelu',
                 use_past=False,
                 moe_config=default_moe_config,
                 has_bias=False,
                 parallel_config=default_dpmp_config):
        super(TransformerEncoderLayer, self).__init__()
        if num_heads % parallel_config.model_parallel != 0:
            raise ValueError(
                "For 'TransformerEncoderLayer', the class variable 'num_heads' must be divisibled by the "
                "'parallel_config.model_parallel', but got the num_heads is {} and "
                "parallel_config.model_parallel is {}.".format(num_heads, parallel_config.model_parallel))
        if hidden_size % parallel_config.model_parallel != 0:
            raise ValueError(
                "For 'TransformerEncoderLayer', the class variable 'hidden_size' must be divisibled by "
                "the 'parallel_config.model_parallel', but got the hidden_size is {} and parallel_config."
                " model_parallel is {}.".format(hidden_size, parallel_config.model_parallel))
        if ffn_hidden_size % parallel_config.model_parallel != 0:
            raise ValueError(
                "For 'TransformerEncoderLayer', the class variable 'ffn_hidden_size' must be divisibled "
                "by the 'parallel_config.model_parallel', but got the ffn_hidden_size is {} "
                "and parallel_config. model_parallel is {}.".format(ffn_hidden_size,
                                                                    parallel_config.model_parallel))
        _check_moe_config(moe_config, parallel_config)
        self.use_moe = (moe_config.expert_num > 1)
        self.use_past = use_past
        self.seq_length = seq_length
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.layernorm1 = LayerNorm((hidden_size,), eps=layer_norm_epsilon).to_float(layernorm_compute_type)
        self.layernorm2 = LayerNorm((hidden_size,), eps=layer_norm_epsilon).to_float(layernorm_compute_type)

        self.attention = T5MultiHeadAttention(batch_size=batch_size,
                                              src_seq_length=seq_length,
                                              tgt_seq_length=seq_length,
                                              hidden_size=hidden_size,
                                              num_heads=num_heads,
                                              kv_size=kv_size,
                                              hidden_dropout_rate=hidden_dropout_rate,
                                              attention_dropout_rate=attention_dropout_rate,
                                              softmax_compute_type=softmax_compute_type,
                                              param_init_type=param_init_type,
                                              use_past=use_past,
                                              is_decoder=False,
                                              has_relative_bias=has_bias,
                                              parallel_config=parallel_config.dpmp if self.use_moe
                                              else parallel_config)
        if self.use_moe:
            self.output = MoE(hidden_size=hidden_size,
                              dropout_rate=hidden_dropout_rate,
                              ffn_hidden_size=ffn_hidden_size,
                              param_init_type=param_init_type,
                              hidden_act=hidden_act,
                              moe_config=moe_config,
                              parallel_config=parallel_config)
        else:
            # Feed Forward Network, FFN
            self.output = T5FeedFoward(hidden_size=hidden_size,
                                       dropout_rate=hidden_dropout_rate,
                                       ffn_hidden_size=ffn_hidden_size,
                                       param_init_type=param_init_type,
                                       hidden_act=hidden_act,
                                       parallel_config=parallel_config)
        self.post_layernorm_residual = post_layernorm_residual
        self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1)))
        self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
        self.dtype = mstype.float16
        self.key_past = None
        self.value_past = None

        if self.use_past:
            # operator used for state reuse
            self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
            self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
            self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
            size_per_head = int(hidden_size / num_heads)
            self.key_shape = (batch_size, num_heads, size_per_head, seq_length)
            self.value_shape = (batch_size, num_heads, seq_length, size_per_head)
            # parameters saving key and value states
            self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
            self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past")
            self.tile = P.Tile().shard(((1, 1),))
            self.mul = P.Mul().shard(((1, 1, 1, 1), (1,)))
            self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1)))

        if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
            pass
        elif _get_parallel_mode() not in (ParallelMode.AUTO_PARALLEL,):
            self.layernorm1.shard(((parallel_config.data_parallel, 1),))
            self.layernorm2.shard(((parallel_config.data_parallel, 1),))
        else:
            raise RuntimeError(f"The {self.cls_name} only support sharding propagation or "
                               f"semi-auto parallel mode now.")

    def construct(self, x, input_mask, bias, init_reset=True, batch_valid_length=None):
        """Forward function of the EncoderLayer"""
        self._check_input(x, input_mask, init_reset, batch_valid_length)
        x_shape = F.shape(x)
        x = F.reshape(x, (-1, x_shape[-1]))
        input_x = self.layernorm1(x)
        input_x = F.cast(input_x, self.dtype)

        # indicate whether reset saved states
        key_reset = None
        value_reset = None

        if self.use_past:
            # reset states, init_reset True for reuse and False for reset
            key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
            value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
            # add dependency for desired execution order
            input_x = F.depend(input_x, key_reset)
            input_x = F.depend(input_x, value_reset)

        attention, layer_present, bias = self.attention(input_x, input_x, input_x, input_mask, bias,
                                                        self.key_past, self.value_past, batch_valid_length)
        # For post-layernorm the inputs for residual path are output of self-attention and output of layernorm
        if self.post_layernorm_residual:
            x = self.add(input_x, attention)
        # For pre-layernorm the inputs for residual path are output of self-attention and input of this layer
        else:
            x = self.add(x, attention)

        output_x = self.layernorm2(x)
        output_x = F.cast(output_x, self.dtype)
        aux_loss = None
        if self.use_moe:
            mlp_logit, aux_loss = self.output(output_x)
        else:
            mlp_logit = self.output(output_x)

        value_update = None
        key_update = None
        if self.use_past:
            # current key and value
            key_present, value_present = layer_present
            # update key and value calculated this step
            key_update = self.assign(self.key_past, key_present)
            value_update = self.assign(self.value_past, value_present)
            # add dependency for desired execution order
            key_update = F.depend(key_update, key_reset)
            value_update = F.depend(value_update, value_reset)

        # add dependency for desired execution order
        mlp_logit = F.depend(mlp_logit, value_update)
        mlp_logit = F.depend(mlp_logit, key_update)

        # if shape is 3d, we reshape the inputs of the add
        if len(x_shape) == 3:
            output_x = P.Reshape()(output_x, x_shape)
            mlp_logit = P.Reshape()(mlp_logit, x_shape)
            x = P.Reshape()(x, x_shape)

            if self.post_layernorm_residual:
                output = self.add_3d(output_x, mlp_logit)
            else:
                output = self.add_3d(x, mlp_logit)
        else:
            if self.post_layernorm_residual:
                output = self.add(output_x, mlp_logit)
            else:
                output = self.add(x, mlp_logit)
            output = F.reshape(output, x_shape)

        if self.use_moe:
            return output, layer_present, aux_loss
        return output, layer_present, bias

    def _check_input(self, x, input_mask, init_reset, batch_valid_length):
        r"""Check inputs"""
        _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
        _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)

        init_reset_is_tensor = isinstance(init_reset, Tensor)
        init_reset_is_default = init_reset is True
        batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor)
        batch_is_default = batch_valid_length is None
        _check_past_none_input_none(self.use_past, "init_reset", self.cls_name, True, init_reset_is_tensor,
                                    init_reset_is_default)
        _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
                                    batch_valid_length_is_tensor, batch_is_default)

        if self.use_past:
            _check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
            _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
        return True


class TransformerDecoderLayer(nn.Cell):
    """
        The Transformer Decoder Layer
    """

    def __init__(self, hidden_size,
                 ffn_hidden_size,
                 num_heads,
                 batch_size,
                 src_seq_length,
                 tgt_seq_length,
                 kv_size=64,
                 attention_dropout_rate=0.1,
                 hidden_dropout_rate=0.1,
                 post_layernorm_residual=False,
                 use_past=False,
                 layernorm_compute_type=mstype.float32,
                 softmax_compute_type=mstype.float32,
                 param_init_type=mstype.float32,
                 layer_norm_epsilon=1e-6,
                 hidden_act='gelu',
                 has_bias=False,
                 moe_config=default_moe_config,
                 parallel_config=default_dpmp_config):
        super(TransformerDecoderLayer, self).__init__()
        if num_heads % parallel_config.model_parallel != 0:
            raise ValueError("For 'TransformerDecoderLayer', the class variable 'num_heads' must be divisibled by "
                             "'parallel_config.model_parallel', but got the num_heads is {} and "
                             "parallel_config.model_parallel is {}.".format(num_heads,
                                                                            parallel_config.model_parallel))
        if hidden_size % parallel_config.model_parallel != 0:
            raise ValueError(
                "For 'TransformerDecoderLayer', the class variable 'hidden_size' must be divisibled by "
                "'parallel_config.model_parallel', but got the hidden_size is {} and "
                "parallel_config.model_parallel is {}.".format(hidden_size, parallel_config.model_parallel))
        if ffn_hidden_size % parallel_config.model_parallel != 0:
            raise ValueError("For 'TransformerDecoderLayer', the class variable 'ffn_hidden_size' must be "
                             "divisibled by 'parallel_config.model_parallel', but got the ffn_hidden_size is {} "
                             "and parallel_config.model_parallel is {}."
                             .format(ffn_hidden_size, parallel_config.model_parallel))
        _check_moe_config(moe_config, parallel_config)
        self.use_moe = (moe_config.expert_num > 1)
        if use_past:
            raise ValueError(f"The {self.cls_name} does not support use_past=True.")
        self.batch_size = batch_size
        self.use_past = use_past
        self.softmax_compute_type = softmax_compute_type

        self.src_seq_length = src_seq_length
        self.tgt_seq_length = tgt_seq_length
        self.use_past = use_past
        self.hidden_size = hidden_size

        self.layernorm1 = LayerNorm((hidden_size,), eps=layer_norm_epsilon).to_float(layernorm_compute_type)
        self.layernorm1.shard(((parallel_config.data_parallel, 1),))
        self.layernorm2 = LayerNorm((hidden_size,), eps=layer_norm_epsilon).to_float(layernorm_compute_type)
        self.layernorm2.shard(((parallel_config.data_parallel, 1),))
        self.attention = T5MultiHeadAttention(hidden_size=hidden_size,
                                              num_heads=num_heads,
                                              batch_size=batch_size,
                                              kv_size=kv_size,
                                              src_seq_length=tgt_seq_length,
                                              tgt_seq_length=tgt_seq_length,
                                              hidden_dropout_rate=hidden_dropout_rate,
                                              attention_dropout_rate=attention_dropout_rate,
                                              use_past=use_past,
                                              softmax_compute_type=softmax_compute_type,
                                              param_init_type=param_init_type,
                                              is_decoder=True,
                                              has_relative_bias=has_bias,
                                              parallel_config=parallel_config.dpmp if self.use_moe
                                              else parallel_config)

        # Cross attention with the output of encoder as memory tensor
        self.cross_attention = T5MultiHeadAttention(hidden_size=hidden_size,
                                                    num_heads=num_heads,
                                                    batch_size=batch_size,
                                                    kv_size=kv_size,
                                                    src_seq_length=tgt_seq_length,
                                                    tgt_seq_length=src_seq_length,
                                                    hidden_dropout_rate=hidden_dropout_rate,
                                                    attention_dropout_rate=attention_dropout_rate,
                                                    softmax_compute_type=softmax_compute_type,
                                                    use_past=use_past,
                                                    is_decoder=True,
                                                    is_cross_atten=True,
                                                    has_relative_bias=has_bias,
                                                    param_init_type=param_init_type,
                                                    parallel_config=parallel_config.dpmp
                                                    if self.use_moe else parallel_config)
        self.cross_attention_layernorm = LayerNorm((hidden_size,), eps=layer_norm_epsilon).to_float(
            layernorm_compute_type)
        self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1),))

        if self.use_moe:
            self.output = MoE(hidden_size=hidden_size,
                              dropout_rate=hidden_dropout_rate,
                              ffn_hidden_size=ffn_hidden_size,
                              param_init_type=param_init_type,
                              hidden_act=hidden_act,
                              moe_config=moe_config,
                              parallel_config=parallel_config)
        else:
            # Feed Forward Network, FFN
            self.output = T5FeedFoward(hidden_size=hidden_size,
                                       dropout_rate=hidden_dropout_rate,
                                       ffn_hidden_size=ffn_hidden_size,
                                       hidden_act=hidden_act,
                                       param_init_type=param_init_type,
                                       parallel_config=parallel_config)
        self.post_layernorm_residual = post_layernorm_residual
        self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1)))
        self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
        self.dtype = mstype.float16
        self.key_past = None
        self.value_past = None
        if self.use_past:
            # operator used for state reuse
            self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
            self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
            self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
            size_per_head = int(hidden_size / num_heads)
            self.key_shape = (batch_size, num_heads, size_per_head, tgt_seq_length)
            self.value_shape = (batch_size, num_heads, tgt_seq_length, size_per_head)
            # parameters saving key and value states
            self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
            self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past")
            self.tile = P.Tile().shard(((1, 1),))
            self.mul = P.Mul().shard(((1, 1, 1, 1), (1,)))
            self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1)))

    def construct(self, hidden_stats,
                  decoder_mask,
                  encoder_output=None,
                  memory_mask=None,
                  self_bias=None,
                  encoder_attention_bias=None,
                  init_reset=True, batch_valid_length=None):
        """The forward function of the decoder layer"""
        self._check_input(hidden_stats, decoder_mask, encoder_output, memory_mask, init_reset, batch_valid_length)
        # the returned shape is [bs, seq_length, embedding_size] or [bs * seq_length, embedding_size]
        hidden_shape = F.shape(hidden_stats)
        hidden_stats = F.reshape(hidden_stats, (-1, hidden_shape[-1]))
        input_x = self.layernorm1(hidden_stats)
        input_x = F.cast(input_x, self.dtype)

        # indicate whether reset saved states
        key_reset = None
        value_reset = None
        if self.use_past:
            # reset states, init_reset True for reuse and False for reset
            key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
            value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
            # add dependency for desired execution order
            input_x = F.depend(input_x, key_reset)
            input_x = F.depend(input_x, value_reset)

        attention, layer_present, self_bias = self.attention(input_x, input_x, input_x, decoder_mask, self_bias,
                                                             self.key_past,
                                                             self.value_past, batch_valid_length)
        # For post-layernorm the inputs for residual path are output of self-attention and output of layernorm
        if self.post_layernorm_residual:
            x = self.add(input_x, attention)
        # For pre-layernorm the inputs for residual path are output of self-attention and input of this layer
        else:
            x = self.add(hidden_stats, attention)

        middle_output = None
        if encoder_output is not None:
            middle_output = self.cross_attention_layernorm(x)
            middle_output = F.cast(middle_output, self.dtype)
            encoder_output = F.cast(encoder_output, self.dtype)
            cross_attn_out, cross_layer_present, encoder_attention_bias = self.cross_attention(middle_output,
                                                                                               encoder_output,
                                                                                               encoder_output,
                                                                                               memory_mask,
                                                                                               encoder_attention_bias,
                                                                                               self.key_past,
                                                                                               self.value_past,
                                                                                               batch_valid_length)
            layer_present += cross_layer_present
            if self.post_layernorm_residual:
                x = self.add(middle_output, cross_attn_out)
            else:
                x = self.add(x, cross_attn_out)

        output_x = self.layernorm2(x)
        output_x = F.cast(output_x, self.dtype)
        aux_loss = None
        if self.use_moe:
            mlp_logit, aux_loss = self.output(output_x)
        else:
            mlp_logit = self.output(output_x)

        value_update = None
        key_update = None
        if self.use_past:
            # current key and value
            key_present, value_present = layer_present
            # update key and value calculated this step
            key_update = self.assign(self.key_past, key_present)
            value_update = self.assign(self.value_past, value_present)
            # add dependency for desired execution order
            key_update = F.depend(key_update, key_reset)
            value_update = F.depend(value_update, value_reset)

        # add dependency for desired execution order
        mlp_logit = F.depend(mlp_logit, value_update)
        mlp_logit = F.depend(mlp_logit, key_update)

        # if shape is 3d, we reshape the inputs of the add
        if len(hidden_shape) == 3:
            output_x = P.Reshape()(output_x, hidden_shape)
            mlp_logit = P.Reshape()(mlp_logit, hidden_shape)
            x = P.Reshape()(x, hidden_shape)

            if self.post_layernorm_residual:
                output = self.add_3d(output_x, mlp_logit)
            else:
                output = self.add_3d(x, mlp_logit)
        else:
            if self.post_layernorm_residual:
                output = self.add(output_x, mlp_logit)
            else:
                output = self.add(x, mlp_logit)
            output = F.reshape(output, hidden_shape)

        if self.use_moe:
            return output, layer_present, aux_loss
        return output, layer_present, self_bias, encoder_attention_bias

    def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length):
        r"""Check inputs"""
        _check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name)
        _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
        if encoder_output is not None:
            _check_input_dtype(F.dtype(encoder_output), "encoder_output",
                               [mstype.float32, mstype.float16], self.cls_name)
        if memory_mask is not None:
            _check_input_dtype(F.dtype(memory_mask), "memory_mask",
                               [mstype.float32, mstype.float16], self.cls_name)

        init_reset_is_tensor = isinstance(init_reset, Tensor)
        init_reset_is_default = init_reset is True
        batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor)
        batch_is_default = batch_valid_length is None
        _check_past_none_input_none(self.use_past, "init_reset", self.cls_name, True, init_reset_is_tensor,
                                    init_reset_is_default)
        _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
                                    batch_valid_length_is_tensor, batch_is_default)

        if self.use_past:
            _check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
            _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
        return True


def _get_lambda_func(total_layer=None):
    r"""
    A wrapper function of specifying pipeline stage and gradient aggregation fusion. If the total layer
    is not None, for example, set in the transformer model, the pipeline stage setting function will be
    `(layer_id + 0) // (total_layers / parallel_config.pipeline_stage)` for the encoder and,
    `(layer_id + offset) //
    (total_layers / parallel_config.pipeline_stage)` for the decoder, where `offset` is the layers in the encoder.
    """

    def _set_parallel_configure_for_layer(network, layer_id, offset, parallel_config, layers):
        r"""
        Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.

        Args:
            network(Cell) - Represents the transformer block
            layer_id(int) - Means the layer index for the current module, counts from zero.
            offset(int) - Means the layer_index needs an offset, if there are other modules in the net.
            layers(int) - The total layers used for the model.
        """
        # override the layers
        if total_layer:
            layers = total_layer
        # Used for the pipeline's stages setting
        if layers < parallel_config.pipeline_stage:
            raise ValueError(f"layers {layers} must be larger than pipeline stage {parallel_config.pipeline_stage}")

        pp_dis = max(int(layers / parallel_config.pipeline_stage), 1)
        # the pipeline stage must be in [0, parallel_config.pipeline_stage - 1]
        pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1)
        network.pipeline_stage = pp_id

        # Used for optimizer's fusion tag
        dis = max(int(layers / parallel_config.gradient_aggregation_group), 1)
        network.set_comm_fusion(int((layer_id + offset) / dis) + 1)
        # Used for enabling recomputation of the block
        if isinstance(parallel_config.recompute, bool):
            if parallel_config.recompute:
                network.recompute()
        else:
            if parallel_config.recompute.recompute:
                paralel_op_comm_compute = parallel_config.recompute.parallel_optimizer_comm_recompute
                network.recompute(parallel_optimizer_comm_recompute=paralel_op_comm_compute,
                                  mp_comm_recompute=parallel_config.recompute.mp_comm_recompute,
                                  recompute_slice_activation=parallel_config.recompute.recompute_slice_activation)

    return _set_parallel_configure_for_layer


class TransformerEncoder(nn.Cell):
    """The TransformerEncoder Cell"""

    def __init__(self,
                 batch_size,
                 num_layers,
                 hidden_size,
                 ffn_hidden_size,
                 seq_length,
                 num_heads,
                 kv_size=64,
                 attention_dropout_rate=0.1,
                 hidden_dropout_rate=0.1,
                 hidden_act='gelu',
                 layer_norm_epsilon=1e-6,
                 post_layernorm_residual=False,
                 layernorm_compute_type=mstype.float32,
                 softmax_compute_type=mstype.float32,
                 param_init_type=mstype.float32,
                 lambda_func=None,
                 offset=0,
                 use_past=False,
                 moe_config=default_moe_config,
                 parallel_config=default_transformer_config):
        super(TransformerEncoder, self).__init__()
        _check_moe_config(moe_config, parallel_config)
        self.use_moe = (moe_config.expert_num > 1)
        self.add = P.Add()
        self.aux_loss = Tensor(0.0, mstype.float32)
        self.num_layers = num_layers
        self.blocks = nn.CellList()
        for i in range(num_layers):
            block = TransformerEncoderLayer(hidden_size=hidden_size,
                                            batch_size=batch_size,
                                            ffn_hidden_size=ffn_hidden_size,
                                            seq_length=seq_length,
                                            attention_dropout_rate=attention_dropout_rate,
                                            hidden_dropout_rate=hidden_dropout_rate,
                                            layernorm_compute_type=layernorm_compute_type,
                                            softmax_compute_type=softmax_compute_type,
                                            kv_size=kv_size,
                                            num_heads=num_heads,
                                            hidden_act=hidden_act,
                                            has_bias=(i == 0),
                                            layer_norm_epsilon=layer_norm_epsilon,
                                            post_layernorm_residual=post_layernorm_residual,
                                            param_init_type=param_init_type,
                                            use_past=use_past,
                                            moe_config=moe_config,
                                            parallel_config=parallel_config.moe_parallel_config if self.use_moe
                                            else parallel_config.dp_mp_config)
            # If the user doesn't pass the fusion function, use the default one
            if not lambda_func:
                lambda_func = _get_lambda_func()

            lambda_func(block, layer_id=i, layers=num_layers,
                        offset=offset, parallel_config=parallel_config)
            self.blocks.append(block)

        if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
            pass
        elif _get_parallel_mode() not in (ParallelMode.AUTO_PARALLEL,):
            logger.warning("For parallel mode, sharding propagation is recommended, you can use it by setting "
                           "'set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, "
                           "search_mode=\"sharding_propagation\")' and "
                           "'set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)'")
        else:
            raise RuntimeError(f"The {self.cls_name} only support sharding propagation or "
                               f"semi-auto parallel mode now.")

    def construct(self, hidden_states, attention_mask, init_reset=True, batch_valid_length=None):
        """The forward process of the encoder"""
        present_layer = ()
        attention_bias = None
        if self.use_moe:
            accum_loss = self.aux_loss
            for i in range(self.num_layers):
                hidden_states, present, aux_loss = self.blocks[i](hidden_states,
                                                                  attention_mask,
                                                                  init_reset,
                                                                  batch_valid_length)
                present_layer = present_layer + (present,)
                accum_loss = self.add(accum_loss, aux_loss)
            return hidden_states, present_layer, accum_loss

        for i in range(self.num_layers):
            hidden_states, present, attention_bias = self.blocks[i](hidden_states,
                                                                    attention_mask,
                                                                    attention_bias,
                                                                    init_reset,
                                                                    batch_valid_length)
            present_layer = present_layer + (present,)

        return hidden_states, present_layer


class TransformerDecoder(nn.Cell):
    """The TransformerDecoder cell"""

    def __init__(self,
                 num_layers,
                 batch_size,
                 hidden_size,
                 ffn_hidden_size,
                 src_seq_length,
                 tgt_seq_length,
                 num_heads,
                 kv_size=64,
                 attention_dropout_rate=0.1,
                 hidden_dropout_rate=0.1,
                 post_layernorm_residual=False,
                 layernorm_compute_type=mstype.float32,
                 softmax_compute_type=mstype.float32,
                 param_init_type=mstype.float32,
                 layer_norm_epsilon=1e-6,
                 hidden_act='gelu',
                 lambda_func=None,
                 use_past=False,
                 offset=0,
                 moe_config=default_moe_config,
                 parallel_config=default_transformer_config):
        super(TransformerDecoder, self).__init__()
        self.add = P.Add()
        self.aux_loss = Tensor(0.0, mstype.float32)
        self.num_layers = num_layers
        self.blocks = nn.CellList()
        _check_moe_config(moe_config, parallel_config)
        self.use_moe = (moe_config.expert_num > 1)
        for i in range(num_layers):
            block = TransformerDecoderLayer(hidden_size=hidden_size,
                                            batch_size=batch_size,
                                            ffn_hidden_size=ffn_hidden_size,
                                            src_seq_length=src_seq_length,
                                            tgt_seq_length=tgt_seq_length,
                                            attention_dropout_rate=attention_dropout_rate,
                                            hidden_dropout_rate=hidden_dropout_rate,
                                            num_heads=num_heads,
                                            layernorm_compute_type=layernorm_compute_type,
                                            softmax_compute_type=softmax_compute_type,
                                            hidden_act=hidden_act,
                                            kv_size=kv_size,
                                            use_past=use_past,
                                            has_bias=(i == 0),
                                            param_init_type=param_init_type,
                                            post_layernorm_residual=post_layernorm_residual,
                                            layer_norm_epsilon=layer_norm_epsilon,
                                            moe_config=moe_config,
                                            parallel_config=parallel_config.moe_parallel_config if self.use_moe
                                            else parallel_config.dp_mp_config)
            # If the user doesn't pass the fusion function, use the default one
            if not lambda_func:
                lambda_func = _get_lambda_func()

            lambda_func(block, layer_id=i, layers=num_layers,
                        offset=offset, parallel_config=parallel_config)

            self.blocks.append(block)

        if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
            pass
        elif _get_parallel_mode() not in (ParallelMode.AUTO_PARALLEL,):
            logger.warning("For parallel mode, sharding propagation is recommended, you can use it by setting "
                           "'set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, "
                           "search_mode=\"sharding_propagation\")' and "
                           "'set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)'")
        else:
            raise RuntimeError(f"The {self.cls_name} only support sharding propagation or "
                               f"semi-auto parallel mode now.")

    def construct(self, hidden_states, attention_mask, encoder_output=None, memory_mask=None,
                  init_reset=True, batch_valid_length=None):
        """For forward process of the decoder"""
        present_layer = ()
        self_bias = None
        encoder_decoder_bias = None
        if self.use_moe:
            accum_loss = self.aux_loss
            for i in range(self.num_layers):
                hidden_states, present, aux_loss = self.blocks[i](hidden_states,
                                                                  attention_mask,
                                                                  encoder_output,
                                                                  memory_mask,
                                                                  init_reset,
                                                                  batch_valid_length)
                present_layer = present_layer + (present,)
                accum_loss = self.add(accum_loss, aux_loss)
            return hidden_states, present_layer, accum_loss

        # Loop through each self-attention layer
        for i in range(self.num_layers):
            hidden_states, present, self_bias, encoder_decoder_bias = self.blocks[i](hidden_states,
                                                                                     attention_mask,
                                                                                     encoder_output,
                                                                                     memory_mask,
                                                                                     self_bias,
                                                                                     encoder_decoder_bias,
                                                                                     init_reset,
                                                                                     batch_valid_length)
            present_layer = present_layer + (present,)

        return hidden_states, present_layer


def position_encoding(length,
                      depth,
                      min_timescale=1,
                      max_timescale=1e4):
    """
    Create Tensor of sinusoids of different frequencies.

    Args:
        length (int): Length of the Tensor to create, i.e. Number of steps.
        depth (int): Hidden size.
        min_timescale (float): Default: 1.
        max_timescale (float): Default: 10000.

    Returns:
        Tensor of shape (length, depth)
    """
    depth = depth // 2
    positions = np.arange(length, dtype=np.float32)
    log_timescale_increment = (np.log(max_timescale / min_timescale) / (depth - 1))
    inv_timescales = min_timescale * np.exp(np.arange(depth, dtype=np.float32) * -log_timescale_increment)
    scaled_time = np.expand_dims(positions, 1) * np.expand_dims(inv_timescales, 0)
    x = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
    return x


class EmbeddingPostprocessor(nn.Cell):
    """
    Postprocessors apply positional embeddings to word embeddings.

    Args:
        embedding_size (int): The size of each embedding vector.
        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
        max_position_embeddings (int): Maximum length of sequences used in this
                                 model. Default: 128.
        dropout_prob (float): The dropout probability. Default: 0.1.
    """

    def __init__(self,
                 embedding_size,
                 max_position_embeddings=128,
                 dropout_prob=0.1):
        super(EmbeddingPostprocessor, self).__init__()
        self.scores_mul = Tensor([math.sqrt(float(embedding_size))], dtype=mstype.float32)
        self.multiply = ops.Mul()
        self.add = ops.Add()
        self.dropout = nn.Dropout(1 - dropout_prob, dtype=mstype.float32)
        self.expand_dims = ops.ExpandDims()
        self.position_embedding_table = Tensor(position_encoding(max_position_embeddings, embedding_size),
                                               mstype.float32)
        self.shape = ops.Shape()
        self.slice = ops.StridedSlice().shard(((1, 1),))

    def construct(self, word_embeddings):
        """Postprocessors apply positional embeddings to word embeddings."""
        output = self.multiply(word_embeddings, self.scores_mul)

        output = self.dropout(output)
        return output


class CastWrapper(nn.Cell):
    """
    Cast wrapper.
    """

    def __init__(self, dst_type=mstype.float32):
        super(CastWrapper, self).__init__()
        self.cast = ops.Cast()
        self.dst_type = dst_type

    def construct(self, x):
        return self.cast(x, self.dst_type)


class CreateAttentionMaskFromInputMask(nn.Cell):
    """
    Create attention mask according to input mask.

    Args:
        config (:class:`TransformerConfig`): Configuration for Transformer.
    """

    def __init__(self, parallel_config):
        super(CreateAttentionMaskFromInputMask, self).__init__()
        self.reshape = ops.Reshape()
        self.shape = ops.Shape()
        self.batch_matmul = ops.BatchMatMul().shard(
            ((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))

    def construct(self, input_mask):
        """Create attention mask according to input mask."""
        input_shape = self.shape(input_mask)
        shape_right = (input_shape[0], 1, input_shape[1])
        shape_left = input_shape + (1,)

        input_mask = F.cast(input_mask, mstype.float32)
        mask_left = self.reshape(input_mask, shape_left)
        mask_right = self.reshape(input_mask, shape_right)
        attention_mask = self.batch_matmul(mask_left, mask_right)

        return attention_mask


@constexpr
def convert_np_to_tensor_encoder(seq_length):
    ones = np.ones(shape=(seq_length, seq_length))
    return Tensor(np.tril(ones), dtype=mstype.float32)


class T5Head(nn.Cell):
    """
    Head to get the logits of each token in the vocab
    Args:
        config(): the config of network
    Inputs:
        state: the output of the backbone
        embedding_table: the embedding table of the vocabulary
    Returns:
        logits: Tensor, the logits of the corresponding inputs
    """

    def __init__(self,
                 hidden_size,
                 compute_dtype=mstype.float16,
                 parallel_config=None):
        super(T5Head, self).__init__()
        if parallel_config.vocab_emb_dp:
            self.matmul = ops.MatMul(transpose_b=True).shard(((parallel_config.data_parallel, 1), (1, 1)))
        else:
            self.matmul = ops.MatMul(transpose_b=True).shard(((parallel_config.data_parallel, 1), (
                parallel_config.model_parallel, 1)))
        self.hidden_size = hidden_size
        self.dtype = compute_dtype
        self.cast = ops.Cast()

    def construct(self, state, embed):
        state = ops.Reshape()(state, (-1, self.hidden_size))
        # output logits over vocabulary [bs*seq_length, vocab_size]
        logits = self.matmul(self.cast(state, self.dtype), self.cast(embed, self.dtype))
        return logits


class T5Model(BaseModel):
    """
    T5Model with encoder and decoder.

    Args:
        config (Class): Configuration for T5Model.
    """

    def __init__(self,
                 config):
        super(T5Model, self).__init__(config)

        self.batch_size = config.batch_size
        self.hidden_size = config.hidden_size
        self.max_decode_length = config.max_decode_length
        self.scale_output = config.scale_output
        embedding_config = EmbeddingOpParallelConfig(data_parallel=config.parallel_config.data_parallel,
                                                     model_parallel=config.parallel_config.model_parallel)
        self.tfm_embedding_lookup = VocabEmbedding(vocab_size=config.vocab_size,
                                                   embedding_size=config.hidden_size,
                                                   parallel_config=embedding_config)
        self.tfm_embedding_postprocessor_for_encoder = EmbeddingPostprocessor(embedding_size=
                                                                              config.hidden_size,
                                                                              max_position_embeddings=
                                                                              config.max_position_embeddings,
                                                                              dropout_prob=
                                                                              config.embedding_dropout_prob)
        self.tfm_embedding_postprocessor_for_decoder = EmbeddingPostprocessor(
            embedding_size=config.hidden_size,
            max_position_embeddings=config.max_position_embeddings,
            dropout_prob=config.embedding_dropout_prob)
        self.tfm_encoder = TransformerEncoder(
            batch_size=self.batch_size,
            hidden_size=config.hidden_size,
            num_heads=config.num_heads,
            num_layers=config.num_layers,
            seq_length=config.seq_length,
            ffn_hidden_size=config.d_ff,
            attention_dropout_rate=config.attention_dropout_rate,
            hidden_dropout_rate=config.hidden_dropout_rate,
            layer_norm_epsilon=config.layer_norm_epsilon,
            kv_size=config.kv_size,
            hidden_act=config.hidden_act,
            param_init_type=config.param_init_type,
            layernorm_compute_type=config.layernorm_compute_type,
            softmax_compute_type=config.softmax_compute_type,
            post_layernorm_residual=config.post_layernorm_residual,
            offset=config.offset,
            use_past=config.use_past,
            moe_config=config.moe_config)

        self.tfm_decoder = TransformerDecoder(
            batch_size=self.batch_size,
            hidden_size=config.hidden_size,
            src_seq_length=config.seq_length,
            tgt_seq_length=config.max_decode_length,
            num_heads=config.num_heads,
            ffn_hidden_size=config.d_ff,
            attention_dropout_rate=config.attention_dropout_rate,
            hidden_dropout_rate=config.hidden_dropout_rate,
            layer_norm_epsilon=config.layer_norm_epsilon,
            kv_size=config.kv_size,
            num_layers=config.num_decoder_layers if config.num_decoder_layers else config.num_layers,
            hidden_act=config.hidden_act,
            param_init_type=config.param_init_type,
            layernorm_compute_type=config.layernorm_compute_type,
            softmax_compute_type=config.softmax_compute_type,
            post_layernorm_residual=config.post_layernorm_residual,
            offset=config.offset,
            use_past=config.use_past,
            moe_config=config.moe_config)

        self.projection = T5Head(config.hidden_size,
                                 compute_dtype=mstype.float16,
                                 parallel_config=config.parallel_config)

        self.cast = ops.Cast()
        self.dtype = config.dtype
        self.cast_compute_type = CastWrapper(dst_type=config.compute_dtype)
        self.expand = ops.ExpandDims()
        self.multiply = ops.Mul()
        self.shape = ops.Shape()
        self.encoder_layernorm = LayerNorm(normalized_shape=(config.hidden_size,),
                                           eps=config.layer_norm_epsilon).to_float(mstype.float32)
        self.decoder_layernorm = LayerNorm(normalized_shape=(config.hidden_size,),
                                           eps=config.layer_norm_epsilon).to_float(mstype.float32)
        self.encoder_layernorm.shard(((config.parallel_config.data_parallel, 1),))
        self.decoder_layernorm.shard(((config.parallel_config.data_parallel, 1),))
        self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config.parallel_config)
        self.ones_like = P.OnesLike()

    def construct(self,
                  source_ids=None,
                  source_mask=None,
                  target_ids=None,
                  target_mask=None,
                  memory_mask=None,
                  encoder_cache=None):
        """T5Model with encoder and decoder."""
        if source_mask is None and source_ids is not None:
            source_mask = self.ones_like(source_ids)
            source_mask = self._create_attention_mask_from_input_mask(source_mask)
        if source_ids is not None:
            encoder_output = self.encoder_forward(source_ids, source_mask)
        else:
            encoder_output = encoder_cache

        if target_ids is None:
            return encoder_output

        # process target sentence
        tgt_embedding_output, embedding_table = self.tfm_embedding_lookup(target_ids)
        # attention mask [batch_size, seq_length, seq_length]
        tgt_length = self.shape(target_ids)[1]

        if memory_mask is None:
            memory_mask = self.create_memory_mask(source_mask, target_mask)

        if len(ops.shape(target_mask)) == 2:
            future_mask = convert_np_to_tensor_encoder(tgt_length)
            tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask)
            tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(future_mask, 0))
        else:
            tgt_attention_mask = target_mask

        # transformer decoder
        decoder_output, _ = self.tfm_decoder(self.cast_compute_type(tgt_embedding_output),
                                             self.cast_compute_type(tgt_attention_mask),
                                             encoder_output, memory_mask)
        decoder_output = self.decoder_layernorm(decoder_output)

        if self.scale_output:
            decoder_output = decoder_output * (self.hidden_size ** -0.5)
        # calculate logits and log_probs
        log_probs = self.projection(decoder_output, embedding_table)

        return log_probs

    def encoder_forward(self, source_ids, source_mask):
        """Execute the forward process"""
        # process source sentence
        src_embedding_output, _ = self.tfm_embedding_lookup(source_ids)
        # attention mask [batch_size, seq_length, seq_length]
        if len(F.shape(source_mask)) == 2:
            enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask)
        else:
            enc_attention_mask = source_mask
        # transformer encoder
        encoder_output, _ = self.tfm_encoder(self.cast_compute_type(src_embedding_output),
                                             self.cast_compute_type(enc_attention_mask))
        encoder_output = self.encoder_layernorm(encoder_output)

        return encoder_output

    def create_memory_mask(self, source_mask, target_mask):
        memory_mask = P.Ones()((F.shape(source_mask)[0],
                                F.shape(target_mask)[-1], F.shape(source_mask)[-1]), mstype.float32)
        memory_mask = memory_mask * F.expand_dims(source_mask, 1)
        memory_mask = memory_mask * F.expand_dims(target_mask, 2)
        return memory_mask


[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS) class T5ForConditionalGeneration(BaseModel): """ A T5 model with the loss added. Args: config(T5Config) : The network of the transformer. Examples: >>> from mindformers import T5ForConditionalGeneration, T5Tokenizer >>> model = T5ForConditionalGeneration.from_pretrained('t5_small') >>> tokenizer = T5Tokenizer.from_pretrained('t5_small') >>> src_output = tokenizer(["hello world"], padding='max_length', max_length=model.config.seq_length, ... return_tensors='ms') >>> model_input = tokenizer(["So happy to see you!"], padding='max_length', ... max_length=model.config.max_decode_length, ... return_tensors='ms')["input_ids"] >>> input_ids = src_output['input_ids'] >>> attention_mask = src_output['attention_mask'] >>> output = model(input_ids, attention_mask, model_input) >>> print(output) [5.64458] """ _support_list = MindFormerBook.get_model_support_list()['t5'] def __init__(self, config: T5Config): super(T5ForConditionalGeneration, self).__init__(config) parallel_config = config.parallel_config self.t5_model = T5Model(config=config) self.loss = CrossEntropyLoss(parallel_config=OpParallelConfig(data_parallel=parallel_config.data_parallel, model_parallel=parallel_config.model_parallel)) self.cast = ops.Cast() self.shape = ops.Shape() # The value of start and end should get from the tokenizer self.start_token = Tensor(np.zeros((1, 1)).astype(np.int32)) self.eod_token = Tensor(np.zeros((1, 1)).astype(np.int32)) self.concat = P.Concat(axis=1) self.tile = P.Tile() # disable the bias for param in self.trainable_params(): if ('bias' in param.name or 'beta' in param.name) and 'relative' not in param.name: param.requires_grad = False self.set_train(True) self.load_checkpoint(config) def _add_start_to_inputs(self, target_ids): """concat the start id to the decoder inputs""" start_token = self.tile(self.start_token, (F.shape(target_ids)[0], 1)) decoder_inputs = self.concat((start_token, target_ids)) return decoder_inputs def _add_eos_to_inputs(self, target_ids): """concat the eos id to the end of the decoder inputs""" eod_token = self.tile(self.eod_token, (F.shape(target_ids)[0], 1)) inputs_with_eos = self.concat((target_ids, eod_token)) return inputs_with_eos
[文档] def encoder_forward(self, source_ids, source_mask): """Execute the encoder forward process""" return self.t5_model.encoder_forward(source_ids, source_mask)
def construct(self, input_ids, attention_mask, labels=None, decoder_input_ids=None, decoder_attention_mask=None, memory_mask=None, encoder_outputs=None, return_loss=False): """t5_model network with loss.""" if decoder_attention_mask is None: decoder_attention_mask = F.cast(labels != 0, mstype.float32) if decoder_input_ids is None: decoder_input_ids = self._add_start_to_inputs(labels[:, :-1]) logits = self.t5_model(input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, memory_mask, encoder_cache=encoder_outputs) total_loss = None if labels is not None: label_ids = ops.Reshape()(labels, (-1,)) label_weights = ops.Reshape()(decoder_attention_mask, (-1,)) total_loss = self.loss(logits, label_ids, self.cast(label_weights, mstype.float32)) if self.training: return total_loss if return_loss: return total_loss return logits