mindformers.modules.transformer.AttentionMask

class mindformers.modules.transformer.AttentionMask(**kwargs)[源代码]

Get the Lower triangular matrix from the input mask. The input mask is a 2D tensor (batch_size, seq_length) with 1 and 0, where 1 indicates the current position is a valid token, otherwise not.

参数
  • seq_length (int) – The sequence length of the input tensor.

  • parallel_config (OpParallelConfig) – The parallel configure. Default default_dpmp_config, an instance of OpParallelConfig with default args.

Inputs:
  • input_mask (Tensor) - The mask indicating whether each position is a valid input with (batch_size, seq_length).

Outputs:

Tensor. The attention mask matrix with shape (batch_size, seq_length, seq_length).

引发
  • TypeErrorseq_length is not an integer.

  • ValueErrorseq_length is not a positive value.

  • TypeErrorparallel_config is not a subclass of OpParallelConfig.

Supported Platforms:

Ascend GPU

实际案例

>>> import numpy as np
>>> from mindformers.modules.transformer import AttentionMask
>>> from mindspore import Tensor
>>> mask = AttentionMask(seq_length=4)
>>> mask_array = np.array([[1, 1, 1, 0]], np.float32)
>>> inputs = Tensor(mask_array)
>>> res = mask(inputs)
>>> print(res)
[[[1. 0. 0. 0]
  [1. 1. 0. 0]
  [1. 1. 1. 0]
  [0. 0. 0. 0]]]