mindformers.models.glm.glm 源代码

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GLM model."""
import os
import numpy as np

import mindspore as ms
from mindspore import dtype as mstype
from mindspore import nn, ops, Tensor
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter

from mindformers.mindformer_book import MindFormerBook
from mindformers.modules.transformer import VocabEmbedding, EmbeddingOpParallelConfig, OpParallelConfig
from mindformers.tools.register import MindFormerRegister, MindFormerModuleType
from mindformers.core.loss import CrossEntropyLoss
from mindformers.modules.layers import LayerNorm
from mindformers.tools.utils import is_version_ge
from mindformers.pet.tuners.pet_adapter import PetAdapter
from mindformers.pet.tuners.lora_adapter import LoraAdapter

from .glm_config import GLMConfig
from .layers import DeepNormWithGLULayer
from ..base_model import BaseModel

#  Get MS backend: 0 vm 1 GE
is_ge = os.getenv('MS_ENABLE_GE')
if is_ge == '1':
    jit_level = "O3"
else:
    jit_level = "O1"

default_dpmp_config = OpParallelConfig()
default_embedding_parallel_config = EmbeddingOpParallelConfig()

__all__ = ['GLMForPreTraining', 'GLMChatModel', 'GLMForPreTrainingWithLora', 'GLMChatModelWithLora']


def topk_fun(logits, topk=5):
    """Get topk"""
    batch_value = []
    batch_index = []
    for i in range(logits.shape[0]):
        target_column = logits[i].tolist()
        sorted_array = [(k, v) for k, v in enumerate(target_column)]
        sorted_array.sort(key=lambda x: x[1], reverse=True)
        topk_array = sorted_array[:topk]
        index, value = zip(*topk_array)
        batch_value.append(value)
        batch_index.append(index)
    return np.array(batch_value), np.array(batch_index)


def batch_select(data, index):
    """bathc operation to sorted_logits[:, :top_p_num]"""
    output = []
    for i in range(data.shape[0]):
        res = data[i, :index[i]]
        output.append(res.reshape(1, -1))
    return np.concatenate(output, 0)


def sampler(log_probs_revised, top_p, top_k, use_pynative=False):
    """Convert the log_probs to probability"""
    if use_pynative:
        logits = P.Pow()(np.e, Tensor(log_probs_revised, mstype.float32))
    else:
        logits = np.power(np.e, np.array(log_probs_revised, np.float32))

    # If top_p is less than 1.0, use top_p sampling
    if top_p < 1.0:
        # Only consider the 5000 largest logits to reduce computation
        if use_pynative:
            sorted_logits, index = P.TopK(sorted=True)(logits, 5000)
            cumsum_logits = P.CumSum()(sorted_logits, 1)
            cumsum_logits = cumsum_logits.asnumpy()
            index = index.asnumpy()
            sorted_logits = sorted_logits.asnumpy()
        else:
            sorted_logits, index = topk_fun(logits, 5000)
            cumsum_logits = np.cumsum(sorted_logits, 1)
        cumsum_logits = cumsum_logits
        index = index
        sorted_logits = sorted_logits
        top_p_num = np.sum(cumsum_logits < top_p, axis=-1) + 1
        # Get the corresponding probs and indices
        probs = batch_select(sorted_logits, top_p_num)
        p_args = batch_select(index, top_p_num)
        p = probs / np.sum(probs, -1, keepdims=True)
        # if top_p is set to 1.0, use top_k sampling
    else:
        # Get the corresponding probs and indices
        if use_pynative:
            probs, p_args = P.TopK(sorted=True)(logits, top_k)
            probs = probs.asnumpy()
            p_args = p_args.asnumpy()
        else:
            probs, p_args = topk_fun(logits, top_k)
        probs = probs
        p_args = p_args
        # Avoid rounding error
        for i in range(probs.shape[0]):
            if np.sum(probs[i]) == 0:
                probs[i] = np.array([1 / top_k for _ in range(top_k)])
        p = probs / np.sum(probs, -1, keepdims=True)
    return p, p_args


def precision_correct(p, top_p, top_k, batch_size):
    # Avoid rounding error
    if top_p == 1:
        for i in range(batch_size):
            if np.sum(p[i]) == 0:
                p[i] = np.array([1 / top_k for _ in range(top_k)])
        p = p / np.sum(p, -1, keepdims=True)
    return p


class ProcessLogits(nn.Cell):
    r"""Process logits into probability distribution."""

    def __init__(self, use_past=False):
        super(ProcessLogits, self).__init__()
        self.e = ms.Tensor(np.e)
        self.gather = P.Gather()
        self.logsoftmax = P.LogSoftmax()
        self.reshape = P.Reshape()
        self.use_past = use_past

    def construct(self, logits, current_index=None, is_first_iteration=False):
        logits = logits.reshape(-1, logits.shape[-1])
        if self.use_past and not is_first_iteration:
            logits = logits
        elif current_index is not None:
            index = current_index.view(-1,)
            logits = self.gather(logits, index, 0)
        outputs = self.logsoftmax(logits)
        outputs = F.tensor_pow(self.e, outputs)
        return outputs


class GLMModel(nn.Cell):
    """
    The backbone of GLM network

    Args:
        config (GLMConfig): The config of network.
        op_parallel_config (optional): Operator parallel strategy. Default: `OpParallelConfig()`.
        embed_parallel_config (optional): Operator parallel strategy. Default: `EmbeddingOpParallelConfig()`.
    """
    def __init__(self,
                 config,
                 op_parallel_config=default_dpmp_config,
                 embed_parallel_config=default_embedding_parallel_config):
        super(GLMModel, self).__init__()
        # recording parameters
        self.num_layers = config.num_layers
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_heads
        self.seq_length = config.seq_length
        self.use_past = config.use_past
        layernorm = LayerNorm
        if config.parallel_config:
            op_parallel_config = config.parallel_config

        # create embedding parameters
        if is_version_ge(ms.__version__, '1.11.0'):
            self.embedding_dropout = nn.Dropout(p=config.embedding_dropout_prob)
        else:
            self.embedding_dropout = nn.Dropout(keep_prob=1 - config.embedding_dropout_prob)

        embed_parallel_config.data_parallel = op_parallel_config.data_parallel
        embed_parallel_config.model_parallel = op_parallel_config.model_parallel
        embed_parallel_config.vocab_emb_dp = False
        self.word_embeddings = VocabEmbedding(vocab_size=config.vocab_size, embedding_size=config.hidden_size,
                                              parallel_config=embed_parallel_config)

        self.matmul = ops.MatMul().shard(((1, 1), (1, embed_parallel_config.model_parallel)))
        self.transpose = ops.Transpose().shard(((embed_parallel_config.model_parallel, 1),))

        def get_layer(layer_id):
            return DeepNormWithGLULayer(
                self.num_layers,
                self.hidden_size,
                self.num_heads,
                config.batch_size,
                config.attention_dropout_rate,
                config.hidden_dropout_rate,
                config.layernorm_epsilon,
                layer_id,
                max_seq_len=self.seq_length,
                inner_hidden_size=config.inner_hidden_size,
                hidden_size_per_attention_head=config.hidden_size_per_attention_head,
                layernorm_order=config.layernorm_order,
                layernorm=layernorm,
                use_bias=True,
                activation_func=config.activation_func,
                position_encoding_2d=config.position_encoding_2d,
                params_dtype=config.param_init_type,
                layernorm_dtype=config.layernorm_compute_type,
                softmax_dtype=config.softmax_compute_type,
                compute_dtype=config.compute_dtype,
                use_past=self.use_past,
                parallel_config=op_parallel_config,
            )

        self.layers = nn.CellList(
            [get_layer(layer_id) for layer_id in range(config.num_layers)])

        # Final layer norm before output.
        self.use_final_layernorm = config.use_final_layernorm
        if config.use_final_layernorm:
            self.final_layernorm = layernorm(config.hidden_size, eps=config.layernorm_epsilon)
            self.final_layernorm.shard(((op_parallel_config.data_parallel, 1, 1),))

    def construct(self, input_ids, position_ids, attention_mask, init_reset=True, batch_valid_length=None):
        """
        Get output logits

        Inputs:
            input_ids (Tensor): The tokenized inputs with dtype int32.
            input_mask (Tensor): The mask indicating whether each position is a valid input.
            position_ids (Tensor): Used to identify each token's position in the list of tokens.
            attention_mask (Tensor): Used when batching sequences together.
            init_reset (bool, optional): Default: True.
            batch_valid_length (Tensor, optional): Default: None.

        Returns:
            logits (Tensor): The output logit of backbone.
            table (Tensor): The embedding table for the vocabulary.
        """
        if attention_mask is None:
            attention_mask = ops.ones((1, 1), mstype.int32)

        hidden_states, table = self.word_embeddings(input_ids)

        hidden_states = self.embedding_dropout(hidden_states)

        for i in range(self.num_layers):
            layer_ret = self.layers[i](hidden_states, attention_mask, position_ids, init_reset, batch_valid_length)

            if isinstance(layer_ret, tuple):
                layer_ret = layer_ret[0]
            hidden_states = layer_ret

        # Final layer norm.
        if self.use_final_layernorm:
            logits = self.final_layernorm(hidden_states)
        else:
            logits = hidden_states

        return logits, table


class GLMHead(nn.Cell):
    r"""Head for GLM to get the logits of each token in the vocab."""

    def __init__(self,
                 hidden_size,
                 vocab_size,
                 param_init_type=mstype.float32,
                 compute_dtype=mstype.float16,
                 embed_parallel_config=None):
        super(GLMHead, self).__init__()
        self.param_init_type = param_init_type
        self.compute_dtype = compute_dtype
        self.weight = Parameter(initializer("normal", [vocab_size, hidden_size], compute_dtype), name="weight")
        self.transpose = ops.Transpose().shard(((embed_parallel_config.model_parallel, 1),))
        self.matmul = ops.MatMul(transpose_b=True).shard(
            ((embed_parallel_config.data_parallel, 1), (embed_parallel_config.model_parallel, 1)))

    def construct(self, state, embedding_table=None):
        """Get vocab probs"""
        state = F.reshape(state, (-1, F.shape(state)[-1]))
        state = ops.cast(state, self.compute_dtype)
        if embedding_table is None:
            embedding_table = self.weight
        embedding_table = self.cast(embedding_table, self.compute_dtype)
        logits_parallel = self.matmul(state, embedding_table)
        return logits_parallel


[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS) class GLMForPreTraining(BaseModel): r""" Provide glm training loss or logits through network. Args: config (GLMConfig): The config of GLMModel. """ _support_list = MindFormerBook.get_model_support_list()['glm'] def __init__(self, config: GLMConfig): super(GLMForPreTraining, self).__init__(config) self.config = config self.position_encoding_2d = config.position_encoding_2d self.transformer = GLMModel(config) self.lm_head = GLMHead( hidden_size=config.hidden_size, vocab_size=config.vocab_size, param_init_type=config.param_init_type, compute_dtype=config.compute_dtype, embed_parallel_config=config.parallel_config) self.stridedslice = ops.StridedSlice().shard(((1, 1),)) self.loss = CrossEntropyLoss(parallel_config=config.parallel_config) self.gmask = config.gmask_token_id self.bos_token_id = config.bos_token_id self.ones = P.Ones() self.load_checkpoint(config) def get_masks_np(self, input_ids): batch_size, seq_length = input_ids.shape context_lengths = [list(seq).index(self.config.bos_token_id) for seq in input_ids] attention_mask = np.tril(np.ones((batch_size, seq_length, seq_length))) for i, context_length in enumerate(context_lengths): attention_mask[i, :, :context_length] = 1 attention_mask = np.expand_dims(attention_mask, axis=1) attention_mask = np.array(attention_mask < 0.5, np.bool_) return attention_mask
[文档] def get_position_ids_np(self, input_ids, mask_positions, use_gmasks=None): """Get position ids from input_ids and mask_positions with numpy""" batch_size, seq_length = input_ids.shape if use_gmasks is None: use_gmasks = [False] * batch_size context_lengths = [list(seq).index(self.config.bos_token_id) for seq in input_ids] if self.config.position_encoding_2d: position_ids = np.repeat(np.expand_dims(np.arange(seq_length), 0), batch_size, axis=0) for i, context_length in enumerate(context_lengths): position_ids[i, context_length:] = mask_positions[i] block_position_ids = [np.concatenate(( np.zeros(context_length, np.int32), np.arange(seq_length - context_length, dtype=np.int32) + 1 )) for context_length in context_lengths] block_position_ids = np.stack(block_position_ids, axis=0) position_ids = np.stack((position_ids, block_position_ids), axis=1) else: position_ids = np.repeat(np.expand_dims(np.arange(seq_length), 0), batch_size, axis=0) for i, context_length in enumerate(context_lengths): if not use_gmasks[i]: position_ids[context_length:] = mask_positions[i] return position_ids
[文档] def create_position_ids_np(self, input_ids): """Get position ids from input_ids with numpy""" mask, gmask = self.config.mask_token_id, self.config.gmask_token_id seqs = list(input_ids) mask_positions, use_gmasks = [], [] for seq in seqs: mask_token = gmask if gmask in seq else mask use_gmask = mask_token == gmask mask_positions.append(list(seq).index(mask_token)) use_gmasks.append(use_gmask) position_ids = self.get_position_ids_np(input_ids, mask_positions, use_gmasks=None) return position_ids
def _incremental_infer(self, input_ids, current_index, valid_length_each_example, position_ids=None, attention_mask=None): # Claim the first graph if self.is_first_iteration: self.add_flags_recursive(is_first_iteration=True) res = self( input_ids=Tensor(input_ids, mstype.int32), # input_ids (1,512) int32 position_ids=Tensor(position_ids, mstype.int32), # position_ids (1, 2, 512) int32 attention_mask=Tensor(attention_mask, mstype.float32), # attention_mask (1, 1, 512, 512) float32 input_position=current_index, init_reset=Tensor([False], mstype.bool_), # init_reset (1,) bool False batch_valid_length=Tensor([valid_length_each_example], mstype.int32) ) # batch_valid_length (1,) int32 4 # first iter done, go to other iters self.is_first_iteration = False else: self.add_flags_recursive(is_first_iteration=False) current_index_tmp = int(current_index[0]) # use numpy to slice array to avoid complie ascend slice op inputs_tmp = input_ids[:, current_index_tmp:current_index_tmp + 1] position_ids_tmp = position_ids[..., current_index_tmp:current_index_tmp + 1] attention_mask_tmp = attention_mask[:, :, current_index_tmp:current_index_tmp + 1, :] res = self( input_ids=Tensor(inputs_tmp, mstype.int32), # input_ids (1,512) int32 position_ids=Tensor(position_ids_tmp, mstype.int32), # position_ids (1, 2, 1) int32 attention_mask=Tensor(attention_mask_tmp, mstype.float32), # attention_mask (1, 1, 1, 512) float32 input_position=current_index, init_reset=Tensor([True], mstype.bool_), # init_reset (1,) bool True batch_valid_length=Tensor([valid_length_each_example], mstype.int32) ) # batch_valid_length (1,) int32 5 return res def _forward(self, origin_inputs, top_k, top_p, repetition_penalty, max_length, eos_token_id, streamer=None, pad_token_id=None): """ Text generation given the model and origin inputs Inputs: model: The model to run the prediction end_token(int): The model will stop generating the words when it reaches the end_token. origin_inputs(list): The prompt for generation, should be a list of ids. model_origin_max_length(int): The sequence length of the model trained. max_length(int): The maximum of generated length. vocab_size(int): The vocabulary length of the model. config: Inference configurations. streamer: Streamer object that will be used to stream the generated sequences. Returns: outputs: the ids for the generated text """ if pad_token_id is None: pad_token_id = 0 # Get configurations for inference use_pynative = True if streamer is not None: streamer.put(origin_inputs[0]) batch_size = origin_inputs.shape[0] is_npu_acceleration = self.config.is_npu_acceleration valid_length_each_example = [] for i in range(batch_size): # As the nonzero returns the index and we need length valid_length_each_example.append(np.max(np.argwhere(origin_inputs[i] != pad_token_id)) + 1) valid_length_each_example = np.array(valid_length_each_example) if np.max(valid_length_each_example) > max_length: raise ValueError("The max_length set is smaller than the length in the input_ids. You shout set " f"max_length to {np.max(valid_length_each_example)}") target_length = self.config.seq_length if max_length > self.config.seq_length else max_length # A list of the frequency of each token frequency_list = None input_ids = self._pad_inputs_using_max_length(origin_inputs=origin_inputs, pad_token_id=pad_token_id) # for GLM `attention_mask` and `position_ids` generation attention_mask = self.get_masks_np(input_ids) position_ids = self.create_position_ids_np(input_ids) input_mask = np.zeros_like(input_ids) for i in range(valid_length_each_example.shape[0]): input_mask[i, :valid_length_each_example[i]] = 1 # A single loop generates one token, loop until reaching target model_origin_max_length or generating eod token is_finished = [False] * batch_size # setup is_first_iteration flag for incremental infer if self.config.use_past: self.is_first_iteration = True is_first_iteration = False while np.sum(is_finished) != batch_size: # for GLM generation # model basic setting self.top_p = top_p self.top_k = top_k self.repetition_penalty = repetition_penalty seq_length = input_ids.shape[1] current_index = [valid_length_each_example[i] - 1 + i * seq_length for i in range(batch_size)] current_index = Tensor(current_index, mstype.int32) if self.config.use_past: is_first_iteration = self.is_first_iteration res = self._incremental_infer( input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, current_index=current_index, valid_length_each_example=valid_length_each_example ) else: res = self( input_ids=Tensor(input_ids, mstype.int32), position_ids=Tensor(position_ids, mstype.int32), attention_mask=Tensor(attention_mask, mstype.float32) ) if is_npu_acceleration: p, p_args = res p = p.asnumpy() p_args = p_args.asnumpy() # Avoid rounding error p = precision_correct(p, top_p, top_k, batch_size) else: log_probs = self.process_logits(res, current_index, is_first_iteration, self.config.use_past) # Sample log_probs = log_probs.asnumpy() vocab_size = log_probs.shape[-1] if repetition_penalty != 1 and frequency_list is None: frequency_list = np.array([[0 for _ in range(vocab_size)]]) log_probs_revised = log_probs.reshape(batch_size, vocab_size) if repetition_penalty != 1: log_probs_revised = log_probs - frequency_list * repetition_penalty - \ (frequency_list > 0) * repetition_penalty p, p_args = sampler(log_probs_revised, top_p, top_k, use_pynative) # Random select a token as final output for this round for i in range(batch_size): if is_finished[i]: continue target_index = np.random.choice(len(p[i]), p=p[i]) # update frequency list target = p_args[i][target_index] if repetition_penalty != 1: frequency_list[0][target] = frequency_list[0][target] + 1 input_ids[i, valid_length_each_example[i]] = p_args[i, target_index] if streamer is not None: streamer.put(np.asarray([target])) valid_length_each_example[i] += int(1) input_mask[i][valid_length_each_example[i] - 1] = 1 # Stop judgment if p_args[i][target_index] == eos_token_id or valid_length_each_example[i] == target_length: is_finished[i] = True continue # Return valid outputs out of padded outputs output_ids = [] for i in range(batch_size): output_ids.append(input_ids[i, : int(valid_length_each_example[i])].astype(np.int32)) if streamer is not None: streamer.end() return output_ids # pylint: disable=W0613 def construct(self, input_ids, label=None, position_ids=None, attention_mask=None, input_position=None, init_reset=True, batch_valid_length=None): """ Extract logits and calculate loss Inputs: input_ids (Tensor): The tokenized inputs with dtype int32. label (Tensor): The indices of input sequence tokens in the vocabulary. position_ids (Tensor): Used to identify each token's position in the list of tokens. attention_mask (Tensor): Used when batching sequences together. init_reset (bool, optional): Default: True. batch_valid_length(Tensor, optional): Default: None. Returns: Training phase: loss: Training loss. Other phase: logits (Tensor): The output logit of backbone. """ batch_size, seq_length = input_ids.shape if self.phase == "train": tokens = self.stridedslice(input_ids, (0, 0), (batch_size, seq_length), (1, 1)) else: tokens = input_ids output_states, _ = self.transformer(tokens, position_ids, attention_mask, init_reset, batch_valid_length) logits = self.lm_head(output_states) if self.phase != 'train': return logits logits_shape = logits.shape label = label.reshape((-1,)) logits = logits.reshape((-1, logits_shape[-1])) input_mask = self.ones(tokens.shape, logits.dtype) input_mask = input_mask.reshape((-1,)) loss = self.loss(logits, label, input_mask) return loss
[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS) class GLMChatModel(GLMForPreTraining): r""" Provide glm chat capability through network. Args: config (GLMConfig): The config of GLMModel. Returns: Tensor, the probability distribution of network loss. """ _support_list = MindFormerBook.get_model_support_list()['glm'] def __init__(self, config: GLMConfig): super(GLMChatModel, self).__init__(config) self.e = ms.Tensor(np.e, dtype=mstype.float32) self.pow = P.Pow() self.topk = P.TopK(sorted=True) self.cumsum = P.CumSum() if is_version_ge(ms.__version__, '1.11.0'): self.sum = ops.sum else: self.sum = P.ReduceSum(keep_dims=False) self.vocab_size = config.vocab_size self.batch_size = config.batch_size self.frequency_list = ms.Tensor([[0 for _ in range(self.vocab_size)]]) self.post_logits = ProcessLogits(use_past=config.use_past) # seems not supported yet. # self.top_p = config.top_p self.top_p = 1 self.top_k = config.top_k self.repetition_penalty = config.repetition_penalty self.is_first_iteration = False self.is_npu_acceleration = config.is_npu_acceleration
[文档] def sample(self, log_probs): """Convert the log_probs to probability""" if self.repetition_penalty != 1: log_probs = log_probs - self.frequency_list * self.repetition_penalty - \ (self.frequency_list > 0) * self.repetition_penalty # Process sample in graph to accelerate generate logits = self.pow(self.e, log_probs) # If top_p is less than 1.0, use top_p sampling # seems not supported yet. if self.top_p < 1.0: sorted_logits, index = self.topk(logits, 5000) cumsum_logits = self.cumsum(sorted_logits, 1) top_p_num = self.sum((cumsum_logits < self.top_p).astype(mstype.int32), -1) + 1 top_p_num = int(top_p_num) # Get the corresponding probs and indices probs = sorted_logits[:, :top_p_num] p_args = index[:, :top_p_num] p = probs / self.sum(probs, -1, keepdim=True) # if top_p is set to 1.0, use top_k sampling else: probs, p_args = self.topk(logits, self.top_k) p = probs return p, p_args
# pylint:disable=arguments-differ def construct(self, input_ids, position_ids=None, attention_mask=None, input_position=None, init_reset=True, batch_valid_length=None): """Get probs and p_args""" # model forward output_states, _ = self.transformer(input_ids, position_ids, attention_mask, init_reset, batch_valid_length) logits = self.lm_head(output_states) if not self.is_npu_acceleration: return logits # logit post process log_probs = self.post_logits(logits, input_position, self.is_first_iteration) # logit sort and sample probs, p_args = self.sample(log_probs) return probs, p_args
[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS) class GLMForPreTrainingWithLora(GLMForPreTraining): """GLM Model for pretraining with LoRA Args: config (GLMConfig): The config of network. """ def __init__(self, config: GLMConfig = None, pet=None, **kwargs): _ = kwargs super().__init__(config) # get Pet tuning model. self.pet = pet self.pet.pet_config.reg_rules = r'.*query_key_value*' self.transformer = LoraAdapter.get_pet_model(self.transformer, self.pet.pet_config) # freeze pretrained model PetAdapter.freeze_pretrained_model(self, self.pet.pet_type)
[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS) class GLMChatModelWithLora(GLMChatModel): """GLM Model for pretraining with LoRA Args: config (GLMConfig): The config of network. """ def __init__(self, config: GLMConfig = None, pet=None, **kwargs): _ = kwargs ckpt_cfg = config.checkpoint_name_or_path config.checkpoint_name_or_path = None super().__init__(config) # get Pet tuning model. self.pet = pet self.pet.pet_config.reg_rules = r'.*query_key_value*' self.transformer = LoraAdapter.get_pet_model(self.transformer, self.pet.pet_config) config.checkpoint_name_or_path = ckpt_cfg self.load_checkpoint(config)