mindformers.modules.layers.FixedSparseAttention¶
-
class
mindformers.modules.layers.FixedSparseAttention(batch_size, num_heads, size_per_head, block_size, seq_length=1024, num_different_global_patterns=4, parallel_config=<mindformers.modules.transformer.op_parallel_config.OpParallelConfig object>)[源代码]¶ Fixed Sparse Attention Layer.
This function contains the sparse attention primitives used in Sparse Transformers (see paper) Generating Long Sequences with Sparse Transformers.
Specifically, it includes the following:
A faster implementation of normal attention (the upper triangle is not computed, and many operations are fused).
An implementation of “strided” and “fixed” attention, as in the Sparse Transformers paper.
- 参数
batch_size (int) – Number of input batch size.
num_heads (int) – Number of attention heads.
size_per_head (int) – An integer determining embedding size of each attention head, only supports 64, 128 for now.
block_size (int) – An integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines the size of such blocks, Block X Block. Only supports 64 for now.
seq_length (int) – length of input sequence, only supports 1024 for now. Default 1024.
num_different_global_patterns (int) – An integer determining the number of different global attentions layouts. While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative, only supports 4 for now. Default 4.
parallel_config (OpParallelConfig) – The config of parallel setting, see OpParallelConfig. Default default_dpmp_config, an instance of OpParallelConfig with default args.
- Inputs:
q (Tensor) - Tensor query (
mstype.fp16[batch_size, seq_length, hidden_size]): Sequence of queries to query the context.k (Tensor) - Tensor key (
mstype.fp16[batch_size, seq_length, hidden_size]): Sequence of queries to query the context.v (Tensor) - Tensor value (
mstype.fp16[batch size, sequence length, Embedding Size]): Sequence of queries to query the context.attention_mask (Tensor) - Float Tensor the mask of (
mstype.fp32,mstype.fp16[batch_size, seq_length, seq_length]): Lower triangular matrix to pass masked information.
- Outputs:
A Tensor. The output of the attention with shape [batch_size, seq_length, hidden_size]
- Supported Platforms:
Ascend
实际案例
>>> import numpy as np >>> from mindspore import dtype as mstype >>> from mindformers.modules import FixedSparseAttention >>> from mindspore import Tensor >>> model = FixedSparseAttention(batch_size=2, ... num_heads=8, ... size_per_head=64, ... block_size=64) >>> q = Tensor(np.ones((2, 1024, 8*64)), mstype.float16) >>> k = Tensor(np.ones((2, 1024, 8*64)), mstype.float16) >>> v = Tensor(np.ones((2, 1024, 8*64)), mstype.float16) >>> attention_mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32) >>> output = model(q, k, v, attention_mask) >>> print(output.shape) (2, 1024, 512)