From a603d7b70c236d3567d2a2adf9c15d7983b32259 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 2 Apr 2020 08:49:15 -0700 Subject: [PATCH 01/56] add MHA building blocks in torchtext --- torchtext/models/__init__.py | 6 + torchtext/models/functional.py | 183 +++++++++++++++++++++++++ torchtext/models/multiheadattention.py | 129 +++++++++++++++++ 3 files changed, 318 insertions(+) create mode 100644 torchtext/models/__init__.py create mode 100644 torchtext/models/functional.py create mode 100644 torchtext/models/multiheadattention.py diff --git a/torchtext/models/__init__.py b/torchtext/models/__init__.py new file mode 100644 index 0000000000..65aada6847 --- /dev/null +++ b/torchtext/models/__init__.py @@ -0,0 +1,6 @@ +from .multiheadattention import MultiheadAttentionInProjection, \ + ScaledDotProduct, MultiheadAttentionOutProjection + +__all__ = ['MultiheadAttentionInProjection', + 'ScaledDotProduct', + 'MultiheadAttentionOutProjection'] diff --git a/torchtext/models/functional.py b/torchtext/models/functional.py new file mode 100644 index 0000000000..c9b84b72f7 --- /dev/null +++ b/torchtext/models/functional.py @@ -0,0 +1,183 @@ +import torch +from torch._overrides import has_torch_function, handle_torch_function +import torch.nn.functional as F +from torch._jit_internal import Optional, Tuple + + +Tensor = torch.Tensor + + +def multi_head_attention_in_projection(seq, num_heads, in_proj_weight, in_proj_bias=None): + # type: (Tensor, int, Tensor, Optional[Tensor]) -> Tensor + r"""Projects an input sequence using parallel attention heads. + Args: + seq (Tensor): sequence to be projected + num_heads (int): number of parallel heads used. + in_proj_weight (Tensor): weight used for projection + in_proj_bias (Tensor, optional): bias used for projection. + Shape: + - seq: :math:`(S, N, E)` + - in_proj_weight: :math:`(P, E)` + - in_proj_bias: :math:`(P)` + - Output: :math:`(N * H, S, P / H)` + where S is the sequence length, H is the number of attention heads, N is the + batch size, P is the projection dimension, and E is the embedding + dimension. + """ + if not torch.jit.is_scripting(): + tens_ops = (seq, in_proj_weight) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + multi_head_attention_in_projection, tens_ops, + seq, num_heads, in_proj_weight, in_proj_bias=in_proj_bias) + seq_len, bsz, _ = seq.size() + proj_dim = in_proj_weight.size(0) + assert proj_dim % num_heads == 0, "projection dimension must be divisible by num_heads" + head_dim = proj_dim // num_heads + + q = F.linear(seq, in_proj_weight, in_proj_bias) + # Shape of q: (S, N, P) + q = q.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1) + return q + + +def scaled_dot_product_attention(q, # type: Tensor + k, # type: Tensor + v, # type: Tensor + num_heads, # type: int + add_zero_attn, # type: bool + dropout_p, # type: float + training=True, # type: bool + key_padding_mask=None, # type: Optional[Tensor] + attn_mask=None, # type: Optional[Tensor] + ): + # type: (...) -> Tuple[Tensor, Tensor] + r"""Uses a scaled dot product with the projected key-value pair to update + the projected query. + Args: + q (Tensor): Projected query + k (Tensor): Projected key + v (Tensor): Projected value + num_heads (int): Number of parallel attention heads. + add_zero_attn (bool): Add a new batch of zeros to the projected key and + value sequences at dimension 1. + dropout_p (float): Probability of an element will be zeroed. + training (bool): Apply dropout if ``training=True`` + key_padding_mask (Tensor, optional): Specified padding elements in the + key will be ignored by the attention. This is a binary mask. When + the value is True, the corresponding value on the attention layer + will be set to :math:`-\inf`. + attn_mask (Tensor, optional): 2D or 3D mask that prevents attention to + certain positions. This is an additive mask (i.e. the values will + be added to the attention layer). A 2D mask will be broadcasted for + all the batches while a 3D mask allows to specify a different mask + for the entries of each batch. + Shape: + - q: :math:`(N * H, L, P / H)` + - k: :math:`(N * H, S, P / H)` + - v: :math:`(N * H, S, P / H)` + - key_padding_mask: :math:`(N, S)` + - attn_mask: :math:`(L, S)` or :math:`(N * H, L, S)` + - Output: :math:`(N * H, L, P / H)`, :math:`(N * H, L, S)` + where L is the target length, S is the source length, H is the number + of attention heads, N is the batch size, and P is the projection + dimension. + """ + if not torch.jit.is_scripting(): + tens_ops = (q, k, v) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + scaled_dot_product_attention, tens_ops, + q, k, v, num_heads, add_zero_attn, dropout_p, + training=training, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + batch_heads, tgt_len, head_dim = q.size() + assert q.size(0) == k.size(0) == v.size(0), "Dimension 0 of q, k, v must be equal." + assert batch_heads % num_heads == 0, "Dimension 0 of q, k, v must be divisible by num_heads" + bsz = batch_heads // num_heads + assert k.size() == v.size(), "Shape of k, v must match" + assert q.size(-1) == k.size(-1), "The head dimension of query must be equal to that of key" + + src_len = k.size(1) + + # Scale q + q = q * (float(head_dim) ** -0.5) + if attn_mask is not None: + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, tgt_len, src_len]: + raise RuntimeError('The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [batch_heads, tgt_len, src_len]: + raise RuntimeError('The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) + # attn_mask's dim is 3 now. + if attn_mask.dtype == torch.bool: + attn_mask = torch.where( + attn_mask, torch.tensor(float('-inf')), torch.tensor(0.)).to(dtype=q.dtype, device=q.device) + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + + # Dot product of q, k + attn_output_weights = torch.matmul(q, k.transpose(-2, -1)) + assert list(attn_output_weights.size()) == [batch_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.reshape(batch_heads, tgt_len, src_len) + + attn_output_weights = F.softmax(attn_output_weights, dim=-1) + + attn_output = torch.matmul(F.dropout(attn_output_weights, p=dropout_p, training=training), v) + return attn_output, attn_output_weights + + +def multi_head_attention_out_projection(attn_output, num_heads, out_proj_weight, out_proj_bias=None): + # type: (Tensor, int, Tensor, Optional[Tensor]) -> Tensor + r"""Projects an output sequence using parallel attention heads. + Args: + attn_output (Tensor): Projection to be decoded to an embedding. + num_heads (int): Number of parallel attention heads + out_proj_weight (Tensor): weight used to decode projection. + out_proj_bias (Tensor, optional): bias used to decode projection. + Shape: + - attn_output: :math:`(N * H, S, P / H)` + - out_proj_weight: :math:`(E, P)` + - out_proj_bias: :math:`(E)` + - Output: :math:`(S, N, E)` + where S is the sequence length, H is the number of attention heads, N is the + batch size, P is the projection dimension, and E is the embedding + dimension. + """ + if not torch.jit.is_scripting(): + tens_ops = (attn_output, out_proj_weight) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + multi_head_attention_out_projection, tens_ops, + attn_output, num_heads, out_proj_weight, out_proj_bias=out_proj_bias) + batch_heads, seq_len, head_dim = attn_output.size() + # embed_dim = out_proj_weight.size(0) + assert batch_heads % num_heads == 0, "dimension 0 of attn_output must be divisible by num_heads" + bsz = batch_heads // num_heads + attn_output = attn_output.transpose(0, 1).reshape(seq_len, bsz, head_dim * num_heads) + return F.linear(attn_output, out_proj_weight, out_proj_bias) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py new file mode 100644 index 0000000000..04e3dd99ce --- /dev/null +++ b/torchtext/models/multiheadattention.py @@ -0,0 +1,129 @@ +import torch +import torchtext.model.functional as F +from torch.nn.init import kaiming_uniform_ +from math import sqrt + + +class MultiheadAttentionInProjection(torch.nn.Module): + r"""Process input using multi-head attention. + Args: + embed_dim (int): Input embedding dimension + num_heads (int): Number of parallel attention heads. + head_dim (int, optional): Dimension of embedding for each attention + head. If not provided, then it is set to ``embed_dim / num_heads``. + Shape: + - seq: :math:`(S, N, E)` + - Output: :math:`(N * H, S, D)` + where S is the sequence length, N is the batch size, H is the number of + attention heads, E is the embedding dimension, and D is the head + dimension. + Attributes: + weight: The learnable weights of the module of shape + :math:`(\text{head\_dim} * \text{num\_heads}, \text{embed\_dim})`. + Examples:: + >>> # S = 21; N = 64; E = 10; D = 3; H = 4; + >>> MHA_in = nn.MultiheadAttentionInProjection(10, 4, 3) + >>> seq = torch.randn(21, 64, 10) + >>> s = MHA_in(seq) + >>> print(s.shape) + torch.Size([256, 21, 3]) + """ + __constants__ = ['embed_dim', 'num_heads', 'head_dim'] + + def __init__(self, embed_dim, num_heads, head_dim=None): + super(MultiheadAttentionInProjection, self).__init__() + if head_dim is None: + assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads when head_dim=None" + head_dim = embed_dim // num_heads + self.head_dim = head_dim + self.embed_dim = embed_dim + self.num_heads = num_heads + self.weight = torch.nn.Parameter(torch.Tensor(head_dim * num_heads, embed_dim)) + kaiming_uniform_(self.weight, a=sqrt(5)) + + def forward(self, seq): + return F.multi_head_attention_in_projection(seq, self.num_heads, self.weight, in_proj_bias=None) + + +class ScaledDotProduct(torch.nn.Module): + r"""Processes a projected query and key-value pair to apply attention + in each parallel attention head. + Args: + num_heads (int): Number of parallel attention heads. + add_zero_attn (bool): Whether to add a batch of zeros to the key and + value sequences. + dropout_p (float): probability of dropping an attention weight. + Shape: + - query: :math:`(N * H, L, D)` + - key: :math:`(N * H, S, D)` + - value: :math:`(N * H, S, D)` + - key_padding_mask: :math:`(N, S)` + - attn_mask: :math:`(L, S)` or :math:`(N * H, L, S)` + - Output: :math:`(N * H, L, D)`, :math:`(N * H, L, S)` + where L is the target sequence length, S is the source sequence + length, H is the number of attention heads, N is the batch size, + and D is the head dimension. + Examples:: + >>> # S = L = 21; N = 64; E = 10; D = 3; H = 4; + >>> SDP = nn.ScaledDotProduct(4, False, 0.1) + >>> q = torch.randn(256, 21, 3) + >>> k = v = torch.randn(256, 21, 3) + >>> attn_output, attn_weights = SDP(q, k, v) + >>> print(attn_output.shape, attn_weights.shape) + torch.Size([256, 21, 3]) torch.Size([256, 21, 21]) + """ + __constants__ = ['num_heads', 'add_zero_attn', 'dropout_p'] + + def __init__(self, num_heads, add_zero_attn=False, dropout_p=0.0): + super(ScaledDotProduct, self).__init__() + self.dropout_p = dropout_p + self.add_zero_attn = add_zero_attn + self.num_heads = num_heads + + def forward(self, query, key, value, key_padding_mask=None, attn_mask=None): + attn_output, attn_output_weights = F.scaled_dot_product_attention( + query, key, value, + self.num_heads, self.add_zero_attn, self.dropout_p, self.training, key_padding_mask, attn_mask) + return attn_output, attn_output_weights + + +class MultiheadAttentionOutProjection(torch.nn.Module): + r"""Process attention output using multi-head attention. + Args: + embed_dim (int): Input projection dimension. + num_heads (int): Number of parallel attention heads. + head_dim (int, optional): Dimension of embedding for each attention + head. If not provided, then it is set to ``embed_dim / num_heads``. + Shape: + - attn_output: :math:`(N * H, S, D)` + - Output: :math:`(S, N, E)` + where S is the sequence length, N is the batch size, H is the number of + attention heads, E is the embedding dimension, and D is the head + dimension. + Attributes: + weight: The learnable weights of the module of shape + :math:`(\text{embed\_dim}, \text{head\_dim} * \text{num\_heads})`. + Examples:: + >>> # S = 21; N = 64; E = 10; D = 3; H = 4; + >>> MHA_out = nn.MultiheadAttentionOutProjection(10, 4, 3) + >>> attn_seq = torch.randn(256, 21, 3) + >>> a = MHA_out(attn_seq) + >>> print(a.shape) + torch.Size([21, 64, 10]) + """ + __constants__ = ['embed_dim', 'num_heads', 'head_dim'] + + def __init__(self, embed_dim, num_heads, head_dim=None): + super(MultiheadAttentionOutProjection, self).__init__() + self.embed_dim = embed_dim + if head_dim is None: + assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads when head_dim=None" + head_dim = embed_dim // num_heads + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.weight = torch.nn.Parameter(torch.Tensor(embed_dim, head_dim * num_heads)) + kaiming_uniform_(self.weight, a=sqrt(5)) + + def forward(self, attn_output): + return F.multi_head_attention_out_projection(attn_output, self.num_heads, self.weight, out_proj_bias=None) From 5d06447830a8d83c45d0ffd8b595213460748409 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 2 Apr 2020 09:17:59 -0700 Subject: [PATCH 02/56] add docs --- docs/source/models.rst | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 docs/source/models.rst diff --git a/docs/source/models.rst b/docs/source/models.rst new file mode 100644 index 0000000000..a2a7bcd253 --- /dev/null +++ b/docs/source/models.rst @@ -0,0 +1,23 @@ +.. role:: hidden + :class: hidden-section + +torchtext.models.multiheadattention +================================== + +.. automodule:: torchtext.models.multiheadattention +.. currentmodule:: torchtext.models.multiheadattention + +:hidden:`MultiheadAttentionInProjection` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: MultiheadAttentionInProjection + +:hidden:`ScaledDotProduct` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: ScaledDotProduct + +:hidden:`MultiheadAttentionOutProjection` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: MultiheadAttentionOutProjection From e9e18cbd0b8d3a72e09735a2ab7f57d0fd0c2c66 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 2 Apr 2020 09:51:45 -0700 Subject: [PATCH 03/56] combine forward function with functional --- torchtext/models/functional.py | 183 ------------------------- torchtext/models/multiheadattention.py | 144 +++++++++++++++++-- 2 files changed, 129 insertions(+), 198 deletions(-) delete mode 100644 torchtext/models/functional.py diff --git a/torchtext/models/functional.py b/torchtext/models/functional.py deleted file mode 100644 index c9b84b72f7..0000000000 --- a/torchtext/models/functional.py +++ /dev/null @@ -1,183 +0,0 @@ -import torch -from torch._overrides import has_torch_function, handle_torch_function -import torch.nn.functional as F -from torch._jit_internal import Optional, Tuple - - -Tensor = torch.Tensor - - -def multi_head_attention_in_projection(seq, num_heads, in_proj_weight, in_proj_bias=None): - # type: (Tensor, int, Tensor, Optional[Tensor]) -> Tensor - r"""Projects an input sequence using parallel attention heads. - Args: - seq (Tensor): sequence to be projected - num_heads (int): number of parallel heads used. - in_proj_weight (Tensor): weight used for projection - in_proj_bias (Tensor, optional): bias used for projection. - Shape: - - seq: :math:`(S, N, E)` - - in_proj_weight: :math:`(P, E)` - - in_proj_bias: :math:`(P)` - - Output: :math:`(N * H, S, P / H)` - where S is the sequence length, H is the number of attention heads, N is the - batch size, P is the projection dimension, and E is the embedding - dimension. - """ - if not torch.jit.is_scripting(): - tens_ops = (seq, in_proj_weight) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - multi_head_attention_in_projection, tens_ops, - seq, num_heads, in_proj_weight, in_proj_bias=in_proj_bias) - seq_len, bsz, _ = seq.size() - proj_dim = in_proj_weight.size(0) - assert proj_dim % num_heads == 0, "projection dimension must be divisible by num_heads" - head_dim = proj_dim // num_heads - - q = F.linear(seq, in_proj_weight, in_proj_bias) - # Shape of q: (S, N, P) - q = q.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1) - return q - - -def scaled_dot_product_attention(q, # type: Tensor - k, # type: Tensor - v, # type: Tensor - num_heads, # type: int - add_zero_attn, # type: bool - dropout_p, # type: float - training=True, # type: bool - key_padding_mask=None, # type: Optional[Tensor] - attn_mask=None, # type: Optional[Tensor] - ): - # type: (...) -> Tuple[Tensor, Tensor] - r"""Uses a scaled dot product with the projected key-value pair to update - the projected query. - Args: - q (Tensor): Projected query - k (Tensor): Projected key - v (Tensor): Projected value - num_heads (int): Number of parallel attention heads. - add_zero_attn (bool): Add a new batch of zeros to the projected key and - value sequences at dimension 1. - dropout_p (float): Probability of an element will be zeroed. - training (bool): Apply dropout if ``training=True`` - key_padding_mask (Tensor, optional): Specified padding elements in the - key will be ignored by the attention. This is a binary mask. When - the value is True, the corresponding value on the attention layer - will be set to :math:`-\inf`. - attn_mask (Tensor, optional): 2D or 3D mask that prevents attention to - certain positions. This is an additive mask (i.e. the values will - be added to the attention layer). A 2D mask will be broadcasted for - all the batches while a 3D mask allows to specify a different mask - for the entries of each batch. - Shape: - - q: :math:`(N * H, L, P / H)` - - k: :math:`(N * H, S, P / H)` - - v: :math:`(N * H, S, P / H)` - - key_padding_mask: :math:`(N, S)` - - attn_mask: :math:`(L, S)` or :math:`(N * H, L, S)` - - Output: :math:`(N * H, L, P / H)`, :math:`(N * H, L, S)` - where L is the target length, S is the source length, H is the number - of attention heads, N is the batch size, and P is the projection - dimension. - """ - if not torch.jit.is_scripting(): - tens_ops = (q, k, v) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - scaled_dot_product_attention, tens_ops, - q, k, v, num_heads, add_zero_attn, dropout_p, - training=training, key_padding_mask=key_padding_mask, attn_mask=attn_mask) - batch_heads, tgt_len, head_dim = q.size() - assert q.size(0) == k.size(0) == v.size(0), "Dimension 0 of q, k, v must be equal." - assert batch_heads % num_heads == 0, "Dimension 0 of q, k, v must be divisible by num_heads" - bsz = batch_heads // num_heads - assert k.size() == v.size(), "Shape of k, v must match" - assert q.size(-1) == k.size(-1), "The head dimension of query must be equal to that of key" - - src_len = k.size(1) - - # Scale q - q = q * (float(head_dim) ** -0.5) - if attn_mask is not None: - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, tgt_len, src_len]: - raise RuntimeError('The size of the 2D attn_mask is not correct.') - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [batch_heads, tgt_len, src_len]: - raise RuntimeError('The size of the 3D attn_mask is not correct.') - else: - raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) - # attn_mask's dim is 3 now. - if attn_mask.dtype == torch.bool: - attn_mask = torch.where( - attn_mask, torch.tensor(float('-inf')), torch.tensor(0.)).to(dtype=q.dtype, device=q.device) - - src_len = k.size(1) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz - assert key_padding_mask.size(1) == src_len - - if add_zero_attn: - src_len += 1 - k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) - v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) - if attn_mask is not None: - attn_mask = F.pad(attn_mask, (0, 1)) - if key_padding_mask is not None: - key_padding_mask = F.pad(key_padding_mask, (0, 1)) - - # Dot product of q, k - attn_output_weights = torch.matmul(q, k.transpose(-2, -1)) - assert list(attn_output_weights.size()) == [batch_heads, tgt_len, src_len] - - if attn_mask is not None: - attn_output_weights += attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float('-inf'), - ) - attn_output_weights = attn_output_weights.reshape(batch_heads, tgt_len, src_len) - - attn_output_weights = F.softmax(attn_output_weights, dim=-1) - - attn_output = torch.matmul(F.dropout(attn_output_weights, p=dropout_p, training=training), v) - return attn_output, attn_output_weights - - -def multi_head_attention_out_projection(attn_output, num_heads, out_proj_weight, out_proj_bias=None): - # type: (Tensor, int, Tensor, Optional[Tensor]) -> Tensor - r"""Projects an output sequence using parallel attention heads. - Args: - attn_output (Tensor): Projection to be decoded to an embedding. - num_heads (int): Number of parallel attention heads - out_proj_weight (Tensor): weight used to decode projection. - out_proj_bias (Tensor, optional): bias used to decode projection. - Shape: - - attn_output: :math:`(N * H, S, P / H)` - - out_proj_weight: :math:`(E, P)` - - out_proj_bias: :math:`(E)` - - Output: :math:`(S, N, E)` - where S is the sequence length, H is the number of attention heads, N is the - batch size, P is the projection dimension, and E is the embedding - dimension. - """ - if not torch.jit.is_scripting(): - tens_ops = (attn_output, out_proj_weight) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - multi_head_attention_out_projection, tens_ops, - attn_output, num_heads, out_proj_weight, out_proj_bias=out_proj_bias) - batch_heads, seq_len, head_dim = attn_output.size() - # embed_dim = out_proj_weight.size(0) - assert batch_heads % num_heads == 0, "dimension 0 of attn_output must be divisible by num_heads" - bsz = batch_heads // num_heads - attn_output = attn_output.transpose(0, 1).reshape(seq_len, bsz, head_dim * num_heads) - return F.linear(attn_output, out_proj_weight, out_proj_bias) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index 04e3dd99ce..e28567de8c 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -1,7 +1,8 @@ import torch -import torchtext.model.functional as F -from torch.nn.init import kaiming_uniform_ -from math import sqrt +from torch._jit_internal import Optional, Tuple + + +Tensor = torch.Tensor class MultiheadAttentionInProjection(torch.nn.Module): @@ -38,11 +39,32 @@ def __init__(self, embed_dim, num_heads, head_dim=None): self.head_dim = head_dim self.embed_dim = embed_dim self.num_heads = num_heads - self.weight = torch.nn.Parameter(torch.Tensor(head_dim * num_heads, embed_dim)) - kaiming_uniform_(self.weight, a=sqrt(5)) + self.linear = torch.nn.Linear(embed_dim, head_dim) def forward(self, seq): - return F.multi_head_attention_in_projection(seq, self.num_heads, self.weight, in_proj_bias=None) + # type: (Tensor, int, Tensor, Optional[Tensor]) -> Tensor + r"""Projects an input sequence using parallel attention heads. + Args: + seq (Tensor): sequence to be projected + num_heads (int): number of parallel heads used. + in_proj_weight (Tensor): weight used for projection + in_proj_bias (Tensor, optional): bias used for projection. + Shape: + - seq: :math:`(S, N, E)` + - in_proj_weight: :math:`(P, E)` + - in_proj_bias: :math:`(P)` + - Output: :math:`(N * H, S, P / H)` + where S is the sequence length, H is the number of attention heads, N is the + batch size, P is the projection dimension, and E is the embedding + dimension. + """ + seq_len, bsz, proj_dim = seq.size() + assert proj_dim % self.num_heads == 0, "projection dimension must be divisible by num_heads" + head_dim = proj_dim // self.num_heads + q = self.linear(seq) + # Shape of q: (S, N, P) + q = q.reshape(seq_len, bsz * self.num_heads, head_dim).transpose(0, 1) + return q class ScaledDotProduct(torch.nn.Module): @@ -74,16 +96,94 @@ class ScaledDotProduct(torch.nn.Module): """ __constants__ = ['num_heads', 'add_zero_attn', 'dropout_p'] - def __init__(self, num_heads, add_zero_attn=False, dropout_p=0.0): + def __init__(self, num_heads, dropout=0.0): super(ScaledDotProduct, self).__init__() - self.dropout_p = dropout_p - self.add_zero_attn = add_zero_attn self.num_heads = num_heads + self.dropout = dropout def forward(self, query, key, value, key_padding_mask=None, attn_mask=None): - attn_output, attn_output_weights = F.scaled_dot_product_attention( - query, key, value, - self.num_heads, self.add_zero_attn, self.dropout_p, self.training, key_padding_mask, attn_mask) + # type: (...) -> Tuple[Tensor, Tensor] + r"""Uses a scaled dot product with the projected key-value pair to update + the projected query. + Args: + q (Tensor): Projected query + k (Tensor): Projected key + v (Tensor): Projected value + num_heads (int): Number of parallel attention heads. + add_zero_attn (bool): Add a new batch of zeros to the projected key and + value sequences at dimension 1. + dropout_p (float): Probability of an element will be zeroed. + training (bool): Apply dropout if ``training=True`` + key_padding_mask (Tensor, optional): Specified padding elements in the + key will be ignored by the attention. This is a binary mask. When + the value is True, the corresponding value on the attention layer + will be set to :math:`-\inf`. + attn_mask (Tensor, optional): 2D or 3D mask that prevents attention to + certain positions. This is an additive mask (i.e. the values will + be added to the attention layer). A 2D mask will be broadcasted for + all the batches while a 3D mask allows to specify a different mask + for the entries of each batch. + Shape: + - query: :math:`(N * H, L, P / H)` + - key: :math:`(N * H, S, P / H)` + - value: :math:`(N * H, S, P / H)` + - key_padding_mask: :math:`(N, S)` + - attn_mask: :math:`(L, S)` or :math:`(N * H, L, S)` + - Output: :math:`(N * H, L, P / H)`, :math:`(N * H, L, S)` + where L is the target length, S is the source length, H is the number + of attention heads, N is the batch size, and P is the projection + dimension. + """ + batch_heads, tgt_len, head_dim = query.size() + assert query.size(0) == key.size(0) == value.size(0), "Dimension 0 of query, key, value must be equal." + assert batch_heads % self.num_heads == 0, "Dimension 0 of query, key, value must be divisible by num_heads" + bsz = batch_heads // self.num_heads + assert key.size() == value.size(), "Shape of key, value must match" + assert query.size(-1) == key.size(-1), "The head dimension of query must be equal to that of key" + + src_len = key.size(1) + + # Scale query + query = query * (float(head_dim) ** -0.5) + if attn_mask is not None: + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, tgt_len, src_len]: + raise RuntimeError('The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [batch_heads, tgt_len, src_len]: + raise RuntimeError('The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) + # attn_mask's dim is 3 now. + if attn_mask.dtype == torch.bool: + attn_mask = torch.where( + attn_mask, torch.tensor(float('-inf')), torch.tensor(0.)).to(dtype=query.dtype, device=query.device) + + src_len = key.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + # Dot product of q, k + attn_output_weights = torch.matmul(query, key.transpose(-2, -1)) + assert list(attn_output_weights.size()) == [batch_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.reshape(batch_heads, tgt_len, src_len) + + attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1) + + attn_output = torch.matmul(self.dropout(attn_output_weights), value) return attn_output, attn_output_weights @@ -122,8 +222,22 @@ def __init__(self, embed_dim, num_heads, head_dim=None): self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = head_dim - self.weight = torch.nn.Parameter(torch.Tensor(embed_dim, head_dim * num_heads)) - kaiming_uniform_(self.weight, a=sqrt(5)) + self.linear = torch.nn.Linear(head_dim, embed_dim) def forward(self, attn_output): - return F.multi_head_attention_out_projection(attn_output, self.num_heads, self.weight, out_proj_bias=None) + # type: (Tensor, int, Tensor, Optional[Tensor]) -> Tensor + r"""Projects an output sequence using parallel attention heads. + Args: + attn_output (Tensor): Projection to be decoded to an embedding. + Shape: + - attn_output: :math:`(N * H, S, P / H)` + where S is the sequence length, H is the number of attention heads, N is the + batch size, P is the projection dimension, and E is the embedding + dimension. + """ + batch_heads, seq_len, head_dim = attn_output.size() + # embed_dim = out_proj_weight.size(0) + assert batch_heads % self.num_heads == 0, "dimension 0 of attn_output must be divisible by num_heads" + bsz = batch_heads // self.num_heads + attn_output = attn_output.transpose(0, 1).reshape(seq_len, bsz, head_dim * self.num_heads) + return self.linear(attn_output) From 36c876a3fe912eb6729d9cd364c8243e09a30bf9 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 2 Apr 2020 10:36:02 -0700 Subject: [PATCH 04/56] add models to init --- torchtext/__init__.py | 2 ++ torchtext/models/multiheadattention.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torchtext/__init__.py b/torchtext/__init__.py index 1f31d001c5..5ec6d82420 100644 --- a/torchtext/__init__.py +++ b/torchtext/__init__.py @@ -1,4 +1,5 @@ from . import data +from . import models from . import datasets from . import utils from . import vocab @@ -7,6 +8,7 @@ __version__ = '0.5.1' __all__ = ['data', + 'models', 'datasets', 'utils', 'vocab', diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index e28567de8c..6fba17d608 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -23,7 +23,7 @@ class MultiheadAttentionInProjection(torch.nn.Module): :math:`(\text{head\_dim} * \text{num\_heads}, \text{embed\_dim})`. Examples:: >>> # S = 21; N = 64; E = 10; D = 3; H = 4; - >>> MHA_in = nn.MultiheadAttentionInProjection(10, 4, 3) + >>> MHA_in = torchtext.models.MultiheadAttentionInProjection(10, 5) >>> seq = torch.randn(21, 64, 10) >>> s = MHA_in(seq) >>> print(s.shape) @@ -182,8 +182,8 @@ def forward(self, query, key, value, key_padding_mask=None, attn_mask=None): attn_output_weights = attn_output_weights.reshape(batch_heads, tgt_len, src_len) attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1) - - attn_output = torch.matmul(self.dropout(attn_output_weights), value) + attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_output_weights, value) return attn_output, attn_output_weights From bddc782adca6f228cf7e497fbe8a3e5fe7a0606d Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 2 Apr 2020 10:53:09 -0700 Subject: [PATCH 05/56] minor revisions --- torchtext/models/multiheadattention.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index 6fba17d608..9fad59c582 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -27,7 +27,7 @@ class MultiheadAttentionInProjection(torch.nn.Module): >>> seq = torch.randn(21, 64, 10) >>> s = MHA_in(seq) >>> print(s.shape) - torch.Size([256, 21, 3]) + torch.Size([320, 21, 2]) """ __constants__ = ['embed_dim', 'num_heads', 'head_dim'] @@ -35,11 +35,12 @@ def __init__(self, embed_dim, num_heads, head_dim=None): super(MultiheadAttentionInProjection, self).__init__() if head_dim is None: assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads when head_dim=None" - head_dim = embed_dim // num_heads - self.head_dim = head_dim + self.head_dim = embed_dim // num_heads + else: + self.head_dim = head_dim self.embed_dim = embed_dim self.num_heads = num_heads - self.linear = torch.nn.Linear(embed_dim, head_dim) + self.linear = torch.nn.Linear(embed_dim, self.num_heads * self.head_dim) def forward(self, seq): # type: (Tensor, int, Tensor, Optional[Tensor]) -> Tensor @@ -72,9 +73,7 @@ class ScaledDotProduct(torch.nn.Module): in each parallel attention head. Args: num_heads (int): Number of parallel attention heads. - add_zero_attn (bool): Whether to add a batch of zeros to the key and - value sequences. - dropout_p (float): probability of dropping an attention weight. + dropout (float): probability of dropping an attention weight. Shape: - query: :math:`(N * H, L, D)` - key: :math:`(N * H, S, D)` @@ -87,7 +86,7 @@ class ScaledDotProduct(torch.nn.Module): and D is the head dimension. Examples:: >>> # S = L = 21; N = 64; E = 10; D = 3; H = 4; - >>> SDP = nn.ScaledDotProduct(4, False, 0.1) + >>> SDP = torchtext.models.ScaledDotProduct(4, 0.1) >>> q = torch.randn(256, 21, 3) >>> k = v = torch.randn(256, 21, 3) >>> attn_output, attn_weights = SDP(q, k, v) @@ -218,11 +217,12 @@ def __init__(self, embed_dim, num_heads, head_dim=None): self.embed_dim = embed_dim if head_dim is None: assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads when head_dim=None" - head_dim = embed_dim // num_heads + self.head_dim = embed_dim // num_heads + else: + self.head_dim = head_dim self.embed_dim = embed_dim self.num_heads = num_heads - self.head_dim = head_dim - self.linear = torch.nn.Linear(head_dim, embed_dim) + self.linear = torch.nn.Linear(self.num_heads * self.head_dim, embed_dim) def forward(self, attn_output): # type: (Tensor, int, Tensor, Optional[Tensor]) -> Tensor From e665f38431959af4efe49fc853da7e89ea622663 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 2 Apr 2020 11:01:27 -0700 Subject: [PATCH 06/56] minor change --- torchtext/models/multiheadattention.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index 9fad59c582..374f08b6cf 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -204,25 +204,24 @@ class MultiheadAttentionOutProjection(torch.nn.Module): :math:`(\text{embed\_dim}, \text{head\_dim} * \text{num\_heads})`. Examples:: >>> # S = 21; N = 64; E = 10; D = 3; H = 4; - >>> MHA_out = nn.MultiheadAttentionOutProjection(10, 4, 3) - >>> attn_seq = torch.randn(256, 21, 3) + >>> MHA_out = torchtext.models.MultiheadAttentionOutProjection(2, 5) + >>> attn_seq = torch.randn(320, 21, 2) >>> a = MHA_out(attn_seq) >>> print(a.shape) torch.Size([21, 64, 10]) """ __constants__ = ['embed_dim', 'num_heads', 'head_dim'] - def __init__(self, embed_dim, num_heads, head_dim=None): + def __init__(self, head_dim, num_heads, embed_dim=None): super(MultiheadAttentionOutProjection, self).__init__() - self.embed_dim = embed_dim - if head_dim is None: + self.head_dim = head_dim + self.num_heads = num_heads + if embed_dim: assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads when head_dim=None" - self.head_dim = embed_dim // num_heads + self.embed_dim = embed_dim else: - self.head_dim = head_dim - self.embed_dim = embed_dim - self.num_heads = num_heads - self.linear = torch.nn.Linear(self.num_heads * self.head_dim, embed_dim) + self.embed_dim = head_dim * num_heads + self.linear = torch.nn.Linear(self.num_heads * self.head_dim, self.embed_dim) def forward(self, attn_output): # type: (Tensor, int, Tensor, Optional[Tensor]) -> Tensor From 4a44337dd7792d3f3ad5440c8b70eb73c11141a3 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 2 Apr 2020 12:54:31 -0700 Subject: [PATCH 07/56] revision --- torchtext/models/multiheadattention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index 374f08b6cf..78ed975e9b 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -40,7 +40,7 @@ def __init__(self, embed_dim, num_heads, head_dim=None): self.head_dim = head_dim self.embed_dim = embed_dim self.num_heads = num_heads - self.linear = torch.nn.Linear(embed_dim, self.num_heads * self.head_dim) + self.linear = torch.nn.Linear(embed_dim, self.num_heads * self.head_dim, bias=False) def forward(self, seq): # type: (Tensor, int, Tensor, Optional[Tensor]) -> Tensor From fa36f85abe6dbd7b0265570297857b4eb8437f40 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 2 Apr 2020 13:11:56 -0700 Subject: [PATCH 08/56] Add unit test --- test/data/test_models.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 test/data/test_models.py diff --git a/test/data/test_models.py b/test/data/test_models.py new file mode 100644 index 0000000000..78d24eaa06 --- /dev/null +++ b/test/data/test_models.py @@ -0,0 +1,39 @@ +import torch +from torchtext.models import MultiheadAttentionInProjection, \ + ScaledDotProduct, MultiheadAttentionOutProjection +from torch.nn.functional import multi_head_attention_forward as mha_forward +from torch.testing import assert_allclose +from ..common.torchtext_test_case import TorchtextTestCase + + +class TestUtils(TorchtextTestCase): + + def test_multiheadattention(self): + embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 + # Build torchtext MultiheadAttention models + q_in = MultiheadAttentionInProjection(embed_dim, nhead) + k_in = MultiheadAttentionInProjection(embed_dim, nhead) + v_in = MultiheadAttentionInProjection(embed_dim, nhead) + MHA_out = MultiheadAttentionOutProjection(embed_dim // nhead, nhead) + SDP = ScaledDotProduct(nhead) + + query = torch.randn(tgt_len, bsz, embed_dim) + key = value = torch.randn(src_len, bsz, embed_dim) + + # MultiheadAttention with building blocks + q = q_in(query) + k = k_in(key) + v = v_in(value) + attn_output, attn_weights = SDP(q, k, v) + mha_output = MHA_out(attn_output) + + # Use torch.nn.functional.multi_head_attention_forward + in_proj_weight = torch.cat([q_in.linear.weight, k_in.linear.weight, v_in.linear.weight]) + torch_mha_output, torch_mha_weights = mha_forward(query, key, value, + embed_dim, nhead, + in_proj_weight, None, + None, None, False, 0.0, + MHA_out.linear.weight, MHA_out.linear.bias) + assert_allclose(mha_output, torch_mha_output) + attn_weights = attn_weights.view(bsz, nhead, tgt_len, src_len).sum(dim=1) / nhead + assert_allclose(attn_weights, torch_mha_weights) From 45fed348e3fa32bd841ee3c4f190e9dd4dbf6029 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 2 Apr 2020 14:19:56 -0700 Subject: [PATCH 09/56] update docs --- torchtext/models/multiheadattention.py | 179 +++++++++---------------- 1 file changed, 67 insertions(+), 112 deletions(-) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index 78ed975e9b..88922fd81f 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -6,38 +6,24 @@ class MultiheadAttentionInProjection(torch.nn.Module): - r"""Process input using multi-head attention. - Args: - embed_dim (int): Input embedding dimension - num_heads (int): Number of parallel attention heads. - head_dim (int, optional): Dimension of embedding for each attention - head. If not provided, then it is set to ``embed_dim / num_heads``. - Shape: - - seq: :math:`(S, N, E)` - - Output: :math:`(N * H, S, D)` - where S is the sequence length, N is the batch size, H is the number of - attention heads, E is the embedding dimension, and D is the head - dimension. - Attributes: - weight: The learnable weights of the module of shape - :math:`(\text{head\_dim} * \text{num\_heads}, \text{embed\_dim})`. - Examples:: - >>> # S = 21; N = 64; E = 10; D = 3; H = 4; - >>> MHA_in = torchtext.models.MultiheadAttentionInProjection(10, 5) - >>> seq = torch.randn(21, 64, 10) - >>> s = MHA_in(seq) - >>> print(s.shape) - torch.Size([320, 21, 2]) - """ - __constants__ = ['embed_dim', 'num_heads', 'head_dim'] - - def __init__(self, embed_dim, num_heads, head_dim=None): + __constants__ = ['embed_dim', 'num_heads'] + + def __init__(self, embed_dim, num_heads): + r"""Process input using multi-head attention. + Args: + embed_dim (int): Input embedding dimension + num_heads (int): Number of parallel attention heads. + + Examples:: + >>> MHA_in = torchtext.models.MultiheadAttentionInProjection(10, 5) + >>> seq = torch.randn(21, 64, 10) + >>> s = MHA_in(seq) + >>> print(s.shape) + torch.Size([320, 21, 2]) + """ super(MultiheadAttentionInProjection, self).__init__() - if head_dim is None: - assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads when head_dim=None" - self.head_dim = embed_dim // num_heads - else: - self.head_dim = head_dim + assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads when head_dim=None" + self.head_dim = embed_dim // num_heads self.embed_dim = embed_dim self.num_heads = num_heads self.linear = torch.nn.Linear(embed_dim, self.num_heads * self.head_dim, bias=False) @@ -47,17 +33,12 @@ def forward(self, seq): r"""Projects an input sequence using parallel attention heads. Args: seq (Tensor): sequence to be projected - num_heads (int): number of parallel heads used. - in_proj_weight (Tensor): weight used for projection - in_proj_bias (Tensor, optional): bias used for projection. + Shape: - seq: :math:`(S, N, E)` - - in_proj_weight: :math:`(P, E)` - - in_proj_bias: :math:`(P)` - - Output: :math:`(N * H, S, P / H)` + - Output: :math:`(N * H, S, E / H)` where S is the sequence length, H is the number of attention heads, N is the - batch size, P is the projection dimension, and E is the embedding - dimension. + batch size, and E is the embedding dimension. """ seq_len, bsz, proj_dim = seq.size() assert proj_dim % self.num_heads == 0, "projection dimension must be divisible by num_heads" @@ -69,33 +50,24 @@ def forward(self, seq): class ScaledDotProduct(torch.nn.Module): - r"""Processes a projected query and key-value pair to apply attention - in each parallel attention head. - Args: - num_heads (int): Number of parallel attention heads. - dropout (float): probability of dropping an attention weight. - Shape: - - query: :math:`(N * H, L, D)` - - key: :math:`(N * H, S, D)` - - value: :math:`(N * H, S, D)` - - key_padding_mask: :math:`(N, S)` - - attn_mask: :math:`(L, S)` or :math:`(N * H, L, S)` - - Output: :math:`(N * H, L, D)`, :math:`(N * H, L, S)` - where L is the target sequence length, S is the source sequence - length, H is the number of attention heads, N is the batch size, - and D is the head dimension. - Examples:: - >>> # S = L = 21; N = 64; E = 10; D = 3; H = 4; - >>> SDP = torchtext.models.ScaledDotProduct(4, 0.1) - >>> q = torch.randn(256, 21, 3) - >>> k = v = torch.randn(256, 21, 3) - >>> attn_output, attn_weights = SDP(q, k, v) - >>> print(attn_output.shape, attn_weights.shape) - torch.Size([256, 21, 3]) torch.Size([256, 21, 21]) - """ - __constants__ = ['num_heads', 'add_zero_attn', 'dropout_p'] + __constants__ = ['num_heads', 'dropout'] def __init__(self, num_heads, dropout=0.0): + r"""Processes a projected query and key-value pair to apply + scaled dot product attention. + + Args: + num_heads (int): Number of parallel attention heads. + dropout (float): probability of dropping an attention weight. + + Examples:: + >>> SDP = torchtext.models.ScaledDotProduct(4, 0.1) + >>> q = torch.randn(256, 21, 3) + >>> k = v = torch.randn(256, 21, 3) + >>> attn_output, attn_weights = SDP(q, k, v) + >>> print(attn_output.shape, attn_weights.shape) + torch.Size([256, 21, 3]) torch.Size([256, 21, 21]) + """ super(ScaledDotProduct, self).__init__() self.num_heads = num_heads self.dropout = dropout @@ -104,15 +76,11 @@ def forward(self, query, key, value, key_padding_mask=None, attn_mask=None): # type: (...) -> Tuple[Tensor, Tensor] r"""Uses a scaled dot product with the projected key-value pair to update the projected query. + Args: - q (Tensor): Projected query - k (Tensor): Projected key - v (Tensor): Projected value - num_heads (int): Number of parallel attention heads. - add_zero_attn (bool): Add a new batch of zeros to the projected key and - value sequences at dimension 1. - dropout_p (float): Probability of an element will be zeroed. - training (bool): Apply dropout if ``training=True`` + query (Tensor): Projected query + key (Tensor): Projected key + value (Tensor): Projected value key_padding_mask (Tensor, optional): Specified padding elements in the key will be ignored by the attention. This is a binary mask. When the value is True, the corresponding value on the attention layer @@ -123,15 +91,14 @@ def forward(self, query, key, value, key_padding_mask=None, attn_mask=None): all the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shape: - - query: :math:`(N * H, L, P / H)` - - key: :math:`(N * H, S, P / H)` - - value: :math:`(N * H, S, P / H)` + - query: :math:`(N * H, L, E / H)` + - key: :math:`(N * H, S, E / H)` + - value: :math:`(N * H, S, E / H)` - key_padding_mask: :math:`(N, S)` - attn_mask: :math:`(L, S)` or :math:`(N * H, L, S)` - - Output: :math:`(N * H, L, P / H)`, :math:`(N * H, L, S)` + - Output: :math:`(N * H, L, E / H)`, :math:`(N * H, L, S)` where L is the target length, S is the source length, H is the number - of attention heads, N is the batch size, and P is the projection - dimension. + of attention heads, N is the batch size, and E is the embedding dimension. """ batch_heads, tgt_len, head_dim = query.size() assert query.size(0) == key.size(0) == value.size(0), "Dimension 0 of query, key, value must be equal." @@ -187,52 +154,40 @@ def forward(self, query, key, value, key_padding_mask=None, attn_mask=None): class MultiheadAttentionOutProjection(torch.nn.Module): - r"""Process attention output using multi-head attention. - Args: - embed_dim (int): Input projection dimension. - num_heads (int): Number of parallel attention heads. - head_dim (int, optional): Dimension of embedding for each attention - head. If not provided, then it is set to ``embed_dim / num_heads``. - Shape: - - attn_output: :math:`(N * H, S, D)` - - Output: :math:`(S, N, E)` - where S is the sequence length, N is the batch size, H is the number of - attention heads, E is the embedding dimension, and D is the head - dimension. - Attributes: - weight: The learnable weights of the module of shape - :math:`(\text{embed\_dim}, \text{head\_dim} * \text{num\_heads})`. - Examples:: - >>> # S = 21; N = 64; E = 10; D = 3; H = 4; - >>> MHA_out = torchtext.models.MultiheadAttentionOutProjection(2, 5) - >>> attn_seq = torch.randn(320, 21, 2) - >>> a = MHA_out(attn_seq) - >>> print(a.shape) - torch.Size([21, 64, 10]) - """ - __constants__ = ['embed_dim', 'num_heads', 'head_dim'] - - def __init__(self, head_dim, num_heads, embed_dim=None): + __constants__ = ['head_dim', 'num_heads'] + + def __init__(self, head_dim, num_heads): + r"""Process attention output using multi-head attention. + + Args: + head_dim (int): Dimension of embedding for each attention head. + num_heads (int): Number of parallel attention heads. + + Examples:: + >>> MHA_out = torchtext.models.MultiheadAttentionOutProjection(2, 5) + >>> attn_seq = torch.randn(320, 21, 2) + >>> a = MHA_out(attn_seq) + >>> print(a.shape) + torch.Size([21, 64, 10]) + """ super(MultiheadAttentionOutProjection, self).__init__() self.head_dim = head_dim self.num_heads = num_heads - if embed_dim: - assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads when head_dim=None" - self.embed_dim = embed_dim - else: - self.embed_dim = head_dim * num_heads + self.embed_dim = head_dim * num_heads self.linear = torch.nn.Linear(self.num_heads * self.head_dim, self.embed_dim) def forward(self, attn_output): # type: (Tensor, int, Tensor, Optional[Tensor]) -> Tensor r"""Projects an output sequence using parallel attention heads. + Args: attn_output (Tensor): Projection to be decoded to an embedding. + Shape: - - attn_output: :math:`(N * H, S, P / H)` + - attn_output: :math:`(N * H, S, E / H)` + - Output: :math:`(S, N, E)` where S is the sequence length, H is the number of attention heads, N is the - batch size, P is the projection dimension, and E is the embedding - dimension. + batch size, and E is the embedding dimension. """ batch_heads, seq_len, head_dim = attn_output.size() # embed_dim = out_proj_weight.size(0) From a6a2d942f3ffc753a4ae05282728bc5c1e665190 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 2 Apr 2020 14:25:32 -0700 Subject: [PATCH 10/56] flake8 --- torchtext/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/__init__.py b/torchtext/__init__.py index 5ec6d82420..2b4a8cb2bf 100644 --- a/torchtext/__init__.py +++ b/torchtext/__init__.py @@ -1,5 +1,5 @@ from . import data -from . import models +from . import models from . import datasets from . import utils from . import vocab From b741e1f694219114bd61cecb7654fbd0e2d836a4 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 8 Apr 2020 07:41:49 -0700 Subject: [PATCH 11/56] add MultiheadAttentionContainer --- torchtext/models/multiheadattention.py | 30 +++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index 88922fd81f..b5150f8a13 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -5,6 +5,34 @@ Tensor = torch.Tensor +class MultiheadAttentionContainer(torch.nn.Module): + def __init__(self, embed_dim, num_heads, attention_layer=None, dropout=0.0): + super(MultiheadAttentionContainer, self).__init__() + assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads when head_dim=None" + self.head_dim = embed_dim // num_heads + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_in_proj = torch.nn.Linear(embed_dim, self.num_heads * self.head_dim, bias=False) + self.key_in_proj = torch.nn.Linear(embed_dim, self.num_heads * self.head_dim, bias=False) + self.value_in_proj = torch.nn.Linear(embed_dim, self.num_heads * self.head_dim, bias=False) + if attention_layer: + self.attention_layer = attention_layer + else: + self.attention_layer = ScaledDotProduct(num_heads, dropout=dropout) + self.out_proj = torch.nn.Linear(num_heads * self.head_dim, embed_dim) + + def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): + seq_len, bsz, proj_dim = query.size() + tgt_len = key.size(0) + q = self.query_in_proj(query).reshape(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + k = self.key_in_proj(key).reshape(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + v = self.value_in_proj(value).reshape(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + attn_output, attn_output_weights = self.attention_layer(q, k, v, attn_mask=attn_mask, + key_padding_mask=key_padding_mask) + attn_output = self.out_proj(attn_output.transpose(0, 1).reshape(seq_len, bsz, self.head_dim * self.num_heads)) + return attn_output, attn_output_weights + + class MultiheadAttentionInProjection(torch.nn.Module): __constants__ = ['embed_dim', 'num_heads'] @@ -72,7 +100,7 @@ def __init__(self, num_heads, dropout=0.0): self.num_heads = num_heads self.dropout = dropout - def forward(self, query, key, value, key_padding_mask=None, attn_mask=None): + def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): # type: (...) -> Tuple[Tensor, Tensor] r"""Uses a scaled dot product with the projected key-value pair to update the projected query. From ba8cd3abcfc9035f2c6620d43506c8ff2b07cd23 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 8 Apr 2020 09:15:27 -0700 Subject: [PATCH 12/56] update models init file --- torchtext/models/__init__.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchtext/models/__init__.py b/torchtext/models/__init__.py index 65aada6847..70ba44939b 100644 --- a/torchtext/models/__init__.py +++ b/torchtext/models/__init__.py @@ -1,6 +1,4 @@ -from .multiheadattention import MultiheadAttentionInProjection, \ - ScaledDotProduct, MultiheadAttentionOutProjection +from .multiheadattention import MultiheadAttentionContainer, ScaledDotProduct -__all__ = ['MultiheadAttentionInProjection', - 'ScaledDotProduct', - 'MultiheadAttentionOutProjection'] +__all__ = ['MultiheadAttentionContainer', + 'ScaledDotProduct'] From 1c35a05c695c7ad8cf4cd463f6f301ad8046065d Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 8 Apr 2020 10:40:34 -0700 Subject: [PATCH 13/56] update docs of container --- torchtext/models/multiheadattention.py | 45 ++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index b5150f8a13..7e49849414 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -7,6 +7,22 @@ class MultiheadAttentionContainer(torch.nn.Module): def __init__(self, embed_dim, num_heads, attention_layer=None, dropout=0.0): + r"""Process input using multi-head attention. + Args: + embed_dim (int): Input embedding dimension + num_heads (int): Number of parallel attention heads. + attention_layer: The attention layer. The default is None and scaled dot product + attention will be used. + dropout: the dropout value (default=0.1). + + Examples:: + >>> MHA = torchtext.models.MultiheadAttentionContainer(10, 5) + >>> query = torch.rand((21, 64, 10)) + >>> key = value = torch.rand((16, 64, 10)) + >>> attn_output, attn_weights = MHA(query, key, value) + >>> print(attn_output.shape) + >>> torch.Size([21, 64, 10]) + """ super(MultiheadAttentionContainer, self).__init__() assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads when head_dim=None" self.head_dim = embed_dim // num_heads @@ -22,6 +38,35 @@ def __init__(self, embed_dim, num_heads, attention_layer=None, dropout=0.0): self.out_proj = torch.nn.Linear(num_heads * self.head_dim, embed_dim) def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): + r"""Uses a scaled dot product with the projected key-value pair to update + the projected query. + + Args: + query, key, value (Tensor): map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + key_padding_mask (Tensor, optional): if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + attn_mask (Tensor, optional): 2D or 3D mask that prevents attention to certain positions. + This is an additive mask (i.e. the values will be added to the attention layer). A 2D mask + will be broadcasted for all the batches while a 3D mask allows to specify a different mask + for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` + - key: :math:`(S, N, E)` + - value: :math:`(S, N, E)` + - attn_mask: 3D mask :math:`(N*num_heads, L, S)` + - key_padding_mask: :math:`(N, S)` + + - Outputs: + - attn_output: :math:`(L, N, E)` + - attn_output_weights: :math:`(N*num_heads, L, S)` + + where where L is the target length, S is the sequence length, H is the number of attention heads, + N is the batch size, and E is the embedding dimension. + """ seq_len, bsz, proj_dim = query.size() tgt_len = key.size(0) q = self.query_in_proj(query).reshape(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) From 541556826df11ba848fec4358da952b4680b3d34 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 8 Apr 2020 10:45:21 -0700 Subject: [PATCH 14/56] update MHA test --- test/data/test_models.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/test/data/test_models.py b/test/data/test_models.py index 78d24eaa06..210414f200 100644 --- a/test/data/test_models.py +++ b/test/data/test_models.py @@ -1,6 +1,5 @@ import torch -from torchtext.models import MultiheadAttentionInProjection, \ - ScaledDotProduct, MultiheadAttentionOutProjection +from torchtext.models import MultiheadAttentionContainer from torch.nn.functional import multi_head_attention_forward as mha_forward from torch.testing import assert_allclose from ..common.torchtext_test_case import TorchtextTestCase @@ -11,29 +10,20 @@ class TestUtils(TorchtextTestCase): def test_multiheadattention(self): embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 # Build torchtext MultiheadAttention models - q_in = MultiheadAttentionInProjection(embed_dim, nhead) - k_in = MultiheadAttentionInProjection(embed_dim, nhead) - v_in = MultiheadAttentionInProjection(embed_dim, nhead) - MHA_out = MultiheadAttentionOutProjection(embed_dim // nhead, nhead) - SDP = ScaledDotProduct(nhead) + MHA = MultiheadAttentionContainer(embed_dim, nhead) - query = torch.randn(tgt_len, bsz, embed_dim) - key = value = torch.randn(src_len, bsz, embed_dim) - - # MultiheadAttention with building blocks - q = q_in(query) - k = k_in(key) - v = v_in(value) - attn_output, attn_weights = SDP(q, k, v) - mha_output = MHA_out(attn_output) + query = torch.rand((tgt_len, bsz, embed_dim)) + key = value = torch.rand((src_len, bsz, embed_dim)) + mha_output, attn_weights = MHA(query, key, value) # Use torch.nn.functional.multi_head_attention_forward - in_proj_weight = torch.cat([q_in.linear.weight, k_in.linear.weight, v_in.linear.weight]) + in_proj_weight = torch.cat([MHA.query_in_proj.weight, MHA.key_in_proj.weight, MHA.value_in_proj.weight]) torch_mha_output, torch_mha_weights = mha_forward(query, key, value, embed_dim, nhead, in_proj_weight, None, None, None, False, 0.0, - MHA_out.linear.weight, MHA_out.linear.bias) + MHA.out_proj.weight, MHA.out_proj.bias) + assert_allclose(mha_output, torch_mha_output) attn_weights = attn_weights.view(bsz, nhead, tgt_len, src_len).sum(dim=1) / nhead assert_allclose(attn_weights, torch_mha_weights) From 2055c1634d281a052eead14ed8ea0ae2309d8394 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 13 Apr 2020 13:21:27 -0700 Subject: [PATCH 15/56] remove in/out projection --- docs/source/models.rst | 9 +-- torchtext/models/multiheadattention.py | 88 -------------------------- 2 files changed, 2 insertions(+), 95 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index a2a7bcd253..0af2fa2e5e 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -7,17 +7,12 @@ torchtext.models.multiheadattention .. automodule:: torchtext.models.multiheadattention .. currentmodule:: torchtext.models.multiheadattention -:hidden:`MultiheadAttentionInProjection` +:hidden:`MultiheadAttentionContainer` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: MultiheadAttentionInProjection +.. autofunction:: MultiheadAttentionContainer :hidden:`ScaledDotProduct` ~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: ScaledDotProduct - -:hidden:`MultiheadAttentionOutProjection` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: MultiheadAttentionOutProjection diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index 7e49849414..f3f0a6b3f6 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -78,50 +78,6 @@ def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): return attn_output, attn_output_weights -class MultiheadAttentionInProjection(torch.nn.Module): - __constants__ = ['embed_dim', 'num_heads'] - - def __init__(self, embed_dim, num_heads): - r"""Process input using multi-head attention. - Args: - embed_dim (int): Input embedding dimension - num_heads (int): Number of parallel attention heads. - - Examples:: - >>> MHA_in = torchtext.models.MultiheadAttentionInProjection(10, 5) - >>> seq = torch.randn(21, 64, 10) - >>> s = MHA_in(seq) - >>> print(s.shape) - torch.Size([320, 21, 2]) - """ - super(MultiheadAttentionInProjection, self).__init__() - assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads when head_dim=None" - self.head_dim = embed_dim // num_heads - self.embed_dim = embed_dim - self.num_heads = num_heads - self.linear = torch.nn.Linear(embed_dim, self.num_heads * self.head_dim, bias=False) - - def forward(self, seq): - # type: (Tensor, int, Tensor, Optional[Tensor]) -> Tensor - r"""Projects an input sequence using parallel attention heads. - Args: - seq (Tensor): sequence to be projected - - Shape: - - seq: :math:`(S, N, E)` - - Output: :math:`(N * H, S, E / H)` - where S is the sequence length, H is the number of attention heads, N is the - batch size, and E is the embedding dimension. - """ - seq_len, bsz, proj_dim = seq.size() - assert proj_dim % self.num_heads == 0, "projection dimension must be divisible by num_heads" - head_dim = proj_dim // self.num_heads - q = self.linear(seq) - # Shape of q: (S, N, P) - q = q.reshape(seq_len, bsz * self.num_heads, head_dim).transpose(0, 1) - return q - - class ScaledDotProduct(torch.nn.Module): __constants__ = ['num_heads', 'dropout'] @@ -224,47 +180,3 @@ def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_output_weights, value) return attn_output, attn_output_weights - - -class MultiheadAttentionOutProjection(torch.nn.Module): - __constants__ = ['head_dim', 'num_heads'] - - def __init__(self, head_dim, num_heads): - r"""Process attention output using multi-head attention. - - Args: - head_dim (int): Dimension of embedding for each attention head. - num_heads (int): Number of parallel attention heads. - - Examples:: - >>> MHA_out = torchtext.models.MultiheadAttentionOutProjection(2, 5) - >>> attn_seq = torch.randn(320, 21, 2) - >>> a = MHA_out(attn_seq) - >>> print(a.shape) - torch.Size([21, 64, 10]) - """ - super(MultiheadAttentionOutProjection, self).__init__() - self.head_dim = head_dim - self.num_heads = num_heads - self.embed_dim = head_dim * num_heads - self.linear = torch.nn.Linear(self.num_heads * self.head_dim, self.embed_dim) - - def forward(self, attn_output): - # type: (Tensor, int, Tensor, Optional[Tensor]) -> Tensor - r"""Projects an output sequence using parallel attention heads. - - Args: - attn_output (Tensor): Projection to be decoded to an embedding. - - Shape: - - attn_output: :math:`(N * H, S, E / H)` - - Output: :math:`(S, N, E)` - where S is the sequence length, H is the number of attention heads, N is the - batch size, and E is the embedding dimension. - """ - batch_heads, seq_len, head_dim = attn_output.size() - # embed_dim = out_proj_weight.size(0) - assert batch_heads % self.num_heads == 0, "dimension 0 of attn_output must be divisible by num_heads" - bsz = batch_heads // self.num_heads - attn_output = attn_output.transpose(0, 1).reshape(seq_len, bsz, head_dim * self.num_heads) - return self.linear(attn_output) From 9adc723b70a23958b12f663e952fb15e781f5b3b Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 15 Apr 2020 12:11:33 -0700 Subject: [PATCH 16/56] Switch MultiheadAttentionContainer to accept ScaledDotProduct, MultiheadInProject, MultiheadOutProject --- test/data/test_models.py | 16 +++-- torchtext/models/__init__.py | 7 +- torchtext/models/multiheadattention.py | 97 ++++++++++++++++---------- 3 files changed, 78 insertions(+), 42 deletions(-) diff --git a/test/data/test_models.py b/test/data/test_models.py index 210414f200..8091309add 100644 --- a/test/data/test_models.py +++ b/test/data/test_models.py @@ -1,5 +1,6 @@ import torch -from torchtext.models import MultiheadAttentionContainer +from torchtext.models import MultiheadAttentionContainer, \ + ScaledDotProduct, MultiheadInProject, MultiheadOutProject from torch.nn.functional import multi_head_attention_forward as mha_forward from torch.testing import assert_allclose from ..common.torchtext_test_case import TorchtextTestCase @@ -10,19 +11,26 @@ class TestUtils(TorchtextTestCase): def test_multiheadattention(self): embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 # Build torchtext MultiheadAttention models - MHA = MultiheadAttentionContainer(embed_dim, nhead) + MHA = MultiheadAttentionContainer((MultiheadInProject(embed_dim, nhead), + MultiheadInProject(embed_dim, nhead), + MultiheadInProject(embed_dim, nhead)), + ScaledDotProduct(nhead), + MultiheadOutProject(embed_dim // nhead, nhead)) query = torch.rand((tgt_len, bsz, embed_dim)) key = value = torch.rand((src_len, bsz, embed_dim)) mha_output, attn_weights = MHA(query, key, value) # Use torch.nn.functional.multi_head_attention_forward - in_proj_weight = torch.cat([MHA.query_in_proj.weight, MHA.key_in_proj.weight, MHA.value_in_proj.weight]) + in_proj_weight = torch.cat([MHA.query_in_proj.proj_layer.weight, + MHA.key_in_proj.proj_layer.weight, + MHA.value_in_proj.proj_layer.weight]) torch_mha_output, torch_mha_weights = mha_forward(query, key, value, embed_dim, nhead, in_proj_weight, None, None, None, False, 0.0, - MHA.out_proj.weight, MHA.out_proj.bias) + MHA.out_proj.proj_layer.weight, + MHA.out_proj.proj_layer.bias) assert_allclose(mha_output, torch_mha_output) attn_weights = attn_weights.view(bsz, nhead, tgt_len, src_len).sum(dim=1) / nhead diff --git a/torchtext/models/__init__.py b/torchtext/models/__init__.py index 70ba44939b..f821fb9a8f 100644 --- a/torchtext/models/__init__.py +++ b/torchtext/models/__init__.py @@ -1,4 +1,7 @@ -from .multiheadattention import MultiheadAttentionContainer, ScaledDotProduct +from .multiheadattention import MultiheadInProject, MultiheadOutProject, \ + MultiheadAttentionContainer, ScaledDotProduct -__all__ = ['MultiheadAttentionContainer', +__all__ = ['MultiheadInProject', + 'MultiheadOutProject', + 'MultiheadAttentionContainer', 'ScaledDotProduct'] diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index f3f0a6b3f6..5b3933c268 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -1,41 +1,36 @@ import torch -from torch._jit_internal import Optional, Tuple +from torch._jit_internal import Tuple Tensor = torch.Tensor class MultiheadAttentionContainer(torch.nn.Module): - def __init__(self, embed_dim, num_heads, attention_layer=None, dropout=0.0): + def __init__(self, in_proj, attention_layer, out_proj): r"""Process input using multi-head attention. Args: - embed_dim (int): Input embedding dimension - num_heads (int): Number of parallel attention heads. attention_layer: The attention layer. The default is None and scaled dot product attention will be used. - dropout: the dropout value (default=0.1). Examples:: - >>> MHA = torchtext.models.MultiheadAttentionContainer(10, 5) - >>> query = torch.rand((21, 64, 10)) - >>> key = value = torch.rand((16, 64, 10)) + >>> embed_dim, num_heads, bsz = 10, 5, 64 + >>> MHA = MultiheadAttentionContainer((MultiheadInProject(embed_dim, num_heads), + MultiheadInProject(embed_dim, num_heads), + MultiheadInProject(embed_dim, num_heads)), + ScaledDotProduct(num_heads), + MultiheadOutProject(embed_dim // num_heads, num_heads)) + >>> query = torch.rand((21, bsz, embed_dim)) + >>> key = value = torch.rand((16, bsz, embed_dim)) >>> attn_output, attn_weights = MHA(query, key, value) >>> print(attn_output.shape) >>> torch.Size([21, 64, 10]) """ super(MultiheadAttentionContainer, self).__init__() - assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads when head_dim=None" - self.head_dim = embed_dim // num_heads - self.embed_dim = embed_dim - self.num_heads = num_heads - self.query_in_proj = torch.nn.Linear(embed_dim, self.num_heads * self.head_dim, bias=False) - self.key_in_proj = torch.nn.Linear(embed_dim, self.num_heads * self.head_dim, bias=False) - self.value_in_proj = torch.nn.Linear(embed_dim, self.num_heads * self.head_dim, bias=False) - if attention_layer: - self.attention_layer = attention_layer - else: - self.attention_layer = ScaledDotProduct(num_heads, dropout=dropout) - self.out_proj = torch.nn.Linear(num_heads * self.head_dim, embed_dim) + self.query_in_proj = in_proj[0] + self.key_in_proj = in_proj[1] + self.value_in_proj = in_proj[2] + self.attention_layer = attention_layer + self.out_proj = out_proj def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): r"""Uses a scaled dot product with the projected key-value pair to update @@ -67,17 +62,48 @@ def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): where where L is the target length, S is the sequence length, H is the number of attention heads, N is the batch size, and E is the embedding dimension. """ - seq_len, bsz, proj_dim = query.size() - tgt_len = key.size(0) - q = self.query_in_proj(query).reshape(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) - k = self.key_in_proj(key).reshape(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) - v = self.value_in_proj(value).reshape(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + q = self.query_in_proj(query) + k = self.key_in_proj(key) + v = self.value_in_proj(value) attn_output, attn_output_weights = self.attention_layer(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask) - attn_output = self.out_proj(attn_output.transpose(0, 1).reshape(seq_len, bsz, self.head_dim * self.num_heads)) + attn_output = self.out_proj(attn_output) return attn_output, attn_output_weights +class MultiheadInProject(torch.nn.Module): + def __init__(self, embed_dim, num_heads): + super(MultiheadInProject, self).__init__() + assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + self.head_dim = embed_dim // num_heads + self.embed_dim = embed_dim + self.num_heads = num_heads + self.proj_layer = torch.nn.Linear(embed_dim, self.num_heads * self.head_dim, bias=False) + + def forward(self, seq): + seq_len, bsz, proj_dim = seq.size() + seq = self.proj_layer(seq) + seq = seq.reshape(seq_len, bsz * self.num_heads, self.head_dim) + return seq + + +class MultiheadOutProject(torch.nn.Module): + def __init__(self, head_dim, num_heads): + super(MultiheadOutProject, self).__init__() + self.head_dim = head_dim + self.num_heads = num_heads + self.proj_layer = torch.nn.Linear(num_heads * head_dim, num_heads * head_dim, bias=False) + + def forward(self, seq): + seq_len, bsz_num_head, head_dim = seq.size() + assert bsz_num_head % self.num_heads == 0, \ + "Dimension -2 of MultiheadOutProject input must be divisible by num_heads" + bsz = bsz_num_head // self.num_heads + seq = seq.reshape(seq_len, bsz, self.num_heads * self.head_dim) + seq = self.proj_layer(seq) + return seq + + class ScaledDotProduct(torch.nn.Module): __constants__ = ['num_heads', 'dropout'] @@ -120,25 +146,26 @@ def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): all the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shape: - - query: :math:`(N * H, L, E / H)` - - key: :math:`(N * H, S, E / H)` - - value: :math:`(N * H, S, E / H)` + - query: :math:`(L, N * H, E / H)` + - key: :math:`(S, N * H, E / H)` + - value: :math:`(S, N * H, E / H)` - key_padding_mask: :math:`(N, S)` - attn_mask: :math:`(L, S)` or :math:`(N * H, L, S)` - - Output: :math:`(N * H, L, E / H)`, :math:`(N * H, L, S)` + - Output: :math:`(L, N * H, E / H)`, :math:`(N * H, L, S)` where L is the target length, S is the source length, H is the number of attention heads, N is the batch size, and E is the embedding dimension. """ - batch_heads, tgt_len, head_dim = query.size() - assert query.size(0) == key.size(0) == value.size(0), "Dimension 0 of query, key, value must be equal." + tgt_len, batch_heads, head_dim = query.size() + assert query.size(1) == key.size(1) == value.size(1), "Dimension 0 of query, key, value must be equal." assert batch_heads % self.num_heads == 0, "Dimension 0 of query, key, value must be divisible by num_heads" bsz = batch_heads // self.num_heads assert key.size() == value.size(), "Shape of key, value must match" assert query.size(-1) == key.size(-1), "The head dimension of query must be equal to that of key" - src_len = key.size(1) + src_len = key.size(0) # Scale query + query, key, value = query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1) query = query * (float(head_dim) ** -0.5) if attn_mask is not None: if attn_mask.dim() == 2: @@ -155,8 +182,6 @@ def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): attn_mask = torch.where( attn_mask, torch.tensor(float('-inf')), torch.tensor(0.)).to(dtype=query.dtype, device=query.device) - src_len = key.size(1) - if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len @@ -179,4 +204,4 @@ def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1) attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_output_weights, value) - return attn_output, attn_output_weights + return attn_output.transpose(0, 1), attn_output_weights From f94506a1e158ba9f5ebcb5b17218280dcbcf3d34 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 15 Apr 2020 13:33:52 -0700 Subject: [PATCH 17/56] add JIT support for MHA blocks --- test/data/test_jit.py | 26 ++++++++++++++++++++++++++ test/data/test_models.py | 2 +- torchtext/models/multiheadattention.py | 7 +++++-- 3 files changed, 32 insertions(+), 3 deletions(-) create mode 100644 test/data/test_jit.py diff --git a/test/data/test_jit.py b/test/data/test_jit.py new file mode 100644 index 0000000000..c74b7d6cb1 --- /dev/null +++ b/test/data/test_jit.py @@ -0,0 +1,26 @@ +import torch +from torchtext.models import MultiheadAttentionContainer, \ + ScaledDotProduct, MultiheadInProject, MultiheadOutProject +from torch.testing import assert_allclose +from ..common.torchtext_test_case import TorchtextTestCase + + +class TestJit(TorchtextTestCase): + + def test_torchscript_multiheadattention(self): + embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 + # Build torchtext MultiheadAttention models + MHA = MultiheadAttentionContainer((MultiheadInProject(embed_dim, nhead), + MultiheadInProject(embed_dim, nhead), + MultiheadInProject(embed_dim, nhead)), + ScaledDotProduct(nhead), + MultiheadOutProject(embed_dim // nhead, nhead)) + + query = torch.rand((tgt_len, bsz, embed_dim)) + key = value = torch.rand((src_len, bsz, embed_dim)) + mha_output, attn_weights = MHA(query, key, value) + + ts_MHA = torch.jit.script(MHA) + ts_mha_output, ts_attn_weights = ts_MHA(query, key, value) + assert_allclose(mha_output, ts_mha_output) + assert_allclose(attn_weights, ts_mha_output) diff --git a/test/data/test_models.py b/test/data/test_models.py index 8091309add..493e86c251 100644 --- a/test/data/test_models.py +++ b/test/data/test_models.py @@ -6,7 +6,7 @@ from ..common.torchtext_test_case import TorchtextTestCase -class TestUtils(TorchtextTestCase): +class TestModels(TorchtextTestCase): def test_multiheadattention(self): embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index 5b3933c268..ac5693efb8 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -1,5 +1,5 @@ import torch -from torch._jit_internal import Tuple +from torch._jit_internal import Tuple, Optional Tensor = torch.Tensor @@ -33,6 +33,7 @@ def __init__(self, in_proj, attention_layer, out_proj): self.out_proj = out_proj def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): + # type: (Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] r"""Uses a scaled dot product with the projected key-value pair to update the projected query. @@ -81,6 +82,7 @@ def __init__(self, embed_dim, num_heads): self.proj_layer = torch.nn.Linear(embed_dim, self.num_heads * self.head_dim, bias=False) def forward(self, seq): + # type: (Tensor) -> Tensor seq_len, bsz, proj_dim = seq.size() seq = self.proj_layer(seq) seq = seq.reshape(seq_len, bsz * self.num_heads, self.head_dim) @@ -95,6 +97,7 @@ def __init__(self, head_dim, num_heads): self.proj_layer = torch.nn.Linear(num_heads * head_dim, num_heads * head_dim, bias=False) def forward(self, seq): + # type: (Tensor) -> Tensor seq_len, bsz_num_head, head_dim = seq.size() assert bsz_num_head % self.num_heads == 0, \ "Dimension -2 of MultiheadOutProject input must be divisible by num_heads" @@ -128,7 +131,7 @@ def __init__(self, num_heads, dropout=0.0): self.dropout = dropout def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): - # type: (...) -> Tuple[Tensor, Tensor] + # type: (Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] r"""Uses a scaled dot product with the projected key-value pair to update the projected query. From f3ed887b11de6d22a6ab0b194339ad69efde7334 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 15 Apr 2020 14:16:14 -0700 Subject: [PATCH 18/56] standardlize attn_mask --- torchtext/models/multiheadattention.py | 69 ++++++-------------------- 1 file changed, 15 insertions(+), 54 deletions(-) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index ac5693efb8..30ab553e01 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -32,29 +32,22 @@ def __init__(self, in_proj, attention_layer, out_proj): self.attention_layer = attention_layer self.out_proj = out_proj - def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): - # type: (Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + def forward(self, query, key, value, attn_mask=None): + # type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] r"""Uses a scaled dot product with the projected key-value pair to update the projected query. Args: query, key, value (Tensor): map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. - key_padding_mask (Tensor, optional): if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - attn_mask (Tensor, optional): 2D or 3D mask that prevents attention to certain positions. - This is an additive mask (i.e. the values will be added to the attention layer). A 2D mask - will be broadcasted for all the batches while a 3D mask allows to specify a different mask - for the entries of each batch. + attn_mask (Bool Tensor, optional): 3D mask that prevents attention to certain positions. Shape: - Inputs: - query: :math:`(L, N, E)` - key: :math:`(S, N, E)` - value: :math:`(S, N, E)` - - attn_mask: 3D mask :math:`(N*num_heads, L, S)` - - key_padding_mask: :math:`(N, S)` + - attn_mask: :math:`(N * H, L, S)` - Outputs: - attn_output: :math:`(L, N, E)` @@ -66,8 +59,7 @@ def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): q = self.query_in_proj(query) k = self.key_in_proj(key) v = self.value_in_proj(value) - attn_output, attn_output_weights = self.attention_layer(q, k, v, attn_mask=attn_mask, - key_padding_mask=key_padding_mask) + attn_output, attn_output_weights = self.attention_layer(q, k, v, attn_mask=attn_mask) attn_output = self.out_proj(attn_output) return attn_output, attn_output_weights @@ -130,8 +122,8 @@ def __init__(self, num_heads, dropout=0.0): self.num_heads = num_heads self.dropout = dropout - def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): - # type: (Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + def forward(self, query, key, value, attn_mask=None): + # type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] r"""Uses a scaled dot product with the projected key-value pair to update the projected query. @@ -139,21 +131,13 @@ def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): query (Tensor): Projected query key (Tensor): Projected key value (Tensor): Projected value - key_padding_mask (Tensor, optional): Specified padding elements in the - key will be ignored by the attention. This is a binary mask. When - the value is True, the corresponding value on the attention layer - will be set to :math:`-\inf`. - attn_mask (Tensor, optional): 2D or 3D mask that prevents attention to - certain positions. This is an additive mask (i.e. the values will - be added to the attention layer). A 2D mask will be broadcasted for - all the batches while a 3D mask allows to specify a different mask - for the entries of each batch. + attn_mask (Bool Tensor, optional): 3D mask that prevents attention to certain positions. + Shape: - query: :math:`(L, N * H, E / H)` - key: :math:`(S, N * H, E / H)` - value: :math:`(S, N * H, E / H)` - - key_padding_mask: :math:`(N, S)` - - attn_mask: :math:`(L, S)` or :math:`(N * H, L, S)` + - attn_mask: :math:`(N * H, L, S)` - Output: :math:`(L, N * H, E / H)`, :math:`(N * H, L, S)` where L is the target length, S is the source length, H is the number of attention heads, N is the batch size, and E is the embedding dimension. @@ -161,48 +145,25 @@ def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): tgt_len, batch_heads, head_dim = query.size() assert query.size(1) == key.size(1) == value.size(1), "Dimension 0 of query, key, value must be equal." assert batch_heads % self.num_heads == 0, "Dimension 0 of query, key, value must be divisible by num_heads" - bsz = batch_heads // self.num_heads assert key.size() == value.size(), "Shape of key, value must match" assert query.size(-1) == key.size(-1), "The head dimension of query must be equal to that of key" - src_len = key.size(0) # Scale query query, key, value = query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1) query = query * (float(head_dim) ** -0.5) if attn_mask is not None: - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, tgt_len, src_len]: - raise RuntimeError('The size of the 2D attn_mask is not correct.') - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [batch_heads, tgt_len, src_len]: - raise RuntimeError('The size of the 3D attn_mask is not correct.') - else: - raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) - # attn_mask's dim is 3 now. - if attn_mask.dtype == torch.bool: - attn_mask = torch.where( - attn_mask, torch.tensor(float('-inf')), torch.tensor(0.)).to(dtype=query.dtype, device=query.device) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz - assert key_padding_mask.size(1) == src_len + if list(attn_mask.size()) != [batch_heads, tgt_len, src_len]: + raise RuntimeError('The size of the 3D attn_mask is not correct.') + if attn_mask.dtype != torch.bool: + raise RuntimeError('Only bool tensor is supported for attn_mask') # Dot product of q, k attn_output_weights = torch.matmul(query, key.transpose(-2, -1)) assert list(attn_output_weights.size()) == [batch_heads, tgt_len, src_len] if attn_mask is not None: - attn_output_weights += attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float('-inf'), - ) - attn_output_weights = attn_output_weights.reshape(batch_heads, tgt_len, src_len) + attn_output_weights.masked_fill_(attn_mask, float('-inf'),) attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1) attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training) From 4a388022b3db0d9fc0b20eb5800a961990db1de7 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 15 Apr 2020 14:27:55 -0700 Subject: [PATCH 19/56] update docs --- torchtext/models/multiheadattention.py | 61 ++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 9 deletions(-) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index 30ab553e01..4fadbda8e4 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -6,11 +6,13 @@ class MultiheadAttentionContainer(torch.nn.Module): - def __init__(self, in_proj, attention_layer, out_proj): - r"""Process input using multi-head attention. + def __init__(self, in_proj_tuple, attention_layer, out_proj): + r""" A multi-head attention container + Args: - attention_layer: The attention layer. The default is None and scaled dot product - attention will be used. + in_proj_tuple: A tuple of multi-head in-projection layers + attention_layer: The attention layer. + out_proj: The multi-head out-projection layer Examples:: >>> embed_dim, num_heads, bsz = 10, 5, 64 @@ -26,16 +28,15 @@ def __init__(self, in_proj, attention_layer, out_proj): >>> torch.Size([21, 64, 10]) """ super(MultiheadAttentionContainer, self).__init__() - self.query_in_proj = in_proj[0] - self.key_in_proj = in_proj[1] - self.value_in_proj = in_proj[2] + self.query_in_proj = in_proj_tuple[0] + self.key_in_proj = in_proj_tuple[1] + self.value_in_proj = in_proj_tuple[2] self.attention_layer = attention_layer self.out_proj = out_proj def forward(self, query, key, value, attn_mask=None): # type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] - r"""Uses a scaled dot product with the projected key-value pair to update - the projected query. + r""" Args: query, key, value (Tensor): map a query and a set of key-value pairs to an output. @@ -66,6 +67,13 @@ def forward(self, query, key, value, attn_mask=None): class MultiheadInProject(torch.nn.Module): def __init__(self, embed_dim, num_heads): + r"""Process input using multi-head attention. + + Args: + embed_dim (int): Input embedding dimension + num_heads (int): Number of parallel attention heads. + """ + super(MultiheadInProject, self).__init__() assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.head_dim = embed_dim // num_heads @@ -75,6 +83,19 @@ def __init__(self, embed_dim, num_heads): def forward(self, seq): # type: (Tensor) -> Tensor + r"""Projects an input sequence using parallel attention heads. + + Args: + seq (Tensor): sequence to be projected + + Shape: + - seq: :math:`(S, N, E)` + + - Output: :math:`(S, N * H, E / H)` + + where S is the sequence length, H is the number of attention heads, N is the + batch size, and E is the embedding dimension. + """ seq_len, bsz, proj_dim = seq.size() seq = self.proj_layer(seq) seq = seq.reshape(seq_len, bsz * self.num_heads, self.head_dim) @@ -83,6 +104,13 @@ def forward(self, seq): class MultiheadOutProject(torch.nn.Module): def __init__(self, head_dim, num_heads): + r"""Process attention output using multi-head attention. + + Args: + head_dim (int): Dimension of embedding for each attention head. + num_heads (int): Number of parallel attention heads. + + """ super(MultiheadOutProject, self).__init__() self.head_dim = head_dim self.num_heads = num_heads @@ -90,6 +118,19 @@ def __init__(self, head_dim, num_heads): def forward(self, seq): # type: (Tensor) -> Tensor + r"""Projects an output sequence using parallel attention heads. + + Args: + seq (Tensor): Projection to be decoded to an embedding. + + Shape: + - seq: :math:`(S, N * H, E / H)` + + - Output: :math:`(S, N, E)` + + where S is the sequence length, H is the number of attention heads, N is the + batch size, and E is the embedding dimension. + """ seq_len, bsz_num_head, head_dim = seq.size() assert bsz_num_head % self.num_heads == 0, \ "Dimension -2 of MultiheadOutProject input must be divisible by num_heads" @@ -138,7 +179,9 @@ def forward(self, query, key, value, attn_mask=None): - key: :math:`(S, N * H, E / H)` - value: :math:`(S, N * H, E / H)` - attn_mask: :math:`(N * H, L, S)` + - Output: :math:`(L, N * H, E / H)`, :math:`(N * H, L, S)` + where L is the target length, S is the source length, H is the number of attention heads, N is the batch size, and E is the embedding dimension. """ From a5bfdee41c26157dfcff503e142a2a4a6ad2f1f4 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 15 Apr 2020 14:48:23 -0700 Subject: [PATCH 20/56] fix a bug in torchscript test --- test/data/test_jit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/data/test_jit.py b/test/data/test_jit.py index c74b7d6cb1..0a77fb4f67 100644 --- a/test/data/test_jit.py +++ b/test/data/test_jit.py @@ -23,4 +23,4 @@ def test_torchscript_multiheadattention(self): ts_MHA = torch.jit.script(MHA) ts_mha_output, ts_attn_weights = ts_MHA(query, key, value) assert_allclose(mha_output, ts_mha_output) - assert_allclose(attn_weights, ts_mha_output) + assert_allclose(attn_weights, ts_attn_weights) From e81c4b3d423c69cba4606c4c7a9f84b7c5773e27 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 16 Apr 2020 09:39:29 -0700 Subject: [PATCH 21/56] add attn_mask in test_multiheadattention and test_torchscript_multiheadattention --- test/data/test_jit.py | 7 ++++--- test/data/test_models.py | 8 ++++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/test/data/test_jit.py b/test/data/test_jit.py index 0a77fb4f67..0a1ef83708 100644 --- a/test/data/test_jit.py +++ b/test/data/test_jit.py @@ -15,12 +15,13 @@ def test_torchscript_multiheadattention(self): MultiheadInProject(embed_dim, nhead)), ScaledDotProduct(nhead), MultiheadOutProject(embed_dim // nhead, nhead)) - query = torch.rand((tgt_len, bsz, embed_dim)) key = value = torch.rand((src_len, bsz, embed_dim)) - mha_output, attn_weights = MHA(query, key, value) + attn_mask = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool) + attn_mask = torch.stack([attn_mask] * (bsz * nhead)) + mha_output, attn_weights = MHA(query, key, value, attn_mask=attn_mask) ts_MHA = torch.jit.script(MHA) - ts_mha_output, ts_attn_weights = ts_MHA(query, key, value) + ts_mha_output, ts_attn_weights = ts_MHA(query, key, value, attn_mask=attn_mask) assert_allclose(mha_output, ts_mha_output) assert_allclose(attn_weights, ts_attn_weights) diff --git a/test/data/test_models.py b/test/data/test_models.py index 493e86c251..2c37d609c6 100644 --- a/test/data/test_models.py +++ b/test/data/test_models.py @@ -19,9 +19,12 @@ def test_multiheadattention(self): query = torch.rand((tgt_len, bsz, embed_dim)) key = value = torch.rand((src_len, bsz, embed_dim)) - mha_output, attn_weights = MHA(query, key, value) + attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool) + mha_output, attn_weights = MHA(query, key, value, + attn_mask=torch.stack([attn_mask_2D] * (bsz * nhead))) # Use torch.nn.functional.multi_head_attention_forward + torch_attn_mask = torch.zeros((tgt_len, src_len)).masked_fill_(attn_mask_2D, float('-inf')) in_proj_weight = torch.cat([MHA.query_in_proj.proj_layer.weight, MHA.key_in_proj.proj_layer.weight, MHA.value_in_proj.proj_layer.weight]) @@ -30,7 +33,8 @@ def test_multiheadattention(self): in_proj_weight, None, None, None, False, 0.0, MHA.out_proj.proj_layer.weight, - MHA.out_proj.proj_layer.bias) + MHA.out_proj.proj_layer.bias, + attn_mask=torch_attn_mask) assert_allclose(mha_output, torch_mha_output) attn_weights = attn_weights.view(bsz, nhead, tgt_len, src_len).sum(dim=1) / nhead From 66b71ac417bd353a47641f36e780edaa47c7e743 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 16 Apr 2020 16:06:13 -0700 Subject: [PATCH 22/56] add partial broadcast support for ScaledDotProduct. Only allow the batch dim of either query or key/value to be 1 --- test/data/test_models.py | 26 ++++++++++++++++++++++++++ torchtext/models/multiheadattention.py | 20 +++++++++++--------- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/test/data/test_models.py b/test/data/test_models.py index 2c37d609c6..af39f00017 100644 --- a/test/data/test_models.py +++ b/test/data/test_models.py @@ -39,3 +39,29 @@ def test_multiheadattention(self): assert_allclose(mha_output, torch_mha_output) attn_weights = attn_weights.view(bsz, nhead, tgt_len, src_len).sum(dim=1) / nhead assert_allclose(attn_weights, torch_mha_weights) + + def test_broadcast_scaled_dot_product(self): + embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 + SDP = ScaledDotProduct(nhead) + query = torch.rand((tgt_len, 1, embed_dim)) + key = value = torch.rand((src_len, 1, embed_dim)) + attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool) + + sdp_attn_output_full, sdp_attn_weights_full = SDP(query.expand(tgt_len, bsz * nhead, embed_dim), + key.expand(src_len, bsz * nhead, embed_dim), + value.expand(src_len, bsz * nhead, embed_dim), + attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) + + # query has a batch size of 1 while key/value have a batch size of bsz * nhead + sdp_attn_output, sdp_attn_weights = SDP(query, key.expand(src_len, bsz * nhead, embed_dim), + value.expand(src_len, bsz * nhead, embed_dim), + attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) + assert_allclose(sdp_attn_output, sdp_attn_output_full) + assert_allclose(sdp_attn_weights, sdp_attn_weights_full) + + # key/value have a batch size of 1 while query has a batch size of bsz * nhead + sdp_attn_output, sdp_attn_weights = SDP(query.expand(tgt_len, bsz * nhead, embed_dim), + key, value, + attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) + assert_allclose(sdp_attn_output, sdp_attn_output_full) + assert_allclose(sdp_attn_weights, sdp_attn_weights_full) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index 4fadbda8e4..d5c382e257 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -185,19 +185,21 @@ def forward(self, query, key, value, attn_mask=None): where L is the target length, S is the source length, H is the number of attention heads, N is the batch size, and E is the embedding dimension. """ - tgt_len, batch_heads, head_dim = query.size() - assert query.size(1) == key.size(1) == value.size(1), "Dimension 0 of query, key, value must be equal." - assert batch_heads % self.num_heads == 0, "Dimension 0 of query, key, value must be divisible by num_heads" + tgt_len, head_dim = query.size(-3), query.size(-1) + assert query.size(-1) == key.size(-1) == value.size(-1), "The feature dim of query, key, value must be equal." assert key.size() == value.size(), "Shape of key, value must match" - assert query.size(-1) == key.size(-1), "The head dimension of query must be equal to that of key" - src_len = key.size(0) + src_len = key.size(-3) + batch_heads = max(query.size(-2), key.size(-2)) # Scale query - query, key, value = query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1) + query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3) query = query * (float(head_dim) ** -0.5) if attn_mask is not None: - if list(attn_mask.size()) != [batch_heads, tgt_len, src_len]: - raise RuntimeError('The size of the 3D attn_mask is not correct.') + if attn_mask.dim() != 3: + raise RuntimeError('attn_mask must be a 3D tensor.') + if (attn_mask.size(-1) != src_len) or (attn_mask.size(-2) != tgt_len) or \ + (attn_mask.size(-3) != 1 and attn_mask.size(-3) != batch_heads): + raise RuntimeError('The size of the attn_mask is not correct.') if attn_mask.dtype != torch.bool: raise RuntimeError('Only bool tensor is supported for attn_mask') @@ -211,4 +213,4 @@ def forward(self, query, key, value, attn_mask=None): attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1) attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_output_weights, value) - return attn_output.transpose(0, 1), attn_output_weights + return attn_output.transpose(-2, -3), attn_output_weights From da1bc7a8963af48223616b0c444a11415e35d2ed Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 17 Apr 2020 09:31:38 -0700 Subject: [PATCH 23/56] add more broadcast tests for scaled dot product model --- test/data/test_models.py | 31 ++++++++++++++++++++++++++ torchtext/models/multiheadattention.py | 4 +++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/test/data/test_models.py b/test/data/test_models.py index af39f00017..fd2e4b948e 100644 --- a/test/data/test_models.py +++ b/test/data/test_models.py @@ -65,3 +65,34 @@ def test_broadcast_scaled_dot_product(self): attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) assert_allclose(sdp_attn_output, sdp_attn_output_full) assert_allclose(sdp_attn_weights, sdp_attn_weights_full) + + # key/value have a size of (3, 3, src_len, bsz * nhead, embed_dim) + # while query has a size of (tgt_len, 1, embed_dim) + sdp_attn_output, sdp_attn_weights = SDP(query.expand(tgt_len, 1, embed_dim), + key.expand(3, 3, src_len, bsz * nhead, embed_dim), + value.expand(3, 3, src_len, bsz * nhead, embed_dim), + attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) + assert list(sdp_attn_output.size()) == [3, 3, tgt_len, bsz * nhead, embed_dim] + assert list(sdp_attn_weights.size()) == [3, 3, bsz * nhead, tgt_len, embed_dim] + assert_allclose(sdp_attn_output[2][2], sdp_attn_output_full) + assert_allclose(sdp_attn_weights[2][2], sdp_attn_weights_full) + + # key/value have a size of (src_len, 1, embed_dim) + # while query has a size of (1, 2, 3, tgt_len, bsz * nhead, embed_dim) + sdp_attn_output, sdp_attn_weights = SDP(query.expand(1, 2, 3, tgt_len, bsz * nhead, embed_dim), + key.expand(src_len, 1, embed_dim), + value.expand(src_len, 1, embed_dim), + attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) + assert list(sdp_attn_output.size()) == [1, 2, 3, tgt_len, bsz * nhead, embed_dim] + assert list(sdp_attn_weights.size()) == [1, 2, 3, bsz * nhead, tgt_len, embed_dim] + assert_allclose(sdp_attn_output[0][1][2], sdp_attn_output_full) + assert_allclose(sdp_attn_weights[0][1][2], sdp_attn_weights_full) + + # attn_mask in a size of (1, tgt_len, src_len) + # 2D tensor is not supported for attn_mask + sdp_attn_output, sdp_attn_weights = SDP(query.expand(tgt_len, bsz * nhead, embed_dim), + key.expand(src_len, bsz * nhead, embed_dim), + value.expand(src_len, bsz * nhead, embed_dim), + attn_mask=attn_mask_2D.expand(1, tgt_len, src_len)) + assert_allclose(sdp_attn_output, sdp_attn_output_full) + assert_allclose(sdp_attn_weights, sdp_attn_weights_full) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index d5c382e257..03b4e10600 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -205,7 +205,9 @@ def forward(self, query, key, value, attn_mask=None): # Dot product of q, k attn_output_weights = torch.matmul(query, key.transpose(-2, -1)) - assert list(attn_output_weights.size()) == [batch_heads, tgt_len, src_len] + assert attn_output_weights.size(-3) == batch_heads + assert attn_output_weights.size(-2) == tgt_len + assert attn_output_weights.size(-1) == src_len if attn_mask is not None: attn_output_weights.masked_fill_(attn_mask, float('-inf'),) From accceebda0f72d691398f6e2ff4b9df19e704963 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 23 Apr 2020 09:38:07 -0700 Subject: [PATCH 24/56] add support for incremental decoding --- test/data/test_models.py | 8 ++++++-- torchtext/models/multiheadattention.py | 25 ++++++++++++++++++++----- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/test/data/test_models.py b/test/data/test_models.py index fd2e4b948e..13351aacb6 100644 --- a/test/data/test_models.py +++ b/test/data/test_models.py @@ -20,8 +20,11 @@ def test_multiheadattention(self): query = torch.rand((tgt_len, bsz, embed_dim)) key = value = torch.rand((src_len, bsz, embed_dim)) attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool) + bias_k = bias_v = torch.rand((1, 1, embed_dim)) mha_output, attn_weights = MHA(query, key, value, - attn_mask=torch.stack([attn_mask_2D] * (bsz * nhead))) + attn_mask=torch.stack([attn_mask_2D] * (bsz * nhead)), + bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1), + bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1)) # Use torch.nn.functional.multi_head_attention_forward torch_attn_mask = torch.zeros((tgt_len, src_len)).masked_fill_(attn_mask_2D, float('-inf')) @@ -31,7 +34,8 @@ def test_multiheadattention(self): torch_mha_output, torch_mha_weights = mha_forward(query, key, value, embed_dim, nhead, in_proj_weight, None, - None, None, False, 0.0, + bias_k, bias_v, + False, 0.0, MHA.out_proj.proj_layer.weight, MHA.out_proj.proj_layer.bias, attn_mask=torch_attn_mask) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index 03b4e10600..2da6a2a078 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -34,14 +34,16 @@ def __init__(self, in_proj_tuple, attention_layer, out_proj): self.attention_layer = attention_layer self.out_proj = out_proj - def forward(self, query, key, value, attn_mask=None): - # type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None): + # type: (...) -> Tuple[Tensor, Optional[Tensor]] r""" Args: query, key, value (Tensor): map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. attn_mask (Bool Tensor, optional): 3D mask that prevents attention to certain positions. + bias_k and bias_v:bias (Tensor, optional): one more key and value sequence to be added at + sequence dim (dim=-3). Those are used for incremental decoding. Shape: - Inputs: @@ -49,6 +51,7 @@ def forward(self, query, key, value, attn_mask=None): - key: :math:`(S, N, E)` - value: :math:`(S, N, E)` - attn_mask: :math:`(N * H, L, S)` + - bias_k and bias_v:bias: :math:`(1, N * H, E / H)` - Outputs: - attn_output: :math:`(L, N, E)` @@ -60,7 +63,8 @@ def forward(self, query, key, value, attn_mask=None): q = self.query_in_proj(query) k = self.key_in_proj(key) v = self.value_in_proj(value) - attn_output, attn_output_weights = self.attention_layer(q, k, v, attn_mask=attn_mask) + attn_output, attn_output_weights = self.attention_layer(q, k, v, attn_mask=attn_mask, + bias_k=bias_k, bias_v=bias_v) attn_output = self.out_proj(attn_output) return attn_output, attn_output_weights @@ -163,8 +167,8 @@ def __init__(self, num_heads, dropout=0.0): self.num_heads = num_heads self.dropout = dropout - def forward(self, query, key, value, attn_mask=None): - # type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None): + # type: (...) -> Tuple[Tensor, Optional[Tensor]] r"""Uses a scaled dot product with the projected key-value pair to update the projected query. @@ -173,18 +177,29 @@ def forward(self, query, key, value, attn_mask=None): key (Tensor): Projected key value (Tensor): Projected value attn_mask (Bool Tensor, optional): 3D mask that prevents attention to certain positions. + bias_k and bias_v:bias: the additional key and value sequence to be added at sequence dim (dim=-3). + Those are used for incremental decoding. Shape: - query: :math:`(L, N * H, E / H)` - key: :math:`(S, N * H, E / H)` - value: :math:`(S, N * H, E / H)` - attn_mask: :math:`(N * H, L, S)` + - bias_k and bias_v:bias: :math:`(1, N * H, E / H)` - Output: :math:`(L, N * H, E / H)`, :math:`(N * H, L, S)` where L is the target length, S is the source length, H is the number of attention heads, N is the batch size, and E is the embedding dimension. """ + if bias_k is not None and bias_v is not None: + assert key.size(-1) == bias_k.size(-1) and key.size(-2) == bias_k.size(-2) and bias_k.size(-3) == 1, \ + "Shape of bias_k is not supported" + assert value.size(-1) == bias_v.size(-1) and value.size(-2) == bias_v.size(-2) and bias_v.size(-3) == 1, \ + "Shape of bias_v is not supported" + key = torch.cat([key, bias_k]) + value = torch.cat([value, bias_v]) + tgt_len, head_dim = query.size(-3), query.size(-1) assert query.size(-1) == key.size(-1) == value.size(-1), "The feature dim of query, key, value must be equal." assert key.size() == value.size(), "Shape of key, value must match" From 7bd3beb9101d9f6f8e46654816d2eda0b4290066 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 23 Apr 2020 09:41:04 -0700 Subject: [PATCH 25/56] remove nheads from ScaledDotProduct --- test/data/test_models.py | 4 ++-- torchtext/models/multiheadattention.py | 10 ++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/test/data/test_models.py b/test/data/test_models.py index 13351aacb6..984e736b6f 100644 --- a/test/data/test_models.py +++ b/test/data/test_models.py @@ -14,7 +14,7 @@ def test_multiheadattention(self): MHA = MultiheadAttentionContainer((MultiheadInProject(embed_dim, nhead), MultiheadInProject(embed_dim, nhead), MultiheadInProject(embed_dim, nhead)), - ScaledDotProduct(nhead), + ScaledDotProduct(), MultiheadOutProject(embed_dim // nhead, nhead)) query = torch.rand((tgt_len, bsz, embed_dim)) @@ -46,7 +46,7 @@ def test_multiheadattention(self): def test_broadcast_scaled_dot_product(self): embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 - SDP = ScaledDotProduct(nhead) + SDP = ScaledDotProduct() query = torch.rand((tgt_len, 1, embed_dim)) key = value = torch.rand((src_len, 1, embed_dim)) attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index 2da6a2a078..d3a19d0d7e 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -19,7 +19,7 @@ def __init__(self, in_proj_tuple, attention_layer, out_proj): >>> MHA = MultiheadAttentionContainer((MultiheadInProject(embed_dim, num_heads), MultiheadInProject(embed_dim, num_heads), MultiheadInProject(embed_dim, num_heads)), - ScaledDotProduct(num_heads), + ScaledDotProduct(), MultiheadOutProject(embed_dim // num_heads, num_heads)) >>> query = torch.rand((21, bsz, embed_dim)) >>> key = value = torch.rand((16, bsz, embed_dim)) @@ -145,18 +145,17 @@ def forward(self, seq): class ScaledDotProduct(torch.nn.Module): - __constants__ = ['num_heads', 'dropout'] + __constants__ = ['dropout'] - def __init__(self, num_heads, dropout=0.0): + def __init__(self, dropout=0.0): r"""Processes a projected query and key-value pair to apply scaled dot product attention. Args: - num_heads (int): Number of parallel attention heads. dropout (float): probability of dropping an attention weight. Examples:: - >>> SDP = torchtext.models.ScaledDotProduct(4, 0.1) + >>> SDP = torchtext.models.ScaledDotProduct(0.1) >>> q = torch.randn(256, 21, 3) >>> k = v = torch.randn(256, 21, 3) >>> attn_output, attn_weights = SDP(q, k, v) @@ -164,7 +163,6 @@ def __init__(self, num_heads, dropout=0.0): torch.Size([256, 21, 3]) torch.Size([256, 21, 21]) """ super(ScaledDotProduct, self).__init__() - self.num_heads = num_heads self.dropout = dropout def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None): From 14da915a1fd19b3564bec9fb91236224f45a3dd4 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 23 Apr 2020 11:04:46 -0700 Subject: [PATCH 26/56] minor fix in jit test --- test/data/test_jit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/data/test_jit.py b/test/data/test_jit.py index 0a1ef83708..6ad101940f 100644 --- a/test/data/test_jit.py +++ b/test/data/test_jit.py @@ -5,7 +5,7 @@ from ..common.torchtext_test_case import TorchtextTestCase -class TestJit(TorchtextTestCase): +class TestJIT(TorchtextTestCase): def test_torchscript_multiheadattention(self): embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 @@ -13,7 +13,7 @@ def test_torchscript_multiheadattention(self): MHA = MultiheadAttentionContainer((MultiheadInProject(embed_dim, nhead), MultiheadInProject(embed_dim, nhead), MultiheadInProject(embed_dim, nhead)), - ScaledDotProduct(nhead), + ScaledDotProduct(), MultiheadOutProject(embed_dim // nhead, nhead)) query = torch.rand((tgt_len, bsz, embed_dim)) key = value = torch.rand((src_len, bsz, embed_dim)) From 032d74916da5e989f10756d2ea7f1620441f85e6 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 23 Apr 2020 11:21:02 -0700 Subject: [PATCH 27/56] adjust attn_mask --- torchtext/models/multiheadattention.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index d3a19d0d7e..a788c24ccc 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -41,7 +41,7 @@ def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None): Args: query, key, value (Tensor): map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. - attn_mask (Bool Tensor, optional): 3D mask that prevents attention to certain positions. + attn_mask (BoolTensor, optional): 3D mask that prevents attention to certain positions. bias_k and bias_v:bias (Tensor, optional): one more key and value sequence to be added at sequence dim (dim=-3). Those are used for incremental decoding. @@ -50,7 +50,8 @@ def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None): - query: :math:`(L, N, E)` - key: :math:`(S, N, E)` - value: :math:`(S, N, E)` - - attn_mask: :math:`(N * H, L, S)` + - attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not allowed to attend + while ``False`` values will be unchanged. - bias_k and bias_v:bias: :math:`(1, N * H, E / H)` - Outputs: @@ -174,7 +175,7 @@ def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None): query (Tensor): Projected query key (Tensor): Projected key value (Tensor): Projected value - attn_mask (Bool Tensor, optional): 3D mask that prevents attention to certain positions. + attn_mask (BoolTensor, optional): 3D mask that prevents attention to certain positions. bias_k and bias_v:bias: the additional key and value sequence to be added at sequence dim (dim=-3). Those are used for incremental decoding. @@ -182,7 +183,8 @@ def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None): - query: :math:`(L, N * H, E / H)` - key: :math:`(S, N * H, E / H)` - value: :math:`(S, N * H, E / H)` - - attn_mask: :math:`(N * H, L, S)` + - attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not allowed to attend + while ``False`` values will be unchanged. - bias_k and bias_v:bias: :math:`(1, N * H, E / H)` - Output: :math:`(L, N * H, E / H)`, :math:`(N * H, L, S)` @@ -197,6 +199,7 @@ def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None): "Shape of bias_v is not supported" key = torch.cat([key, bias_k]) value = torch.cat([value, bias_v]) + torch.nn.functional.pad(attn_mask, (0, 1)) tgt_len, head_dim = query.size(-3), query.size(-1) assert query.size(-1) == key.size(-1) == value.size(-1), "The feature dim of query, key, value must be equal." From 5c1198c8bf76e5192bcbba31ca70b3ef673a3a39 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 23 Apr 2020 11:55:40 -0700 Subject: [PATCH 28/56] fix jit annotation --- .flake8 | 3 ++- torchtext/models/multiheadattention.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.flake8 b/.flake8 index 50ecc8aa11..c5675e16e2 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,5 @@ [flake8] -ignore = E402,E722,W503,W504,F821 +# E501 is not flexible enough, we're using B950 instead. Consistent with pytorch +ignore = E402,E722,W503,W504,F821,E501 max-line-length = 120 exclude = docs/source diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index a788c24ccc..5dbcde52bd 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -35,7 +35,7 @@ def __init__(self, in_proj_tuple, attention_layer, out_proj): self.out_proj = out_proj def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None): - # type: (...) -> Tuple[Tensor, Optional[Tensor]] + # type: (Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] r""" Args: @@ -167,7 +167,7 @@ def __init__(self, dropout=0.0): self.dropout = dropout def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None): - # type: (...) -> Tuple[Tensor, Optional[Tensor]] + # type: (Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] r"""Uses a scaled dot product with the projected key-value pair to update the projected query. From c4ccac7c7e0d5ad3335332d1adeae6f27ec01f7f Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 23 Apr 2020 11:58:04 -0700 Subject: [PATCH 29/56] minor --- torchtext/models/multiheadattention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index 5dbcde52bd..5e5a40c085 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -199,7 +199,7 @@ def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None): "Shape of bias_v is not supported" key = torch.cat([key, bias_k]) value = torch.cat([value, bias_v]) - torch.nn.functional.pad(attn_mask, (0, 1)) + attn_mask = torch.nn.functional.pad(attn_mask, (0, 1)) tgt_len, head_dim = query.size(-3), query.size(-1) assert query.size(-1) == key.size(-1) == value.size(-1), "The feature dim of query, key, value must be equal." From 295ab13a9cf180c20e926a85c6eb0c533e29b8b3 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 23 Apr 2020 12:46:50 -0700 Subject: [PATCH 30/56] refine optional tensor in torchscript --- torchtext/models/multiheadattention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index 5e5a40c085..c9692f3d39 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -199,7 +199,9 @@ def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None): "Shape of bias_v is not supported" key = torch.cat([key, bias_k]) value = torch.cat([value, bias_v]) - attn_mask = torch.nn.functional.pad(attn_mask, (0, 1)) + if attn_mask is not None: + _attn_mask = attn_mask + attn_mask = torch.nn.functional.pad(_attn_mask, (0, 1)) tgt_len, head_dim = query.size(-3), query.size(-1) assert query.size(-1) == key.size(-1) == value.size(-1), "The feature dim of query, key, value must be equal." From 96790220f55c0dc3b73a4b9d99070b11cde18bb1 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 23 Apr 2020 12:51:26 -0700 Subject: [PATCH 31/56] minor fix in mha test --- test/data/test_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/data/test_models.py b/test/data/test_models.py index 984e736b6f..58c5257d0a 100644 --- a/test/data/test_models.py +++ b/test/data/test_models.py @@ -41,7 +41,8 @@ def test_multiheadattention(self): attn_mask=torch_attn_mask) assert_allclose(mha_output, torch_mha_output) - attn_weights = attn_weights.view(bsz, nhead, tgt_len, src_len).sum(dim=1) / nhead + # With bias_k and bias_v, src_len needs to plus 1 + attn_weights = attn_weights.view(bsz, nhead, tgt_len, src_len + 1).sum(dim=1) / nhead assert_allclose(attn_weights, torch_mha_weights) def test_broadcast_scaled_dot_product(self): From bc8a75f9e4c860419c3cbed408bb4245c02e0bda Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 24 Apr 2020 07:35:57 -0700 Subject: [PATCH 32/56] remove a few assert statements --- torchtext/models/multiheadattention.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index c9692f3d39..7cc8681e64 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -223,13 +223,8 @@ def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None): # Dot product of q, k attn_output_weights = torch.matmul(query, key.transpose(-2, -1)) - assert attn_output_weights.size(-3) == batch_heads - assert attn_output_weights.size(-2) == tgt_len - assert attn_output_weights.size(-1) == src_len - if attn_mask is not None: attn_output_weights.masked_fill_(attn_mask, float('-inf'),) - attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1) attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_output_weights, value) From 3a7d70db1bbb3e09d47f0ee931b9c9d0b97cf50a Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 24 Apr 2020 08:07:16 -0700 Subject: [PATCH 33/56] a few changes to for torchscript in python 3 --- torchtext/models/multiheadattention.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/torchtext/models/multiheadattention.py b/torchtext/models/multiheadattention.py index 7cc8681e64..b92f00ebb0 100644 --- a/torchtext/models/multiheadattention.py +++ b/torchtext/models/multiheadattention.py @@ -1,8 +1,5 @@ import torch -from torch._jit_internal import Tuple, Optional - - -Tensor = torch.Tensor +from typing import Tuple, Optional class MultiheadAttentionContainer(torch.nn.Module): @@ -34,8 +31,10 @@ def __init__(self, in_proj_tuple, attention_layer, out_proj): self.attention_layer = attention_layer self.out_proj = out_proj - def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None): - # type: (Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + bias_k: Optional[torch.Tensor] = None, + bias_v: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: @@ -86,8 +85,7 @@ def __init__(self, embed_dim, num_heads): self.num_heads = num_heads self.proj_layer = torch.nn.Linear(embed_dim, self.num_heads * self.head_dim, bias=False) - def forward(self, seq): - # type: (Tensor) -> Tensor + def forward(self, seq: torch.Tensor) -> torch.Tensor: r"""Projects an input sequence using parallel attention heads. Args: @@ -121,8 +119,7 @@ def __init__(self, head_dim, num_heads): self.num_heads = num_heads self.proj_layer = torch.nn.Linear(num_heads * head_dim, num_heads * head_dim, bias=False) - def forward(self, seq): - # type: (Tensor) -> Tensor + def forward(self, seq: torch.Tensor) -> torch.Tensor: r"""Projects an output sequence using parallel attention heads. Args: @@ -146,7 +143,6 @@ def forward(self, seq): class ScaledDotProduct(torch.nn.Module): - __constants__ = ['dropout'] def __init__(self, dropout=0.0): r"""Processes a projected query and key-value pair to apply @@ -166,8 +162,10 @@ def __init__(self, dropout=0.0): super(ScaledDotProduct, self).__init__() self.dropout = dropout - def forward(self, query, key, value, attn_mask=None, bias_k=None, bias_v=None): - # type: (Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + bias_k: Optional[torch.Tensor] = None, + bias_v: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: r"""Uses a scaled dot product with the projected key-value pair to update the projected query. From 80087989339f8d31d2fb4a24fcd099f042f84cf4 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 24 Apr 2020 08:10:10 -0700 Subject: [PATCH 34/56] switch the name from models to modules --- test/data/{test_models.py => test_modules.py} | 4 ++-- torchtext/{models => modules}/__init__.py | 0 torchtext/{models => modules}/multiheadattention.py | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename test/data/{test_models.py => test_modules.py} (98%) rename torchtext/{models => modules}/__init__.py (100%) rename torchtext/{models => modules}/multiheadattention.py (100%) diff --git a/test/data/test_models.py b/test/data/test_modules.py similarity index 98% rename from test/data/test_models.py rename to test/data/test_modules.py index 58c5257d0a..affcd6215b 100644 --- a/test/data/test_models.py +++ b/test/data/test_modules.py @@ -1,5 +1,5 @@ import torch -from torchtext.models import MultiheadAttentionContainer, \ +from torchtext.modules import MultiheadAttentionContainer, \ ScaledDotProduct, MultiheadInProject, MultiheadOutProject from torch.nn.functional import multi_head_attention_forward as mha_forward from torch.testing import assert_allclose @@ -10,7 +10,7 @@ class TestModels(TorchtextTestCase): def test_multiheadattention(self): embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 - # Build torchtext MultiheadAttention models + # Build torchtext MultiheadAttention module MHA = MultiheadAttentionContainer((MultiheadInProject(embed_dim, nhead), MultiheadInProject(embed_dim, nhead), MultiheadInProject(embed_dim, nhead)), diff --git a/torchtext/models/__init__.py b/torchtext/modules/__init__.py similarity index 100% rename from torchtext/models/__init__.py rename to torchtext/modules/__init__.py diff --git a/torchtext/models/multiheadattention.py b/torchtext/modules/multiheadattention.py similarity index 100% rename from torchtext/models/multiheadattention.py rename to torchtext/modules/multiheadattention.py From e12e131d9ce59b61dae5b424e617b59852bdbcb4 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 24 Apr 2020 09:25:41 -0700 Subject: [PATCH 35/56] minor fix --- test/data/test_jit.py | 2 +- torchtext/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/data/test_jit.py b/test/data/test_jit.py index 6ad101940f..e835960ab5 100644 --- a/test/data/test_jit.py +++ b/test/data/test_jit.py @@ -1,5 +1,5 @@ import torch -from torchtext.models import MultiheadAttentionContainer, \ +from torchtext.modules import MultiheadAttentionContainer, \ ScaledDotProduct, MultiheadInProject, MultiheadOutProject from torch.testing import assert_allclose from ..common.torchtext_test_case import TorchtextTestCase diff --git a/torchtext/__init__.py b/torchtext/__init__.py index 5c210c5601..74eda2682c 100644 --- a/torchtext/__init__.py +++ b/torchtext/__init__.py @@ -1,5 +1,5 @@ from . import data -from . import models +from . import modules from . import datasets from . import utils from . import vocab @@ -8,7 +8,7 @@ __version__ = '0.6.0' __all__ = ['data', - 'models', + 'modules', 'datasets', 'utils', 'vocab', From 659db7ae4c95f228a9a248309143eb1a2aa59b2c Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 27 Apr 2020 13:51:24 -0700 Subject: [PATCH 36/56] move reshape to MHA container --- test/data/test_jit.py | 12 +-- test/data/test_modules.py | 18 ++-- torchtext/modules/__init__.py | 7 +- torchtext/modules/multiheadattention.py | 105 ++++++------------------ 4 files changed, 42 insertions(+), 100 deletions(-) diff --git a/test/data/test_jit.py b/test/data/test_jit.py index e835960ab5..dcd2a4bd9e 100644 --- a/test/data/test_jit.py +++ b/test/data/test_jit.py @@ -1,6 +1,5 @@ import torch -from torchtext.modules import MultiheadAttentionContainer, \ - ScaledDotProduct, MultiheadInProject, MultiheadOutProject +from torchtext.modules import MultiheadAttentionContainer, ScaledDotProduct from torch.testing import assert_allclose from ..common.torchtext_test_case import TorchtextTestCase @@ -10,11 +9,12 @@ class TestJIT(TorchtextTestCase): def test_torchscript_multiheadattention(self): embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 # Build torchtext MultiheadAttention models - MHA = MultiheadAttentionContainer((MultiheadInProject(embed_dim, nhead), - MultiheadInProject(embed_dim, nhead), - MultiheadInProject(embed_dim, nhead)), + MHA = MultiheadAttentionContainer(nhead, + (torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim)), ScaledDotProduct(), - MultiheadOutProject(embed_dim // nhead, nhead)) + torch.nn.Linear(embed_dim, embed_dim)) query = torch.rand((tgt_len, bsz, embed_dim)) key = value = torch.rand((src_len, bsz, embed_dim)) attn_mask = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool) diff --git a/test/data/test_modules.py b/test/data/test_modules.py index affcd6215b..7734f5293a 100644 --- a/test/data/test_modules.py +++ b/test/data/test_modules.py @@ -1,6 +1,5 @@ import torch -from torchtext.modules import MultiheadAttentionContainer, \ - ScaledDotProduct, MultiheadInProject, MultiheadOutProject +from torchtext.modules import MultiheadAttentionContainer, ScaledDotProduct from torch.nn.functional import multi_head_attention_forward as mha_forward from torch.testing import assert_allclose from ..common.torchtext_test_case import TorchtextTestCase @@ -11,11 +10,12 @@ class TestModels(TorchtextTestCase): def test_multiheadattention(self): embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 # Build torchtext MultiheadAttention module - MHA = MultiheadAttentionContainer((MultiheadInProject(embed_dim, nhead), - MultiheadInProject(embed_dim, nhead), - MultiheadInProject(embed_dim, nhead)), + MHA = MultiheadAttentionContainer(nhead, + (torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim),), ScaledDotProduct(), - MultiheadOutProject(embed_dim // nhead, nhead)) + torch.nn.Linear(embed_dim, embed_dim)) query = torch.rand((tgt_len, bsz, embed_dim)) key = value = torch.rand((src_len, bsz, embed_dim)) @@ -28,9 +28,9 @@ def test_multiheadattention(self): # Use torch.nn.functional.multi_head_attention_forward torch_attn_mask = torch.zeros((tgt_len, src_len)).masked_fill_(attn_mask_2D, float('-inf')) - in_proj_weight = torch.cat([MHA.query_in_proj.proj_layer.weight, - MHA.key_in_proj.proj_layer.weight, - MHA.value_in_proj.proj_layer.weight]) + in_proj_weight = torch.cat([MHA.query_in_proj.weight, + MHA.key_in_proj.weight, + MHA.value_in_proj.weight]) torch_mha_output, torch_mha_weights = mha_forward(query, key, value, embed_dim, nhead, in_proj_weight, None, diff --git a/torchtext/modules/__init__.py b/torchtext/modules/__init__.py index f821fb9a8f..70ba44939b 100644 --- a/torchtext/modules/__init__.py +++ b/torchtext/modules/__init__.py @@ -1,7 +1,4 @@ -from .multiheadattention import MultiheadInProject, MultiheadOutProject, \ - MultiheadAttentionContainer, ScaledDotProduct +from .multiheadattention import MultiheadAttentionContainer, ScaledDotProduct -__all__ = ['MultiheadInProject', - 'MultiheadOutProject', - 'MultiheadAttentionContainer', +__all__ = ['MultiheadAttentionContainer', 'ScaledDotProduct'] diff --git a/torchtext/modules/multiheadattention.py b/torchtext/modules/multiheadattention.py index b92f00ebb0..1bc5cf01db 100644 --- a/torchtext/modules/multiheadattention.py +++ b/torchtext/modules/multiheadattention.py @@ -3,21 +3,24 @@ class MultiheadAttentionContainer(torch.nn.Module): - def __init__(self, in_proj_tuple, attention_layer, out_proj): + def __init__(self, nhead, in_proj_tuple, attention_layer, out_proj): r""" A multi-head attention container Args: - in_proj_tuple: A tuple of multi-head in-projection layers + nhead: the number of heads in the multiheadattention model + in_proj_tuple: A tuple of multi-head in-projection linear layers (a.k.a nn.Linear). attention_layer: The attention layer. - out_proj: The multi-head out-projection layer + out_proj: The multi-head out-projection layer (a.k.a nn.Linear). Examples:: + >>> import torch >>> embed_dim, num_heads, bsz = 10, 5, 64 - >>> MHA = MultiheadAttentionContainer((MultiheadInProject(embed_dim, num_heads), - MultiheadInProject(embed_dim, num_heads), - MultiheadInProject(embed_dim, num_heads)), + >>> MHA = MultiheadAttentionContainer(num_heads, + (torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim)), ScaledDotProduct(), - MultiheadOutProject(embed_dim // num_heads, num_heads)) + torch.nn.Linear(embed_dim, embed_dim)) >>> query = torch.rand((21, bsz, embed_dim)) >>> key = value = torch.rand((16, bsz, embed_dim)) >>> attn_output, attn_weights = MHA(query, key, value) @@ -25,6 +28,7 @@ def __init__(self, in_proj_tuple, attention_layer, out_proj): >>> torch.Size([21, 64, 10]) """ super(MultiheadAttentionContainer, self).__init__() + self.nhead = nhead self.query_in_proj = in_proj_tuple[0] self.key_in_proj = in_proj_tuple[1] self.value_in_proj = in_proj_tuple[2] @@ -60,88 +64,29 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, where where L is the target length, S is the sequence length, H is the number of attention heads, N is the batch size, and E is the embedding dimension. """ + tgt_len, src_len, bsz, embed_dim = query.size(-3), key.size(-3), query.size(-2), query.size(-1) q = self.query_in_proj(query) + assert q.size(-1) % self.nhead == 0, "query's embed_dim must be divisible by the number of heads" + head_dim = q.size(-1) // self.nhead + q = q.reshape(tgt_len, bsz * self.nhead, head_dim) + k = self.key_in_proj(key) + assert k.size(-1) % self.nhead == 0, "key's embed_dim must be divisible by the number of heads" + head_dim = k.size(-1) // self.nhead + k = k.reshape(src_len, bsz * self.nhead, head_dim) + v = self.value_in_proj(value) + assert v.size(-1) % self.nhead == 0, "value's embed_dim must be divisible by the number of heads" + head_dim = v.size(-1) // self.nhead + v = v.reshape(src_len, bsz * self.nhead, head_dim) + attn_output, attn_output_weights = self.attention_layer(q, k, v, attn_mask=attn_mask, bias_k=bias_k, bias_v=bias_v) + attn_output = attn_output.reshape(tgt_len, bsz, embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_output_weights -class MultiheadInProject(torch.nn.Module): - def __init__(self, embed_dim, num_heads): - r"""Process input using multi-head attention. - - Args: - embed_dim (int): Input embedding dimension - num_heads (int): Number of parallel attention heads. - """ - - super(MultiheadInProject, self).__init__() - assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" - self.head_dim = embed_dim // num_heads - self.embed_dim = embed_dim - self.num_heads = num_heads - self.proj_layer = torch.nn.Linear(embed_dim, self.num_heads * self.head_dim, bias=False) - - def forward(self, seq: torch.Tensor) -> torch.Tensor: - r"""Projects an input sequence using parallel attention heads. - - Args: - seq (Tensor): sequence to be projected - - Shape: - - seq: :math:`(S, N, E)` - - - Output: :math:`(S, N * H, E / H)` - - where S is the sequence length, H is the number of attention heads, N is the - batch size, and E is the embedding dimension. - """ - seq_len, bsz, proj_dim = seq.size() - seq = self.proj_layer(seq) - seq = seq.reshape(seq_len, bsz * self.num_heads, self.head_dim) - return seq - - -class MultiheadOutProject(torch.nn.Module): - def __init__(self, head_dim, num_heads): - r"""Process attention output using multi-head attention. - - Args: - head_dim (int): Dimension of embedding for each attention head. - num_heads (int): Number of parallel attention heads. - - """ - super(MultiheadOutProject, self).__init__() - self.head_dim = head_dim - self.num_heads = num_heads - self.proj_layer = torch.nn.Linear(num_heads * head_dim, num_heads * head_dim, bias=False) - - def forward(self, seq: torch.Tensor) -> torch.Tensor: - r"""Projects an output sequence using parallel attention heads. - - Args: - seq (Tensor): Projection to be decoded to an embedding. - - Shape: - - seq: :math:`(S, N * H, E / H)` - - - Output: :math:`(S, N, E)` - - where S is the sequence length, H is the number of attention heads, N is the - batch size, and E is the embedding dimension. - """ - seq_len, bsz_num_head, head_dim = seq.size() - assert bsz_num_head % self.num_heads == 0, \ - "Dimension -2 of MultiheadOutProject input must be divisible by num_heads" - bsz = bsz_num_head // self.num_heads - seq = seq.reshape(seq_len, bsz, self.num_heads * self.head_dim) - seq = self.proj_layer(seq) - return seq - - class ScaledDotProduct(torch.nn.Module): def __init__(self, dropout=0.0): From a3a21e78cb9f3ce34a703ec2abfd272a2de6c632 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 27 Apr 2020 13:56:34 -0700 Subject: [PATCH 37/56] udpdate doc --- docs/source/{models.rst => modules.rst} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename docs/source/{models.rst => modules.rst} (90%) diff --git a/docs/source/models.rst b/docs/source/modules.rst similarity index 90% rename from docs/source/models.rst rename to docs/source/modules.rst index 0af2fa2e5e..85af9c7026 100644 --- a/docs/source/models.rst +++ b/docs/source/modules.rst @@ -8,7 +8,7 @@ torchtext.models.multiheadattention .. currentmodule:: torchtext.models.multiheadattention :hidden:`MultiheadAttentionContainer` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: MultiheadAttentionContainer From f7e75d1a7fb34d430b53b890055a46d95babe73f Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 27 Apr 2020 14:50:10 -0700 Subject: [PATCH 38/56] minor --- test/data/test_modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/data/test_modules.py b/test/data/test_modules.py index 7734f5293a..9d45316ffd 100644 --- a/test/data/test_modules.py +++ b/test/data/test_modules.py @@ -36,8 +36,8 @@ def test_multiheadattention(self): in_proj_weight, None, bias_k, bias_v, False, 0.0, - MHA.out_proj.proj_layer.weight, - MHA.out_proj.proj_layer.bias, + MHA.out_proj.weight, + MHA.out_proj.bias, attn_mask=torch_attn_mask) assert_allclose(mha_output, torch_mha_output) From 11e302773cc0cba530cd44a0cef1bdca0e7961b5 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 27 Apr 2020 15:29:39 -0700 Subject: [PATCH 39/56] asserRaises tests in broadcast --- test/data/test_modules.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/data/test_modules.py b/test/data/test_modules.py index 9d45316ffd..b7d9a4d465 100644 --- a/test/data/test_modules.py +++ b/test/data/test_modules.py @@ -81,6 +81,11 @@ def test_broadcast_scaled_dot_product(self): assert list(sdp_attn_weights.size()) == [3, 3, bsz * nhead, tgt_len, embed_dim] assert_allclose(sdp_attn_output[2][2], sdp_attn_output_full) assert_allclose(sdp_attn_weights[2][2], sdp_attn_weights_full) + # dim -2 is not equal to neither key/value's dim -2 or 1 + with self.assertRaises(RuntimeError): + SDP(query.expand(tgt_len, 1, embed_dim), key.expand(3, 3, src_len, bsz * nhead, embed_dim), + value.expand(3, 3, src_len, bsz * nhead, embed_dim), + attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len) # key/value have a size of (src_len, 1, embed_dim) # while query has a size of (1, 2, 3, tgt_len, bsz * nhead, embed_dim) @@ -92,6 +97,16 @@ def test_broadcast_scaled_dot_product(self): assert list(sdp_attn_weights.size()) == [1, 2, 3, bsz * nhead, tgt_len, embed_dim] assert_allclose(sdp_attn_output[0][1][2], sdp_attn_output_full) assert_allclose(sdp_attn_weights[0][1][2], sdp_attn_weights_full) + # key dim -2 is not equal to value dim -2 + with self.assertRaisesRegex(AssertionError, "Shape of key, value must match"): + SDP(query.expand(1, 2, 3, tgt_len, bsz * nhead, embed_dim), key.expand(src_len, 2, embed_dim), + value.expand(src_len, 1, embed_dim), + attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) + # key/value dim -2 is not equal to neither query's dim -2 or 1 + with self.assertRaises(RuntimeError): + SDP(query.expand(1, 2, 3, tgt_len, bsz * nhead, embed_dim), key.expand(src_len, 2, embed_dim), + value.expand(src_len, 2, embed_dim), + attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) # attn_mask in a size of (1, tgt_len, src_len) # 2D tensor is not supported for attn_mask @@ -101,3 +116,8 @@ def test_broadcast_scaled_dot_product(self): attn_mask=attn_mask_2D.expand(1, tgt_len, src_len)) assert_allclose(sdp_attn_output, sdp_attn_output_full) assert_allclose(sdp_attn_weights, sdp_attn_weights_full) + # attn_mask's dim -3 is not equal to neither batch size or 1 + with self.assertRaisesRegex(RuntimeError, "The size of the attn_mask is not correct."): + SDP(query.expand(tgt_len, bsz * nhead, embed_dim), key.expand(src_len, bsz * nhead, embed_dim), + value.expand(src_len, bsz * nhead, embed_dim), + attn_mask=attn_mask_2D.expand(2, tgt_len, src_len)) From 5a709a5e232e824a3f4e7f52c5feae3b14ced861 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Tue, 28 Apr 2020 07:02:52 -0700 Subject: [PATCH 40/56] fix typo --- test/data/test_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/data/test_modules.py b/test/data/test_modules.py index b7d9a4d465..bf2da8a22c 100644 --- a/test/data/test_modules.py +++ b/test/data/test_modules.py @@ -85,7 +85,7 @@ def test_broadcast_scaled_dot_product(self): with self.assertRaises(RuntimeError): SDP(query.expand(tgt_len, 1, embed_dim), key.expand(3, 3, src_len, bsz * nhead, embed_dim), value.expand(3, 3, src_len, bsz * nhead, embed_dim), - attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len) + attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) # key/value have a size of (src_len, 1, embed_dim) # while query has a size of (1, 2, 3, tgt_len, bsz * nhead, embed_dim) From a90c8268340a00c2e9e73c09fc102ad7ebbccbd5 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Tue, 28 Apr 2020 08:17:53 -0700 Subject: [PATCH 41/56] minor fix --- test/data/test_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/data/test_modules.py b/test/data/test_modules.py index bf2da8a22c..a58338d112 100644 --- a/test/data/test_modules.py +++ b/test/data/test_modules.py @@ -83,7 +83,7 @@ def test_broadcast_scaled_dot_product(self): assert_allclose(sdp_attn_weights[2][2], sdp_attn_weights_full) # dim -2 is not equal to neither key/value's dim -2 or 1 with self.assertRaises(RuntimeError): - SDP(query.expand(tgt_len, 1, embed_dim), key.expand(3, 3, src_len, bsz * nhead, embed_dim), + SDP(query.expand(tgt_len, 2, embed_dim), key.expand(3, 3, src_len, bsz * nhead, embed_dim), value.expand(3, 3, src_len, bsz * nhead, embed_dim), attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) From 04096367063a6daebee261c1e5f4a6f360b01d5f Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Tue, 28 Apr 2020 08:44:40 -0700 Subject: [PATCH 42/56] add benchmark case --- benchmark/mha_block.py | 59 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 benchmark/mha_block.py diff --git a/benchmark/mha_block.py b/benchmark/mha_block.py new file mode 100644 index 0000000000..e8f2ebeecd --- /dev/null +++ b/benchmark/mha_block.py @@ -0,0 +1,59 @@ +import torch +from torchtext.modules import MultiheadAttentionContainer, ScaledDotProduct +from torch.nn import MultiheadAttention +from torch.nn.functional import multi_head_attention_forward as mha_forward +import time + + +def benchmark_mha_block(): + embed_dim, nhead, tgt_len, src_len, bsz = 768, 12, 128, 128, 72 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Build torchtext MultiheadAttention module + MHA = MultiheadAttentionContainer(nhead, + (torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim),), + ScaledDotProduct(), + torch.nn.Linear(embed_dim, embed_dim)).to(device) + + query = torch.rand((tgt_len, bsz, embed_dim)).to(device) + key = value = torch.rand((src_len, bsz, embed_dim)).to(device) + attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool).to(device) + bias_k = bias_v = torch.rand((1, 1, embed_dim)).to(device) + print("starting") + t0 = time.monotonic() + for _ in range(100): + mha_output, attn_weights = MHA(query, key, value, + attn_mask=torch.stack([attn_mask_2D] * (bsz * nhead)), + bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1), + bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1)) + print(time.monotonic() - t0) + + # Use torch.nn.functional.multi_head_attention_forward + torch_attn_mask = torch.zeros((tgt_len, src_len)).to(device).masked_fill_(attn_mask_2D, float('-inf')) + torch_MHA = MultiheadAttention(embed_dim, nhead).to(device) + print("starting") + torch_MHA.bias_k = bias_k + torch_MHA.bias_v = bias_v + t0 = time.monotonic() + for _ in range(100): + torch_mha_output, torch_mha_weights = torch_MHA(query, key, value, attn_mask=torch_attn_mask) + print(time.monotonic() - t0) + + print("starting") + t0 = time.monotonic() + in_proj_weight = torch.cat([MHA.query_in_proj.weight, MHA.key_in_proj.weight, MHA.value_in_proj.weight]) + for _ in range(100): + torch_mha_output, torch_mha_weights = mha_forward(query, key, value, + embed_dim, nhead, + in_proj_weight, None, + bias_k, bias_v, + False, 0.0, + MHA.out_proj.weight, + MHA.out_proj.bias, + attn_mask=torch_attn_mask) + print(time.monotonic() - t0) + + +if __name__ == "__main__": + benchmark_mha_block() From 6c9a7a3017235d3e0a9ed5bb5ce556f63e614580 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 29 Apr 2020 07:12:17 -0700 Subject: [PATCH 43/56] remove bias from test --- test/data/test_modules.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/data/test_modules.py b/test/data/test_modules.py index a58338d112..a257d76161 100644 --- a/test/data/test_modules.py +++ b/test/data/test_modules.py @@ -11,11 +11,11 @@ def test_multiheadattention(self): embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 # Build torchtext MultiheadAttention module MHA = MultiheadAttentionContainer(nhead, - (torch.nn.Linear(embed_dim, embed_dim), - torch.nn.Linear(embed_dim, embed_dim), - torch.nn.Linear(embed_dim, embed_dim),), + (torch.nn.Linear(embed_dim, embed_dim, bias=False), + torch.nn.Linear(embed_dim, embed_dim, bias=False), + torch.nn.Linear(embed_dim, embed_dim, bias=False),), ScaledDotProduct(), - torch.nn.Linear(embed_dim, embed_dim)) + torch.nn.Linear(embed_dim, embed_dim, bias=False)) query = torch.rand((tgt_len, bsz, embed_dim)) key = value = torch.rand((src_len, bsz, embed_dim)) @@ -36,8 +36,7 @@ def test_multiheadattention(self): in_proj_weight, None, bias_k, bias_v, False, 0.0, - MHA.out_proj.weight, - MHA.out_proj.bias, + MHA.out_proj.weight, None, attn_mask=torch_attn_mask) assert_allclose(mha_output, torch_mha_output) From 45d28b53a12c8dbf9ee79a19eddf658ea3ce2084 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 29 Apr 2020 08:27:03 -0700 Subject: [PATCH 44/56] update benchmark case --- benchmark/mha_block.py | 102 ++++++++++++++++++++++------------------- 1 file changed, 54 insertions(+), 48 deletions(-) diff --git a/benchmark/mha_block.py b/benchmark/mha_block.py index e8f2ebeecd..c15efab5a3 100644 --- a/benchmark/mha_block.py +++ b/benchmark/mha_block.py @@ -1,58 +1,64 @@ import torch from torchtext.modules import MultiheadAttentionContainer, ScaledDotProduct -from torch.nn import MultiheadAttention from torch.nn.functional import multi_head_attention_forward as mha_forward import time def benchmark_mha_block(): - embed_dim, nhead, tgt_len, src_len, bsz = 768, 12, 128, 128, 72 - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # Build torchtext MultiheadAttention module - MHA = MultiheadAttentionContainer(nhead, - (torch.nn.Linear(embed_dim, embed_dim), - torch.nn.Linear(embed_dim, embed_dim), - torch.nn.Linear(embed_dim, embed_dim),), - ScaledDotProduct(), - torch.nn.Linear(embed_dim, embed_dim)).to(device) - - query = torch.rand((tgt_len, bsz, embed_dim)).to(device) - key = value = torch.rand((src_len, bsz, embed_dim)).to(device) - attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool).to(device) - bias_k = bias_v = torch.rand((1, 1, embed_dim)).to(device) - print("starting") - t0 = time.monotonic() - for _ in range(100): - mha_output, attn_weights = MHA(query, key, value, - attn_mask=torch.stack([attn_mask_2D] * (bsz * nhead)), - bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1), - bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1)) - print(time.monotonic() - t0) - - # Use torch.nn.functional.multi_head_attention_forward - torch_attn_mask = torch.zeros((tgt_len, src_len)).to(device).masked_fill_(attn_mask_2D, float('-inf')) - torch_MHA = MultiheadAttention(embed_dim, nhead).to(device) - print("starting") - torch_MHA.bias_k = bias_k - torch_MHA.bias_v = bias_v - t0 = time.monotonic() - for _ in range(100): - torch_mha_output, torch_mha_weights = torch_MHA(query, key, value, attn_mask=torch_attn_mask) - print(time.monotonic() - t0) - - print("starting") - t0 = time.monotonic() - in_proj_weight = torch.cat([MHA.query_in_proj.weight, MHA.key_in_proj.weight, MHA.value_in_proj.weight]) - for _ in range(100): - torch_mha_output, torch_mha_weights = mha_forward(query, key, value, - embed_dim, nhead, - in_proj_weight, None, - bias_k, bias_v, - False, 0.0, - MHA.out_proj.weight, - MHA.out_proj.bias, - attn_mask=torch_attn_mask) - print(time.monotonic() - t0) + + def _run_benchmark(embed_dim, nhead, tgt_len, src_len, bsz, device): + # Build torchtext MultiheadAttention module + MHA = MultiheadAttentionContainer(nhead, + (torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim),), + ScaledDotProduct(), + torch.nn.Linear(embed_dim, embed_dim)).to(device) + + query = torch.rand((tgt_len, bsz, embed_dim)).to(device) + key = value = torch.rand((src_len, bsz, embed_dim)).to(device) + attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool).to(device) + bias_k = bias_v = torch.rand((1, 1, embed_dim)).to(device) + print("starting torchtext.modules.MultiheadAttentionContainer") + t0 = time.monotonic() + for _ in range(100): + mha_output, attn_weights = MHA(query, key, value, + attn_mask=torch.stack([attn_mask_2D] * (bsz * nhead)), + bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1), + bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1)) + print(time.monotonic() - t0) + + # Use torch.nn.functional.multi_head_attention_forward + torch_attn_mask = torch.zeros((tgt_len, src_len)).to(device).masked_fill_(attn_mask_2D, float('-inf')) + print("starting torch.nn.functional.multi_head_attention_forward") + in_proj_weight = torch.cat([MHA.query_in_proj.weight, MHA.key_in_proj.weight, MHA.value_in_proj.weight]) + t0 = time.monotonic() + for _ in range(100): + torch_mha_output, torch_mha_weights = mha_forward(query, key, value, + embed_dim, nhead, + in_proj_weight, None, + bias_k, bias_v, + False, 0.0, + MHA.out_proj.weight, + MHA.out_proj.bias, + attn_mask=torch_attn_mask) + print(time.monotonic() - t0) + + print("*" * 80) + print("test case GPU with embed_dim, nhead, tgt_len, src_len, bsz:", 768, 12, 128, 128, 72) + _run_benchmark(768, 12, 128, 128, 72, torch.device("cuda")) + + print("*" * 80) + print("test case GPU with embed_dim, nhead, tgt_len, src_len, bsz:", 64, 2, 10, 10, 8) + _run_benchmark(64, 2, 10, 10, 8, torch.device("cuda")) + + print("*" * 80) + print("test case CPU with embed_dim, nhead, tgt_len, src_len, bsz:", 768, 12, 128, 128, 72) + _run_benchmark(768, 12, 128, 128, 72, torch.device("cpu")) + + print("*" * 80) + print("test case CPU with embed_dim, nhead, tgt_len, src_len, bsz:", 64, 2, 10, 10, 8) + _run_benchmark(64, 2, 10, 10, 8, torch.device("cpu")) if __name__ == "__main__": From 4f3b458d2b6c98f2083a878cf82ca12c48e8a439 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 29 Apr 2020 13:14:46 -0700 Subject: [PATCH 45/56] add InProjContainer --- test/data/test_jit.py | 11 ++--- test/data/test_modules.py | 17 ++++---- torchtext/modules/__init__.py | 6 ++- torchtext/modules/multiheadattention.py | 57 +++++++++++++++++++------ 4 files changed, 63 insertions(+), 28 deletions(-) diff --git a/test/data/test_jit.py b/test/data/test_jit.py index dcd2a4bd9e..dd2ad8469c 100644 --- a/test/data/test_jit.py +++ b/test/data/test_jit.py @@ -1,5 +1,5 @@ import torch -from torchtext.modules import MultiheadAttentionContainer, ScaledDotProduct +from torchtext.modules import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct from torch.testing import assert_allclose from ..common.torchtext_test_case import TorchtextTestCase @@ -9,10 +9,11 @@ class TestJIT(TorchtextTestCase): def test_torchscript_multiheadattention(self): embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 # Build torchtext MultiheadAttention models - MHA = MultiheadAttentionContainer(nhead, - (torch.nn.Linear(embed_dim, embed_dim), - torch.nn.Linear(embed_dim, embed_dim), - torch.nn.Linear(embed_dim, embed_dim)), + in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim)) + + MHA = MultiheadAttentionContainer(nhead, in_proj_container, ScaledDotProduct(), torch.nn.Linear(embed_dim, embed_dim)) query = torch.rand((tgt_len, bsz, embed_dim)) diff --git a/test/data/test_modules.py b/test/data/test_modules.py index a257d76161..9f1d50bd1a 100644 --- a/test/data/test_modules.py +++ b/test/data/test_modules.py @@ -1,5 +1,5 @@ import torch -from torchtext.modules import MultiheadAttentionContainer, ScaledDotProduct +from torchtext.modules import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct from torch.nn.functional import multi_head_attention_forward as mha_forward from torch.testing import assert_allclose from ..common.torchtext_test_case import TorchtextTestCase @@ -10,10 +10,11 @@ class TestModels(TorchtextTestCase): def test_multiheadattention(self): embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 # Build torchtext MultiheadAttention module - MHA = MultiheadAttentionContainer(nhead, - (torch.nn.Linear(embed_dim, embed_dim, bias=False), - torch.nn.Linear(embed_dim, embed_dim, bias=False), - torch.nn.Linear(embed_dim, embed_dim, bias=False),), + in_proj = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim)) + + MHA = MultiheadAttentionContainer(nhead, in_proj, ScaledDotProduct(), torch.nn.Linear(embed_dim, embed_dim, bias=False)) @@ -28,9 +29,9 @@ def test_multiheadattention(self): # Use torch.nn.functional.multi_head_attention_forward torch_attn_mask = torch.zeros((tgt_len, src_len)).masked_fill_(attn_mask_2D, float('-inf')) - in_proj_weight = torch.cat([MHA.query_in_proj.weight, - MHA.key_in_proj.weight, - MHA.value_in_proj.weight]) + in_proj_weight = torch.cat([MHA.in_proj_container.query_in_proj.weight, + MHA.in_proj_container.key_in_proj.weight, + MHA.in_proj_container.value_in_proj.weight]) torch_mha_output, torch_mha_weights = mha_forward(query, key, value, embed_dim, nhead, in_proj_weight, None, diff --git a/torchtext/modules/__init__.py b/torchtext/modules/__init__.py index 70ba44939b..a55ced48fb 100644 --- a/torchtext/modules/__init__.py +++ b/torchtext/modules/__init__.py @@ -1,4 +1,6 @@ -from .multiheadattention import MultiheadAttentionContainer, ScaledDotProduct +from .multiheadattention import InProjContainer, \ + MultiheadAttentionContainer, ScaledDotProduct -__all__ = ['MultiheadAttentionContainer', +__all__ = ['InProjContainer', + 'MultiheadAttentionContainer', 'ScaledDotProduct'] diff --git a/torchtext/modules/multiheadattention.py b/torchtext/modules/multiheadattention.py index 1bc5cf01db..4ff2ed0722 100644 --- a/torchtext/modules/multiheadattention.py +++ b/torchtext/modules/multiheadattention.py @@ -3,24 +3,25 @@ class MultiheadAttentionContainer(torch.nn.Module): - def __init__(self, nhead, in_proj_tuple, attention_layer, out_proj): + def __init__(self, nhead, in_proj_container, attention_layer, out_proj): r""" A multi-head attention container Args: nhead: the number of heads in the multiheadattention model - in_proj_tuple: A tuple of multi-head in-projection linear layers (a.k.a nn.Linear). + in_proj_container: A container of multi-head in-projection linear layers (a.k.a nn.Linear). attention_layer: The attention layer. out_proj: The multi-head out-projection layer (a.k.a nn.Linear). Examples:: >>> import torch >>> embed_dim, num_heads, bsz = 10, 5, 64 + >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim)) >>> MHA = MultiheadAttentionContainer(num_heads, - (torch.nn.Linear(embed_dim, embed_dim), - torch.nn.Linear(embed_dim, embed_dim), - torch.nn.Linear(embed_dim, embed_dim)), - ScaledDotProduct(), - torch.nn.Linear(embed_dim, embed_dim)) + in_proj_container, + ScaledDotProduct(), + torch.nn.Linear(embed_dim, embed_dim)) >>> query = torch.rand((21, bsz, embed_dim)) >>> key = value = torch.rand((16, bsz, embed_dim)) >>> attn_output, attn_weights = MHA(query, key, value) @@ -29,9 +30,7 @@ def __init__(self, nhead, in_proj_tuple, attention_layer, out_proj): """ super(MultiheadAttentionContainer, self).__init__() self.nhead = nhead - self.query_in_proj = in_proj_tuple[0] - self.key_in_proj = in_proj_tuple[1] - self.value_in_proj = in_proj_tuple[2] + self.in_proj_container = in_proj_container self.attention_layer = attention_layer self.out_proj = out_proj @@ -65,17 +64,15 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, N is the batch size, and E is the embedding dimension. """ tgt_len, src_len, bsz, embed_dim = query.size(-3), key.size(-3), query.size(-2), query.size(-1) - q = self.query_in_proj(query) + q, k, v = self.in_proj_container(query, key, value) assert q.size(-1) % self.nhead == 0, "query's embed_dim must be divisible by the number of heads" head_dim = q.size(-1) // self.nhead q = q.reshape(tgt_len, bsz * self.nhead, head_dim) - k = self.key_in_proj(key) assert k.size(-1) % self.nhead == 0, "key's embed_dim must be divisible by the number of heads" head_dim = k.size(-1) // self.nhead k = k.reshape(src_len, bsz * self.nhead, head_dim) - v = self.value_in_proj(value) assert v.size(-1) % self.nhead == 0, "value's embed_dim must be divisible by the number of heads" head_dim = v.size(-1) // self.nhead v = v.reshape(src_len, bsz * self.nhead, head_dim) @@ -172,3 +169,37 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_output_weights, value) return attn_output.transpose(-2, -3), attn_output_weights + + +class InProjContainer(torch.nn.Module): + def __init__(self, query_proj, key_proj, value_proj): + r"""A in-proj container to process inputs. + + Args: + query_proj: a proj layer for query. + key_proj: a proj layer for key. + value_proj: a proj layer for value. + """ + + super(InProjContainer, self).__init__() + self.query_proj = query_proj + self.key_proj = key_proj + self.value_proj = value_proj + + def forward(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Projects the input sequences using in-proj layers. + + Args: + query, key, value (Tensors): sequence to be projected + + Shape: + - query, key, value: :math:`(S, N, E)` + - Output: :math:`(S, N, E)` + where S is the sequence length, N is the batch size, and E is the embedding dimension. + """ + return self.query_proj(query), self.key_proj(key), self.value_proj(value) + + From da4b3024a8f2a203d5e1f717a02583a9a7a0806d Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 29 Apr 2020 13:24:23 -0700 Subject: [PATCH 46/56] update benchmark --- benchmark/mha_block.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/benchmark/mha_block.py b/benchmark/mha_block.py index c15efab5a3..e9eabe47f6 100644 --- a/benchmark/mha_block.py +++ b/benchmark/mha_block.py @@ -1,5 +1,5 @@ import torch -from torchtext.modules import MultiheadAttentionContainer, ScaledDotProduct +from torchtext.modules import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct from torch.nn.functional import multi_head_attention_forward as mha_forward import time @@ -8,10 +8,10 @@ def benchmark_mha_block(): def _run_benchmark(embed_dim, nhead, tgt_len, src_len, bsz, device): # Build torchtext MultiheadAttention module - MHA = MultiheadAttentionContainer(nhead, - (torch.nn.Linear(embed_dim, embed_dim), - torch.nn.Linear(embed_dim, embed_dim), - torch.nn.Linear(embed_dim, embed_dim),), + in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim), + torch.nn.Linear(embed_dim, embed_dim)) + MHA = MultiheadAttentionContainer(nhead, in_proj_container, ScaledDotProduct(), torch.nn.Linear(embed_dim, embed_dim)).to(device) @@ -31,7 +31,9 @@ def _run_benchmark(embed_dim, nhead, tgt_len, src_len, bsz, device): # Use torch.nn.functional.multi_head_attention_forward torch_attn_mask = torch.zeros((tgt_len, src_len)).to(device).masked_fill_(attn_mask_2D, float('-inf')) print("starting torch.nn.functional.multi_head_attention_forward") - in_proj_weight = torch.cat([MHA.query_in_proj.weight, MHA.key_in_proj.weight, MHA.value_in_proj.weight]) + in_proj_weight = torch.cat([MHA.in_proj_container.query_proj.weight, + MHA.in_proj_container.key_proj.weight, + MHA.in_proj_container.value_proj.weight]) t0 = time.monotonic() for _ in range(100): torch_mha_output, torch_mha_weights = mha_forward(query, key, value, From 4aeaf5eda5b121599e052687e6d281ae513feaa9 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 29 Apr 2020 13:25:38 -0700 Subject: [PATCH 47/56] minor test --- test/data/test_modules.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/data/test_modules.py b/test/data/test_modules.py index 9f1d50bd1a..f74db24289 100644 --- a/test/data/test_modules.py +++ b/test/data/test_modules.py @@ -29,9 +29,9 @@ def test_multiheadattention(self): # Use torch.nn.functional.multi_head_attention_forward torch_attn_mask = torch.zeros((tgt_len, src_len)).masked_fill_(attn_mask_2D, float('-inf')) - in_proj_weight = torch.cat([MHA.in_proj_container.query_in_proj.weight, - MHA.in_proj_container.key_in_proj.weight, - MHA.in_proj_container.value_in_proj.weight]) + in_proj_weight = torch.cat([MHA.in_proj_container.query_proj.weight, + MHA.in_proj_container.key_proj.weight, + MHA.in_proj_container.value_proj.weight]) torch_mha_output, torch_mha_weights = mha_forward(query, key, value, embed_dim, nhead, in_proj_weight, None, From 517b921c17794d504cad2a81a1ea31421f1fd095 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 29 Apr 2020 13:53:36 -0700 Subject: [PATCH 48/56] minor fix --- test/data/test_jit.py | 8 ++++---- test/data/test_modules.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/data/test_jit.py b/test/data/test_jit.py index dd2ad8469c..dff0d26b9a 100644 --- a/test/data/test_jit.py +++ b/test/data/test_jit.py @@ -9,13 +9,13 @@ class TestJIT(TorchtextTestCase): def test_torchscript_multiheadattention(self): embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 # Build torchtext MultiheadAttention models - in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), - torch.nn.Linear(embed_dim, embed_dim), - torch.nn.Linear(embed_dim, embed_dim)) + in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim, bias=False), + torch.nn.Linear(embed_dim, embed_dim, bias=False), + torch.nn.Linear(embed_dim, embed_dim, bias=False)) MHA = MultiheadAttentionContainer(nhead, in_proj_container, ScaledDotProduct(), - torch.nn.Linear(embed_dim, embed_dim)) + torch.nn.Linear(embed_dim, embed_dim, bias=False)) query = torch.rand((tgt_len, bsz, embed_dim)) key = value = torch.rand((src_len, bsz, embed_dim)) attn_mask = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool) diff --git a/test/data/test_modules.py b/test/data/test_modules.py index f74db24289..0de4e93239 100644 --- a/test/data/test_modules.py +++ b/test/data/test_modules.py @@ -10,9 +10,9 @@ class TestModels(TorchtextTestCase): def test_multiheadattention(self): embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 # Build torchtext MultiheadAttention module - in_proj = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), - torch.nn.Linear(embed_dim, embed_dim), - torch.nn.Linear(embed_dim, embed_dim)) + in_proj = InProjContainer(torch.nn.Linear(embed_dim, embed_dim, bias=False), + torch.nn.Linear(embed_dim, embed_dim, bias=False), + torch.nn.Linear(embed_dim, embed_dim, bias=False)) MHA = MultiheadAttentionContainer(nhead, in_proj, ScaledDotProduct(), From 87801e2d92e39d2ff248ec68766e2f0037c45633 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 29 Apr 2020 14:42:12 -0700 Subject: [PATCH 49/56] flake8 --- torchtext/modules/multiheadattention.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchtext/modules/multiheadattention.py b/torchtext/modules/multiheadattention.py index 4ff2ed0722..98016954b7 100644 --- a/torchtext/modules/multiheadattention.py +++ b/torchtext/modules/multiheadattention.py @@ -201,5 +201,3 @@ def forward(self, where S is the sequence length, N is the batch size, and E is the embedding dimension. """ return self.query_proj(query), self.key_proj(key), self.value_proj(value) - - From 33800125011eef6c8d0f5f9e7b5170506b800981 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 4 May 2020 08:12:12 -0700 Subject: [PATCH 50/56] minor docs update --- torchtext/modules/multiheadattention.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchtext/modules/multiheadattention.py b/torchtext/modules/multiheadattention.py index 98016954b7..b180645d98 100644 --- a/torchtext/modules/multiheadattention.py +++ b/torchtext/modules/multiheadattention.py @@ -44,8 +44,9 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, query, key, value (Tensor): map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. attn_mask (BoolTensor, optional): 3D mask that prevents attention to certain positions. - bias_k and bias_v:bias (Tensor, optional): one more key and value sequence to be added at - sequence dim (dim=-3). Those are used for incremental decoding. + bias_k and bias_v: (Tensor, optional): one more key and value sequence to be added at + sequence dim (dim=-3). Those are used for incremental decoding. Users should provide + non-None to both arguments in order to activate them. Shape: - Inputs: @@ -116,8 +117,9 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key (Tensor): Projected key value (Tensor): Projected value attn_mask (BoolTensor, optional): 3D mask that prevents attention to certain positions. - bias_k and bias_v:bias: the additional key and value sequence to be added at sequence dim (dim=-3). - Those are used for incremental decoding. + bias_k and bias_v: (Tensor, optional): one more key and value sequence to be added at + sequence dim (dim=-3). Those are used for incremental decoding. Users should provide + non-None to both arguments in order to activate them. Shape: - query: :math:`(L, N * H, E / H)` From f02073f7b5b20883b86e66beeaf60b68011b4ba5 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Tue, 5 May 2020 13:20:54 -0700 Subject: [PATCH 51/56] add self-attention in the benchmark --- benchmark/mha_block.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/benchmark/mha_block.py b/benchmark/mha_block.py index e9eabe47f6..0d9a9994ca 100644 --- a/benchmark/mha_block.py +++ b/benchmark/mha_block.py @@ -6,7 +6,7 @@ def benchmark_mha_block(): - def _run_benchmark(embed_dim, nhead, tgt_len, src_len, bsz, device): + def _run_benchmark(embed_dim, nhead, bsz, device, tgt_len, src_len=None): # Build torchtext MultiheadAttention module in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim), @@ -16,7 +16,11 @@ def _run_benchmark(embed_dim, nhead, tgt_len, src_len, bsz, device): torch.nn.Linear(embed_dim, embed_dim)).to(device) query = torch.rand((tgt_len, bsz, embed_dim)).to(device) - key = value = torch.rand((src_len, bsz, embed_dim)).to(device) + if src_len is None: + key = value = query + src_len = tgt_len + else: + key = value = torch.rand((src_len, bsz, embed_dim)).to(device) attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool).to(device) bias_k = bias_v = torch.rand((1, 1, embed_dim)).to(device) print("starting torchtext.modules.MultiheadAttentionContainer") @@ -48,19 +52,23 @@ def _run_benchmark(embed_dim, nhead, tgt_len, src_len, bsz, device): print("*" * 80) print("test case GPU with embed_dim, nhead, tgt_len, src_len, bsz:", 768, 12, 128, 128, 72) - _run_benchmark(768, 12, 128, 128, 72, torch.device("cuda")) + _run_benchmark(768, 12, 72, torch.device("cuda"), 128, 128) + + print("*" * 80) + print("self-attention test case GPU with embed_dim, nhead, tgt_len, src_len, bsz:", 256, 8, 1000, 1000, 2) + _run_benchmark(256, 8, 2, torch.device("cuda"), 1000) print("*" * 80) print("test case GPU with embed_dim, nhead, tgt_len, src_len, bsz:", 64, 2, 10, 10, 8) - _run_benchmark(64, 2, 10, 10, 8, torch.device("cuda")) + _run_benchmark(64, 2, 8, torch.device("cuda"), 10, 10) print("*" * 80) print("test case CPU with embed_dim, nhead, tgt_len, src_len, bsz:", 768, 12, 128, 128, 72) - _run_benchmark(768, 12, 128, 128, 72, torch.device("cpu")) + _run_benchmark(768, 12, 72, torch.device("cpu"), 128, 128) print("*" * 80) print("test case CPU with embed_dim, nhead, tgt_len, src_len, bsz:", 64, 2, 10, 10, 8) - _run_benchmark(64, 2, 10, 10, 8, torch.device("cpu")) + _run_benchmark(64, 2, 8, torch.device("cpu"), 10, 10) if __name__ == "__main__": From 9a0a789e31dab9da3e88c7248d23413db4a28710 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 7 May 2020 08:16:06 -0700 Subject: [PATCH 52/56] update benchmark test with more cases --- benchmark/mha_block.py | 64 ++++++++++++++++++++++++++++++------------ 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/benchmark/mha_block.py b/benchmark/mha_block.py index 0d9a9994ca..eff568f5dd 100644 --- a/benchmark/mha_block.py +++ b/benchmark/mha_block.py @@ -22,14 +22,19 @@ def _run_benchmark(embed_dim, nhead, bsz, device, tgt_len, src_len=None): else: key = value = torch.rand((src_len, bsz, embed_dim)).to(device) attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool).to(device) + attn_mask = torch.stack([attn_mask_2D] * (bsz * nhead)) bias_k = bias_v = torch.rand((1, 1, embed_dim)).to(device) print("starting torchtext.modules.MultiheadAttentionContainer") + if device == torch.device("cuda"): + torch.cuda.synchronize() t0 = time.monotonic() for _ in range(100): mha_output, attn_weights = MHA(query, key, value, - attn_mask=torch.stack([attn_mask_2D] * (bsz * nhead)), + attn_mask=attn_mask, bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1), bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1)) + if device == torch.device("cuda"): + torch.cuda.synchronize() print(time.monotonic() - t0) # Use torch.nn.functional.multi_head_attention_forward @@ -38,6 +43,8 @@ def _run_benchmark(embed_dim, nhead, bsz, device, tgt_len, src_len=None): in_proj_weight = torch.cat([MHA.in_proj_container.query_proj.weight, MHA.in_proj_container.key_proj.weight, MHA.in_proj_container.value_proj.weight]) + if device == torch.device("cuda"): + torch.cuda.synchronize() t0 = time.monotonic() for _ in range(100): torch_mha_output, torch_mha_weights = mha_forward(query, key, value, @@ -48,27 +55,48 @@ def _run_benchmark(embed_dim, nhead, bsz, device, tgt_len, src_len=None): MHA.out_proj.weight, MHA.out_proj.bias, attn_mask=torch_attn_mask) + if device == torch.device("cuda"): + torch.cuda.synchronize() print(time.monotonic() - t0) - print("*" * 80) - print("test case GPU with embed_dim, nhead, tgt_len, src_len, bsz:", 768, 12, 128, 128, 72) - _run_benchmark(768, 12, 72, torch.device("cuda"), 128, 128) + # GPU test + device = torch.device("cuda") + for embed_dim in [64, 768]: + for nhead in [2, 16]: + for seq_len in [10, 128, 1000]: + for bsz in [2, 72]: + if seq_len == 1000 and bsz == 72: + continue + print("*" * 80) + print("test case GPU with embed_dim, nhead, seq_len, bsz:", + embed_dim, nhead, seq_len, seq_len, bsz) + _run_benchmark(embed_dim, nhead, bsz, device, seq_len, seq_len) - print("*" * 80) - print("self-attention test case GPU with embed_dim, nhead, tgt_len, src_len, bsz:", 256, 8, 1000, 1000, 2) - _run_benchmark(256, 8, 2, torch.device("cuda"), 1000) + # GPU test for self-attention + device = torch.device("cuda") + for embed_dim in [64, 256]: + for nhead in [2, 16]: + for seq_len in [10, 128, 1000]: + for bsz in [2, 72]: + if seq_len == 1000 and bsz == 72: + continue + print("*" * 80) + print("self-attention test case GPU with embed_dim, nhead, seq_len, bsz:", + embed_dim, nhead, seq_len, seq_len, bsz) + _run_benchmark(embed_dim, nhead, bsz, device, seq_len, None) - print("*" * 80) - print("test case GPU with embed_dim, nhead, tgt_len, src_len, bsz:", 64, 2, 10, 10, 8) - _run_benchmark(64, 2, 8, torch.device("cuda"), 10, 10) - - print("*" * 80) - print("test case CPU with embed_dim, nhead, tgt_len, src_len, bsz:", 768, 12, 128, 128, 72) - _run_benchmark(768, 12, 72, torch.device("cpu"), 128, 128) - - print("*" * 80) - print("test case CPU with embed_dim, nhead, tgt_len, src_len, bsz:", 64, 2, 10, 10, 8) - _run_benchmark(64, 2, 8, torch.device("cpu"), 10, 10) + # CPU test for self-attention + device = torch.device("cpu") + for embed_dim in [64, 768]: + for nhead in [2, 16]: + for seq_len in [10, 128, 1000]: + for bsz in [2, 72]: + if seq_len == 1000 and bsz == 72: + continue + print("*" * 80) + print("test case CPU with embed_dim, nhead, seq_len, bsz:", + embed_dim, nhead, seq_len, seq_len, bsz) + _run_benchmark(embed_dim, nhead, bsz, device, seq_len, None) if __name__ == "__main__": From 9f2491a5f204b524cd8a9f1fd6676ce7c815db3a Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 15 May 2020 12:43:26 -0700 Subject: [PATCH 53/56] update attn_mask --- torchtext/modules/multiheadattention.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtext/modules/multiheadattention.py b/torchtext/modules/multiheadattention.py index b180645d98..3d23a9f2ce 100644 --- a/torchtext/modules/multiheadattention.py +++ b/torchtext/modules/multiheadattention.py @@ -43,7 +43,8 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, Args: query, key, value (Tensor): map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. - attn_mask (BoolTensor, optional): 3D mask that prevents attention to certain positions. + attn_mask (Tensor, optional): 3D mask that prevents attention to certain positions. The type of attn_mask + should be same with that of attn_mask in the attention layer. bias_k and bias_v: (Tensor, optional): one more key and value sequence to be added at sequence dim (dim=-3). Those are used for incremental decoding. Users should provide non-None to both arguments in order to activate them. @@ -166,7 +167,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, # Dot product of q, k attn_output_weights = torch.matmul(query, key.transpose(-2, -1)) if attn_mask is not None: - attn_output_weights.masked_fill_(attn_mask, float('-inf'),) + attn_output_weights.masked_fill_(attn_mask, -1e8,) attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1) attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_output_weights, value) From 8771e3fc63e0f1c73df68e2927b55d1859ea9738 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 15 May 2020 12:56:54 -0700 Subject: [PATCH 54/56] add generate_square_subsequent_mask --- torchtext/modules/multiheadattention.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torchtext/modules/multiheadattention.py b/torchtext/modules/multiheadattention.py index 3d23a9f2ce..58c62c66b7 100644 --- a/torchtext/modules/multiheadattention.py +++ b/torchtext/modules/multiheadattention.py @@ -204,3 +204,15 @@ def forward(self, where S is the sequence length, N is the batch size, and E is the embedding dimension. """ return self.query_proj(query), self.key_proj(key), self.value_proj(value) + + +def generate_square_subsequent_mask(nbatch, sz): + r"""Generate a square mask for the sequence. The masked positions are filled with True. + Unmasked positions are filled with False. + + Args: + nbatch: the number of batch size + sz: the size of square mask + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1).repeat(nbatch, 1, 1) + return mask From 8b50742c86f8f749560bd8bb0dbc3431720b35c7 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 21 May 2020 07:17:56 -0700 Subject: [PATCH 55/56] update docs in MHA container --- torchtext/modules/multiheadattention.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/torchtext/modules/multiheadattention.py b/torchtext/modules/multiheadattention.py index 58c62c66b7..f6d8e7675d 100644 --- a/torchtext/modules/multiheadattention.py +++ b/torchtext/modules/multiheadattention.py @@ -43,24 +43,19 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, Args: query, key, value (Tensor): map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. - attn_mask (Tensor, optional): 3D mask that prevents attention to certain positions. The type of attn_mask - should be same with that of attn_mask in the attention layer. - bias_k and bias_v: (Tensor, optional): one more key and value sequence to be added at - sequence dim (dim=-3). Those are used for incremental decoding. Users should provide - non-None to both arguments in order to activate them. + attn_mask, bias_k and bias_v (Tensor, optional): keyword arguments passed to the attention layer. + See the definitions in the attention. Shape: - Inputs: - query: :math:`(L, N, E)` - key: :math:`(S, N, E)` - value: :math:`(S, N, E)` - - attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not allowed to attend - while ``False`` values will be unchanged. - - bias_k and bias_v:bias: :math:`(1, N * H, E / H)` + - attn_mask, bias_k and bias_v: same with the shape of the corresponding args in attention layer. - Outputs: - attn_output: :math:`(L, N, E)` - - attn_output_weights: :math:`(N*num_heads, L, S)` + - attn_output_weights: :math:`(N * H, L, S)` where where L is the target length, S is the sequence length, H is the number of attention heads, N is the batch size, and E is the embedding dimension. From 7078c9303afdd96365f7aec4558375b1292d5ade Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 5 Jun 2020 07:23:10 -0700 Subject: [PATCH 56/56] add InProjContainer in docs --- docs/source/modules.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/modules.rst b/docs/source/modules.rst index 85af9c7026..ca8b30e8e4 100644 --- a/docs/source/modules.rst +++ b/docs/source/modules.rst @@ -12,6 +12,11 @@ torchtext.models.multiheadattention .. autofunction:: MultiheadAttentionContainer +:hidden:`InProjContainer` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: InProjContainer + :hidden:`ScaledDotProduct` ~~~~~~~~~~~~~~~~~~~~~~~~~~