mindformers.modules.transformer.MultiHeadAttention¶
-
class
mindformers.modules.transformer.MultiHeadAttention(**kwargs)[源代码]¶ This is an implementation of multihead attention in the paper Attention is all you need. Given the query vector with source length, and the key and value vector with target length, the attention will be performed as the following
\[MultiHeadAttention(query, key, vector) = Concat(head_1, \dots, head_h)W^O\]where \(head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)\). The default is with a bias.
if query, key and value tensor is same, then it will be self attention.
- 参数
batch_size (int) – The batch size of the input tensor when do increnmental prediction. Should be a positive value. When do training or prediction, the argument will not work and the user can just pass None to the argument.
src_seq_length (int) – The sequence length of the query vector.
tgt_seq_length (int) – The sequence length of the key and value vector.
hidden_size (int) – The hidden size of the input.
num_heads (int) – The number of the heads.
hidden_dropout_rate (float) – The dropout rate of the final output of the layer. Default:0.1.
attention_dropout_rate (float) – The dropout rate of the attention scores. Default:0.1.
compute_dtype (dtype.Number) – The computation type of dense. Default mstype.float16. Should be mstype.float32 or mstype.float16.
softmax_compute_type (dtype.Number) – The type of softmax computation module. Default mstype.float32. Should be mstype.float32 or mstype.float16.
param_init_type (dtype.Number) – The parameter initialization type of the module. Default mstype.float32. Should be mstype.float32 or mstype.float16.
use_past (bool) – Use the past state to compute, used for incremental prediction. For example, if we have two words and want to generate the ten more words. We just need to compute the two words’ state only once, and generate the next word one by one. When use_past is True, there are two steps to run the prediction. In the first step, set the is_first_iteration to be True by model.add_flags_recursive(is_first_iteration=True), and pass the full inputs. Then, set the is_first_iteration to be False by model.add_flags_recursive(is_first_iteration=False). At this moment, pass the single step’s input tensor, and loop it. Default False.
parallel_config (OpParallelConfig) – The parallel configure. Default default_dpmp_config, an instance of OpParallelConfig with default args.
- Inputs:
query_tensor (Tensor) - The query vector with shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size), if the use_past is False or is_first_iteration=True. Otherwise, must be (batch_size, 1, hidden_size)
key_tensor (Tensor) - The key vector with shape (batch_size, tgt_seq_length, hidden_size) or (batch_size * tgt_seq_length, hidden_size), if the use_past is False or is_first_iteration=True. Otherwise, must be (batch_size, 1, hidden_size)
value_tensor (Tensor) - The value vector with shape (batch_size, tgt_seq_length, hidden_size) or (batch_size * tgt_seq_length, hidden_size), if the use_past is False or is_first_iteration=True. Otherwise, must be (batch_size, 1, hidden_size)
attention_mask (Tensor) - If the use_past is False or is_first_iteration=True, the attention mask matrix should ba (batch_size, src_seq_length, tgt_seq_length), or None. None means there will be no mask in softmax computation. Otherwise, the mask must be (batch_size, 1, tgt_seq_length)
key_past (Tensor) - Float16 tensor with shape (batch_size, num_heads, size_per_head, tgt_seq_length). The past calculated key vector. Used for incremental prediction when the use_past is True. Default None.
value_past (Tensor) - Float16 tensor with shape (batch_size, num_heads, tgt_seq_length, size_per_head). The past calculated value vector. Used for incremental prediction when the use_past is True. Default None.
batch_valid_length (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index. Used for incremental prediction when the use_past is True. Default None.
- Outputs:
Tuple, a tuple contains(output, layer_present)
output (Tensor) - Tensor, the float tensor of the output of the layer with shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size), if the use_past is False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size).
layer_present (Tuple) - A tuple of the Tensor of the projected key and value vector with ((batch_size, num_heads, size_per_head, tgt_seq_length), (batch_size, num_heads, tgt_seq_length, size_per_head)).
- Supported Platforms:
AscendGPU
实际案例
>>> import numpy as np >>> from mindformers.modules.transformer import MultiHeadAttention >>> from mindspore import dtype as mstype >>> from mindspore import Tensor >>> model = MultiHeadAttention(batch_size=None, hidden_size=15, src_seq_length=20, tgt_seq_length=20, ... num_heads=3) >>> from_tensor = Tensor(np.ones((2, 20, 15)), mstype.float32) >>> to_tensor = Tensor(np.ones((2, 20, 15)), mstype.float16) >>> attention_mask = Tensor(np.ones((2, 20, 20)), mstype.float16) >>> attn_out, past = model(from_tensor, to_tensor, to_tensor, attention_mask) >>> print(attn_out.shape) (2, 20, 15) >>> print(past[0].shape) (2, 3, 5, 20) >>> print(past[1].shape) (2, 3, 20, 5) >>> # When use use_past=True, it includes two steps to implement the incremental prediction. >>> # Step 1: set is_first_iteration=True, and input the full sequence length's state. >>> # We need to prepare the memory parameters for saving key and value states firstly. >>> model = MultiHeadAttention(batch_size=2, hidden_size=15, src_seq_length=20, tgt_seq_length=20, ... num_heads=3, use_past=True) >>> key_past = Tensor(np.zeros(shape=(2, 3, 5, 20)), mstype.float16) >>> value_past = Tensor(np.zeros(shape=(2, 3, 20, 5)), mstype.float16) >>> batch_valid_length = Tensor(np.ones((2,)), mstype.int32) >>> # Set is_first_iteration=True to generate the full memory states >>> model.add_flags_recursive(is_first_iteration=True) >>> attn_out, past = model(from_tensor, to_tensor, to_tensor, attention_mask, key_past, value_past, ... batch_valid_length) >>> print(attn_out.shape) (2, 20, 15) >>> print(past[0].shape) (2, 3, 5, 20) >>> print(past[1].shape) (2, 3, 20, 5) >>> from_tensor = Tensor(np.ones((2, 1, 15)), mstype.float32) >>> to_tensor = Tensor(np.ones((2, 1, 15)), mstype.float16) >>> attention_mask = Tensor(np.ones((2, 1, 20)), mstype.float16) >>> # Step 2: set is_first_iteration=False, and pass the single word to run the prediction rather than the >>> # full sequence. >>> model.add_flags_recursive(is_first_iteration=False) >>> attn_out, past = model(from_tensor, to_tensor, to_tensor, attention_mask, key_past, value_past, ... batch_valid_length) >>> print(attn_out.shape) (2, 1, 15) >>> print(past[0].shape) (2, 3, 5, 20) >>> print(past[1].shape) (2, 3, 20, 5)