mindformers.modules.layers.FixedSparseAttention

class mindformers.modules.layers.FixedSparseAttention(batch_size, num_heads, size_per_head, block_size, seq_length=1024, num_different_global_patterns=4, parallel_config=<mindformers.modules.transformer.op_parallel_config.OpParallelConfig object>)[源代码]

Fixed Sparse Attention Layer.

This function contains the sparse attention primitives used in Sparse Transformers (see paper) Generating Long Sequences with Sparse Transformers.

Specifically, it includes the following:

  1. A faster implementation of normal attention (the upper triangle is not computed, and many operations are fused).

  2. An implementation of “strided” and “fixed” attention, as in the Sparse Transformers paper.

Args:

batch_size(int): Number of input batch size. num_heads(int): Number of attention heads. size_per_head(int): An integer determining embedding size of each attention head,

only supports 64, 128 for now.

block_size(int): An integer determining the block size. Current implementation of sparse self-attention

is based on blocked sparse matrices. In which this parameter defines the size of such blocks, Block X Block. Only supports 64 for now.

seq_length(int): length of input sequence, only supports 1024 for now. Default 1024. num_different_global_patterns(int): An integer determining the number of different global attentions layouts.

While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative, only supports 4 for now. Default 4.

parallel_config(OpParallelConfig): The config of parallel setting, see OpParallelConfig.

Default default_dpmp_config, an instance of OpParallelConfig with default args.

Inputs:
  • q (Tensor) - Tensor query (mstype.fp16 [batch_size, seq_length, hidden_size]): Sequence of queries to query the context.

  • k (Tensor) - Tensor key (mstype.fp16 [batch_size, seq_length, hidden_size]): Sequence of queries to query the context.

  • v (Tensor) - Tensor value (mstype.fp16 [batch size, sequence length, Embedding Size]): Sequence of queries to query the context.

  • attention_mask (Tensor) - Float Tensor the mask of (mstype.fp32, mstype.fp16 [batch_size, seq_length, seq_length]): Lower triangular matrix to pass masked information.

Outputs:

A Tensor. The output of the attention with shape [batch_size, seq_length, hidden_size]

Supported Platforms:

Ascend

Examples:
>>> import numpy as np
>>> from mindspore import dtype as mstype
>>> from mindformers.modules import FixedSparseAttention
>>> from mindspore import Tensor
>>> model = FixedSparseAttention(batch_size=2,
...                              num_heads=8,
...                              size_per_head=64,
...                              block_size=64)
>>> q = Tensor(np.ones((2, 1024, 8*64)), mstype.float16)
>>> k = Tensor(np.ones((2, 1024, 8*64)), mstype.float16)
>>> v = Tensor(np.ones((2, 1024, 8*64)), mstype.float16)
>>> attention_mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
>>> output = model(q, k, v, attention_mask)
>>> print(output.shape)
(2, 1024, 512)