mindformers.models.llama.LlamaForCausalLM¶
-
class
mindformers.models.llama.LlamaForCausalLM(config: mindformers.models.llama.llama_config.LlamaConfig = None)[源代码]¶ Provide llama training loss or logits through network. :param config: The config of llama model. :type config: LlamaConfig
- Inputs:
input_ids(Tensor): the tokenized inputs with datatype int32, Tensor of shape \((batch, seq\_length)\). label_ids(Tensor): the tokenized labels with datatype int32, Tensor of shape \((batch, seq\_length)\) input_position(Tensor): current position, used by model.predict (bool, optional): Default: True. attention_mask(Tensor): Reserved param, not used. batch_valid_length(Tensor): Reserved param, not used.
- 返回
Tensor, the loss or logits of the network.
实际案例
>>> from mindformers.models.llama import LlamaConfig, LlamaForCausalLM >>> config = LlamaConfig(batch_size=2) >>> network = LlamaForCausalLM(config=config)