# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bloom model"""
import copy
import os
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import nn
from mindspore import Tensor
from mindspore.common.initializer import initializer
from mindspore.ops import operations as P
from mindformers.modules.transformer import VocabEmbedding
from mindformers.modules.layers import LayerNorm, AlibiTensor
from mindformers.core.loss import CrossEntropyLoss
from mindformers.models.base_model import BaseModel
from mindformers.tools.logger import logger
from mindformers.tools.register import MindFormerRegister, MindFormerModuleType
from mindformers.mindformer_book import MindFormerBook
from .layers import BloomBlocks, CausalMask
from .bloom_config import BloomConfig
from ..utils import convert_mstype, cell_reuse
def jit_inference_with_condition():
"""allow jit inference"""
def decorator(func):
if os.getenv("JIT_INFERENCE", "NOT_FOUND") == "NOT_FOUND":
return func
from mindspore import jit, JitConfig
dec = jit(jit_config=JitConfig(jit_level="O2"))
return dec(func)
return decorator
class BloomEmbeddingLayer(nn.Cell):
"""The Embedding Layer of Bloom network."""
def __init__(self, config=None):
super(BloomEmbeddingLayer, self).__init__(auto_prefix=False)
parallel_config = copy.deepcopy(config.parallel_config)
embedding_mp = config.parallel_config.embedding_dp_mp_config.model_parallel
vocab_size = config.vocab_size
if vocab_size % embedding_mp != 0:
logger.warning("The vocab size of embedding layer is: %s, it is not divide by model_parallel: %s",
vocab_size, embedding_mp)
logger.warning("Now, model_parallel will be changed: mp = 1")
parallel_config.embedding_dp_mp_config.model_parallel = 1
self.word_embedding = VocabEmbedding(vocab_size=vocab_size,
embedding_size=config.hidden_size,
param_init=initializer("normal",
[vocab_size, config.hidden_size],
dtype=config.embedding_init_type),
parallel_config=config.parallel_config.embedding_dp_mp_config)
self.norm = LayerNorm((config.hidden_size,)).shard(((1, 1, 1), (1,), (1,)))
def construct(self, input_ids):
"""The forward compute of Embedding Layer."""
word_embedding, word_table = self.word_embedding(input_ids)
embedding = self.norm(word_embedding)
embedding = embedding.astype(mstype.float16)
return embedding, word_table
def set_parallel_configure_for_layer(network, layer_id, offset, parallel_config, layers, use_select_recompute=False):
"""
Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
Args:
network(Cell) - Represents the transformer block
parallel_config(dict) - Parallel Config
layer_id(int) - Means the layer index for the current module, counts from zero.
offset(int) - Means the layer_index needs a offset, if there are other modules in the net.
layers(int) - The total layers used for the model.
use_select_recompute(bool) - Indicates whether to use the select recompute mode.
"""
pp_dis = max(int((layers + 1) / parallel_config.pipeline_stage), 1)
pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1)
network.pipeline_stage = pp_id
# Used for optimizer's fusion tag
dis = max(int((layers + 1) / parallel_config.gradient_aggregation_group), 1)
if parallel_config.pipeline_stage > 1:
network.set_comm_fusion(2)
else:
network.set_comm_fusion(int((layer_id + offset) / dis) + 1)
if isinstance(parallel_config.recompute, bool):
if parallel_config.recompute and not use_select_recompute:
network.recompute()
else:
if parallel_config.recompute.recompute and not use_select_recompute:
network.recompute(recompute_slice_activation=parallel_config.recompute.recompute_slice_activation)
[文档]class BloomModel(nn.Cell):
"""
The backbone of Bloom network
Args:
config(BloomConfig): The config of network
Inputs:
input_ids(Tensor): The tokenized inputs with datatype int32
input_mask(Tensor): The mask indicating whether each position is a valid input
Returns:
output_state(Tensor): The output logit of backbone
embedding_table(Tensor): The embedding table for the vocabulary
"""
def __init__(self, config):
super(BloomModel, self).__init__()
self.embedding = BloomEmbeddingLayer(config)
self.embedding.pipeline_stage = 0
self.make_causal_attention = CausalMask(seq_length=config.seq_length,
parallel_config=config.parallel_config.dp_mp_config)
self.build_alibi_tensor = AlibiTensor(seq_length=config.seq_length,
num_heads=config.num_heads,
parallel_config=config.parallel_config)
self.blocks = BloomBlocks(hidden_size=config.hidden_size,
batch_size=config.batch_size,
ffn_hidden_size=config.hidden_size * config.expand_ratio,
seq_length=config.seq_length,
num_layers=config.num_layers,
num_heads=config.num_heads,
attention_dropout_rate=config.attention_dropout_rate,
hidden_dropout_rate=config.hidden_dropout_rate,
hidden_act=config.hidden_act,
lambda_func=set_parallel_configure_for_layer,
param_init_type=config.param_init_type,
layernorm_compute_type=config.layernorm_compute_type,
softmax_compute_type=config.softmax_compute_type,
use_past=config.use_past,
use_seq_parallel=config.use_seq_parallel,
use_select_recompute=config.use_select_recompute,
parallel_config=config.parallel_config).blocks
self.num_layers = config.num_layers
self.ln_f = LayerNorm((config.hidden_size,)).to_float(config.layernorm_compute_type)
if config.parallel_config.pipeline_stage > 1:
self.ln_f.set_comm_fusion(2)
else:
self.ln_f.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
self.ln_f.shard(((config.parallel_config.data_parallel, 1, 1),))
self.ln_f.pipeline_stage = config.parallel_config.pipeline_stage - 1
self.use_past = config.use_past
self.dtype = convert_mstype(config.param_init_type)
self.mul_init_reset = P.Mul().shard(
((config.parallel_config.data_parallel, config.parallel_config.model_parallel, 1, 1), (1,)))
def construct(self, input_ids, input_mask, init_reset=True, batch_valid_length=None):
"""Bloom model"""
input_embedding, embedding_table = self.embedding(input_ids)
hidden_states = input_embedding
hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.reshape((-1, hidden_states_shape[-1]))
causal_mask = self.make_causal_attention(input_mask)
alibi_tensor = self.build_alibi_tensor(input_mask, hidden_states.dtype)
if self.use_past:
init_reset = self.mul_init_reset(self.blocks[0].key_past, init_reset.astype(self.dtype))
for i in range(self.num_layers):
hidden_states, _ = self.blocks[i](hidden_states, alibi_tensor, causal_mask, init_reset, batch_valid_length)
hidden_states = hidden_states.reshape(hidden_states_shape)
output_state = self.ln_f(hidden_states)
return output_state, embedding_table
class BloomHead(nn.Cell):
"""Head for Bloom to get the logits of each token in the vocab."""
def __init__(self,
hidden_size,
vocab_size,
compute_type="float16",
parallel_config=None):
super(BloomHead, self).__init__()
copied_parallel_config = copy.deepcopy(parallel_config)
mp = copied_parallel_config.model_parallel
if vocab_size % mp != 0:
logger.warning("The vocab size of BloomHead MatMul is: %s, it is not divide by model_parallel: %s",
vocab_size, mp)
logger.warning("Now, the model_parallel num of BloomHead MatMul will be changed: mp = 1")
copied_parallel_config.model_parallel = 1
if copied_parallel_config.pipeline_stage > 1:
copied_parallel_config.vocab_emb_dp = False
if copied_parallel_config.vocab_emb_dp:
self.matmul = P.MatMul(transpose_b=True).shard(((copied_parallel_config.data_parallel, 1), (1, 1)))
else:
self.matmul = P.MatMul(transpose_b=True).shard(((copied_parallel_config.data_parallel, 1), (
copied_parallel_config.model_parallel, 1)))
self.hidden_size = hidden_size
self.dtype = convert_mstype(compute_type)
def construct(self, state, embedding_table):
ori_dtype = state.dtype
state = state.reshape((-1, self.hidden_size))
logits = self.matmul(state.astype(self.dtype), embedding_table.astype(self.dtype))
return logits.astype(ori_dtype)
[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS)
class BloomLMHeadModel(BaseModel):
"""
Provide bloom training loss or logits through network.
Args:
config (BloomConfig): The config of BloomModel.
Returns:
Tensor, the loss or logits of the network.
"""
_support_list = MindFormerBook.get_model_support_list()['bloom']
@cell_reuse()
def __init__(self, config=None):
config = config if config is not None else BloomConfig()
super(BloomLMHeadModel, self).__init__(config, auto_prefix=False)
self.use_past = self.config.use_past
self.is_sample_acceleration = self.config.is_sample_acceleration
if self.use_past:
self.input_mask_all_ones = Tensor(
np.ones((self.config.batch_size, self.config.seq_length), np.float32), mstype.float32)
if self.is_sample_acceleration:
self.p_all_ones = Tensor(np.ones((self.config.batch_size, 1), np.float32), mstype.float32)
self.eos_token_id = self.config.eos_token_id
parallel_config = self.config.parallel_config
self.stridedslice = P.StridedSlice().shard(((parallel_config.data_parallel, 1),))
self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ()))
self.gt = P.Greater().shard(((parallel_config.data_parallel, 1), ()))
self.mul = P.Mul().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1)))
self.abs = P.Abs().shard(((parallel_config.data_parallel, 1),))
self.transformer = BloomModel(self.config)
self.head = BloomHead(hidden_size=config.hidden_size,
vocab_size=config.vocab_size,
parallel_config=self.config.parallel_config)
if parallel_config.pipeline_stage > 1:
self.head.pipeline_stage = parallel_config.pipeline_stage - 1
self.transformer.embedding.word_embedding.embedding_table.add_pipeline_stage(self.head.pipeline_stage)
mp = config.parallel_config.model_parallel
vocab_size = config.vocab_size
loss_parallel_config = copy.deepcopy(parallel_config)
if vocab_size % mp != 0:
logger.warning("The vocab size of Bloom Loss is: %s, it is not divide by model_parallel: %s",
vocab_size, mp)
logger.warning("Now, the model_parallel num of Bloom Loss will be changed: mp = 1")
loss_parallel_config.model_parallel = 1
self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config)
self.load_checkpoint(config)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {
"input_ids": Tensor(input_ids, mstype.int32)
}
# pylint: disable=W0613
@jit_inference_with_condition()
def construct(self, input_ids, input_position=None, position_ids=None, attention_mask=None,
input_embeds=None, labels=None, init_reset=True, batch_valid_length=None):
"""
construct function for Language Modeling
Args:
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
input_position(Tensor): current position, used by model.predict. Default None.
position_ids(Tensor): Reserved param, not used.
attention_mask(Tensor): Reserved param, not used.
input_embeds(Tensor): Reserved param, not used.
labels(Tensor): Reserved param, not used.
init_reset(bool, optional): A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Default True.
batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental
prediction. Tensor of shape :math:`(batch_size,)`. Default None.
Returns:
logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits,
otherwise, return the computed loss.
"""
batch_size, seq_length = input_ids.shape
if self.training:
tokens = self.stridedslice(input_ids, (0, 0), (batch_size, seq_length - 1), (1, 1))
else:
tokens = input_ids
input_mask = self.not_equal(tokens, self.eos_token_id).astype(mstype.float32) \
if not self.use_past else self.input_mask_all_ones
loss_mask = self.mul(input_mask, self.gt(tokens, 0).astype(mstype.float32))
tokens = self.abs(tokens)
# [batch_size, seq_length, vocab_size]
output_states, embedding_table = self.transformer(tokens, input_mask, init_reset, batch_valid_length)
logits = self.head(output_states, embedding_table)
if not self.training:
if self.is_sample_acceleration:
return self.get_top_token_id(logits, current_index=input_position)
return logits, tokens, input_mask
labels = self.stridedslice(input_ids, (0, 1), (batch_size, seq_length), (1, 1))
labels = labels.reshape((-1,))
loss_mask = loss_mask.reshape((-1,))
loss = self.loss(logits, labels, loss_mask)
return loss
[文档] def get_top_token_id(self, logits, current_index=None):
"""get_top_token_id"""
logits = logits.reshape(-1, logits.shape[-1])
if self.use_past and not self.is_first_iteration:
logits = logits
elif current_index is not None:
index = current_index.view(-1,)
logits = P.Gather()(logits, index, 0)
probabilities = P.Softmax(-1)(logits)
top_token_id = P.Argmax(-1)(probabilities)
top_token_id = top_token_id.view(-1, 1)
return self.p_all_ones, top_token_id