mindformers.modules.layers.FixedSparseAttention¶
- class mindformers.modules.layers.FixedSparseAttention(batch_size, num_heads, size_per_head, block_size, seq_length=1024, num_different_global_patterns=4, parallel_config=<mindformers.modules.transformer.op_parallel_config.OpParallelConfig object>)[源代码]¶
Fixed Sparse Attention Layer.
This function contains the sparse attention primitives used in Sparse Transformers (see paper) Generating Long Sequences with Sparse Transformers.
Specifically, it includes the following:
A faster implementation of normal attention (the upper triangle is not computed, and many operations are fused).
An implementation of “strided” and “fixed” attention, as in the Sparse Transformers paper.
- Args:
batch_size(int): Number of input batch size. num_heads(int): Number of attention heads. size_per_head(int): An integer determining embedding size of each attention head,
only supports 64, 128 for now.
- block_size(int): An integer determining the block size. Current implementation of sparse self-attention
is based on blocked sparse matrices. In which this parameter defines the size of such blocks, Block X Block. Only supports 64 for now.
seq_length(int): length of input sequence, only supports 1024 for now. Default 1024. num_different_global_patterns(int): An integer determining the number of different global attentions layouts.
While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative, only supports 4 for now. Default 4.
- parallel_config(OpParallelConfig): The config of parallel setting, see OpParallelConfig.
Default default_dpmp_config, an instance of OpParallelConfig with default args.
- Inputs:
q (Tensor) - Tensor query (
mstype.fp16[batch_size, seq_length, hidden_size]): Sequence of queries to query the context.k (Tensor) - Tensor key (
mstype.fp16[batch_size, seq_length, hidden_size]): Sequence of queries to query the context.v (Tensor) - Tensor value (
mstype.fp16[batch size, sequence length, Embedding Size]): Sequence of queries to query the context.attention_mask (Tensor) - Float Tensor the mask of (
mstype.fp32,mstype.fp16[batch_size, seq_length, seq_length]): Lower triangular matrix to pass masked information.
- Outputs:
A Tensor. The output of the attention with shape [batch_size, seq_length, hidden_size]
- Supported Platforms:
Ascend- Examples:
>>> import numpy as np >>> from mindspore import dtype as mstype >>> from mindformers.modules import FixedSparseAttention >>> from mindspore import Tensor >>> model = FixedSparseAttention(batch_size=2, ... num_heads=8, ... size_per_head=64, ... block_size=64) >>> q = Tensor(np.ones((2, 1024, 8*64)), mstype.float16) >>> k = Tensor(np.ones((2, 1024, 8*64)), mstype.float16) >>> v = Tensor(np.ones((2, 1024, 8*64)), mstype.float16) >>> attention_mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32) >>> output = model(q, k, v, attention_mask) >>> print(output.shape) (2, 1024, 512)