From e0076d30634cbbe69b35c05c92340fbb909f9681 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 29 Jun 2022 12:05:22 -0400 Subject: [PATCH 1/8] add relative attention bias implementation from HF --- torchtext/prototype/t5/modules.py | 76 +++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 torchtext/prototype/t5/modules.py diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py new file mode 100644 index 0000000000..63baa23926 --- /dev/null +++ b/torchtext/prototype/t5/modules.py @@ -0,0 +1,76 @@ +import math + +import torch + + +# NOTE: taken from HF; used to compute relative attention bias +def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + +# NOTE: modified from HF; used to compute relative attention bias +def _compute_bias( + query_length, + key_length, + relative_attention_bias, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + bidirectional=True, + device=None, +): + """Compute binned relative position bias""" + if device is None: + device = relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = _relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=bidirectional, + num_buckets=relative_attention_num_buckets, + max_distance=relative_attention_max_distance, + ) + values = relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values From 9dbef30fad21c5270bfe44734ecd2d4a74dafad0 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 29 Jun 2022 12:10:20 -0400 Subject: [PATCH 2/8] incoporate relative attention bias in attention computation --- torchtext/prototype/t5/modules.py | 60 +++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index 63baa23926..3fcf757365 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -1,6 +1,9 @@ import math +from typing import Optional, Tuple import torch +import torch.nn.functional as F +from torch import Tensor # NOTE: taken from HF; used to compute relative attention bias @@ -74,3 +77,60 @@ def _compute_bias( values = relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + + +# NOTE: modified from torch.nn.functional._scaled_dot_product_attention to incorporate relative attention bias +def _t5_scaled_dot_product_attention( + q: Tensor, + k: Tensor, + v: Tensor, + position_bias: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, +) -> Tuple[Tensor, Tensor]: + r""" + Computes scaled dot product attention on query, key and value tensors, using + an optional attention mask if passed, and applying dropout if a probability + greater than 0.0 is specified. + Returns a tensor pair containing attended values and attention weights. + Args: + q, k, v: query, key and value tensors. See Shape section for shape details. + attn_mask: optional tensor containing mask values to be added to calculated + attention. May be 2D or 3D; see Shape section for details. + dropout_p: dropout probability. If greater than 0.0, dropout is applied. + Shape: + - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length, + and E is embedding dimension. + - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length, + and E is embedding dimension. + - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length, + and E is embedding dimension. + - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of + shape :math:`(Nt, Ns)`. + - Output: attention values have shape :math:`(B, Nt, E)`; attention weights + have shape :math:`(B, Nt, Ns)` + """ + B, Nt, E = q.shape + q = q / math.sqrt(E) + + n_heads, tgt_len, src_len = position_bias.size()[1:] + assert B % n_heads == 0 + assert tgt_len == Nt + + position_bias = position_bias.repeat(B // n_heads, 1, 1, 1) + position_bias = position_bias.view(B, tgt_len, src_len) + + # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) + if attn_mask is not None: + attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1)) + position_bias += attn_mask + else: + attn = torch.bmm(q, k.transpose(-2, -1)) + + attn += position_bias + attn = F.softmax(attn, dim=-1) + if dropout_p > 0.0: + attn = F.dropout(attn, p=dropout_p) + # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) + output = torch.bmm(attn, v) + return output, attn From e5158a6e6cc2eafc318c68db0e34ee33b60ca3ff Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 29 Jun 2022 12:12:38 -0400 Subject: [PATCH 3/8] incoporate relative attention bias in MultiHeadAttention module --- torchtext/prototype/t5/modules.py | 448 ++++++++++++++++++++++++++++++ 1 file changed, 448 insertions(+) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index 3fcf757365..53748815df 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -134,3 +134,451 @@ def _t5_scaled_dot_product_attention( # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) output = torch.bmm(attn, v) return output, attn + + +# NOTE: modified from torch.nn.functional._multi_head_attention_forward to incorporate relative attention bias +def t5_multi_head_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Optional[Tensor], + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + has_relative_attention_bias: bool, + relative_attention_num_buckets: Optional[int], + relative_attention_max_distance: Optional[int], + relative_attention_bias: Optional[Tensor], + position_bias: Optional[Tensor], + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = True, +) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + has_relative_attention_bias: whether or not relative attention bias should be used in the layer + relative_attention_num_buckets: The number of buckets to use when computing the relative attention bias. + relative_attention_max_distance: The maximum distance of the longer sequences for the bucket separation. + position_bias: relative attention bias tensor, is computed at first layer and passed up to subsequent layers + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. + Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect + when ``need_weights=True.``. Default: True + Shape: + Inputs: + - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + Outputs: + - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns + attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per + head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. + """ + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) + if F.has_torch_function(tens_ops): + return F.handle_torch_function( + t5_multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + has_relative_attention_bias, + relative_attention_num_buckets, + relative_attention_max_distance, + relative_attention_bias, + position_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + average_attn_weights=average_attn_weights, + ) + + is_batched = F._mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads) + + # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input + # is batched, run the computation and before returning squeeze the + # batch dimension so that the output doesn't carry this temporary batch dimension. + if not is_batched: + # unsqueeze if the input is unbatched + query = query.unsqueeze(1) + key = key.unsqueeze(1) + value = value.unsqueeze(1) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.unsqueeze(0) + + # set up shape vars + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + assert ( + embed_dim == embed_dim_to_check + ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + if isinstance(embed_dim, torch.Tensor): + # embed_dim can be a tensor when JIT tracing + head_dim = embed_dim.div(num_heads, rounding_mode="trunc") + else: + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + if use_separate_proj_weight: + # allow MHA to have different embedding dimensions when separate projection weights are used + assert ( + key.shape[:2] == value.shape[:2] + ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + else: + assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" + + # + # compute in-projection + # + if not use_separate_proj_weight: + assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None" + q, k, v = F._in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) + else: + assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" + assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" + assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" + if in_proj_bias is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = in_proj_bias.chunk(3) + q, k, v = F._in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v) + + # prep attention mask + if attn_mask is not None: + if attn_mask.dtype == torch.uint8: + warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + attn_mask = attn_mask.to(torch.bool) + else: + assert ( + attn_mask.is_floating_point() or attn_mask.dtype == torch.bool + ), f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}" + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError( + f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." + ) + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError( + f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." + ) + else: + raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + + # prep key padding mask + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + # add bias along batch dimension (currently second) + if bias_k is not None and bias_v is not None: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + else: + assert bias_k is None + assert bias_v is None + + # + # reshape q, k, v for multihead attention and make em batch first + # + + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if static_k is None: + k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert ( + static_k.size(0) == bsz * num_heads + ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" + assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + k = static_k + if static_v is None: + v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert ( + static_v.size(0) == bsz * num_heads + ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" + assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + v = static_v + + # add zero attention along batch dimension (now first) + if add_zero_attn: + zero_attn_shape = (bsz * num_heads, 1, head_dim) + k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + + # update source sequence length after adjustments + src_len = k.size(1) + + # merge key padding and attention masks + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + bsz, + src_len, + ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = ( + key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) + ) + if attn_mask is None: + attn_mask = key_padding_mask + elif attn_mask.dtype == torch.bool: + attn_mask = attn_mask.logical_or(key_padding_mask) + else: + attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf")) + + # convert mask to float + if attn_mask is not None and attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + + # adjust dropout probability + if not training: + dropout_p = 0.0 + + if position_bias is None: + if not has_relative_attention_bias: + position_bias = torch.zeros((bsz * num_heads, tgt_len, src_len), device=k.device, dtype=k.dtype) + else: + position_bias = _compute_bias( + tgt_len, + src_len, + relative_attention_bias, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + bidirectional=True, + device=k.device, + ) + + # + # (deep breath) calculate attention and out projection + # + attn_output, attn_output_weights = _t5_scaled_dot_product_attention(q, k, v, position_bias, attn_mask, dropout_p) + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) + attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + + if need_weights: + # optionally average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + if average_attn_weights: + attn_output_weights = attn_output_weights.sum(dim=1) / num_heads + + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + attn_output_weights = attn_output_weights.squeeze(0) + + if has_relative_attention_bias: + return attn_output, attn_output_weights, position_bias + else: + return attn_output, attn_output_weights, None + else: + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + + if has_relative_attention_bias: + return attn_output, None, position_bias + else: + return attn_output, None, None + + +class T5MultiheadAttention(nn.MultiheadAttention): + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + device=None, + dtype=None, + has_relative_attention_bias=False, + relative_attention_num_buckets=None, + relative_attention_max_distance=None, + ) -> None: + + super(T5MultiheadAttention, self).__init__( + embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, device, dtype + ) + + self.has_relative_attention_bias = has_relative_attention_bias + if self.has_relative_attention_bias: + assert relative_attention_num_buckets + assert relative_attention_max_distance + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.num_heads) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + position_bias: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + + is_batched = query.dim() == 3 + + if self.batch_first and is_batched: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + + if not self._qkv_same_embed_dim: + attn_output, attn_output_weights, position_bias = t5_multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + self.has_relative_attention_bias, + self.relative_attention_num_buckets, + self.relative_attention_max_distance, + self.relative_attention_bias, + position_bias=position_bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + ) + else: + attn_output, attn_output_weights, position_bias = t5_multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + self.has_relative_attention_bias, + self.relative_attention_num_buckets, + self.relative_attention_max_distance, + self.relative_attention_bias, + position_bias=position_bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + ) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), attn_output_weights, position_bias + else: + return attn_output, attn_output_weights, position_bias From e9cef16c93d2fcf029de520171bac202fddaa313 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 29 Jun 2022 12:15:34 -0400 Subject: [PATCH 4/8] add t5 layer normalization module --- torchtext/prototype/t5/modules.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index 53748815df..f62d0a7dbe 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -1,7 +1,9 @@ import math +import warnings from typing import Optional, Tuple import torch +import torch.nn as nn import torch.nn.functional as F from torch import Tensor @@ -582,3 +584,30 @@ def forward( return attn_output.transpose(1, 0), attn_output_weights, position_bias else: return attn_output, attn_output_weights, position_bias + + +# NOTE: Taken from HF +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states From faf32cc0b2c3a6f18bbad5999e54a0195ab1e035 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 29 Jun 2022 12:16:51 -0400 Subject: [PATCH 5/8] outline t5 encoder layer --- torchtext/prototype/t5/modules.py | 44 ++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index f62d0a7dbe..bba49ded25 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -1,6 +1,6 @@ import math import warnings -from typing import Optional, Tuple +from typing import Optional, Tuple, Union, Callable import torch import torch.nn as nn @@ -611,3 +611,45 @@ def forward(self, hidden_states): hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states + + +class T5EncoderLayer(nn.TransformerEncoderLayer): + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = True, + has_relative_attention_bias: bool = True, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + device=None, + dtype=None, + ) -> None: + super(T5EncoderLayer, self).__init__( + d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, norm_first, device, dtype + ) + + self.self_attn = T5MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + has_relative_attention_bias=has_relative_attention_bias, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + ) + self.norm1 = T5LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = T5LayerNorm(d_model, eps=layer_norm_eps) + + def forward(): + pass + + +class T5Encoder(nn.TransformerEncoder): + + pass From 32a43d3af12b4e3bc42b1a9b7b84b0f4c87e2971 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 29 Jun 2022 17:30:40 -0400 Subject: [PATCH 6/8] implement t5 encoder layer --- torchtext/prototype/t5/modules.py | 197 +++++++++++++++++++++++------- 1 file changed, 151 insertions(+), 46 deletions(-) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index bba49ded25..bbda1865bf 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -1,3 +1,4 @@ +import copy import math import warnings from typing import Optional, Tuple, Union, Callable @@ -461,19 +462,14 @@ def t5_multi_head_attention_forward( attn_output = attn_output.squeeze(1) attn_output_weights = attn_output_weights.squeeze(0) - if has_relative_attention_bias: - return attn_output, attn_output_weights, position_bias - else: - return attn_output, attn_output_weights, None + return attn_output, attn_output_weights, position_bias + else: if not is_batched: # squeeze the output if input was unbatched attn_output = attn_output.squeeze(1) - if has_relative_attention_bias: - return attn_output, None, position_bias - else: - return attn_output, None, None + return attn_output, None, position_bias class T5MultiheadAttention(nn.MultiheadAttention): @@ -481,7 +477,7 @@ def __init__( self, embed_dim, num_heads, - dropout=0.0, + dropout=0.1, bias=True, add_bias_kv=False, add_zero_attn=False, @@ -490,23 +486,12 @@ def __init__( batch_first=False, device=None, dtype=None, - has_relative_attention_bias=False, - relative_attention_num_buckets=None, - relative_attention_max_distance=None, ) -> None: super(T5MultiheadAttention, self).__init__( embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, device, dtype ) - self.has_relative_attention_bias = has_relative_attention_bias - if self.has_relative_attention_bias: - assert relative_attention_num_buckets - assert relative_attention_max_distance - self.relative_attention_num_buckets = relative_attention_num_buckets - self.relative_attention_max_distance = relative_attention_max_distance - self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.num_heads) - def forward( self, query: Tensor, @@ -516,6 +501,9 @@ def forward( need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True, + has_relative_attention_bias=False, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, position_bias: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: @@ -524,6 +512,11 @@ def forward( if self.batch_first and is_batched: query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + if has_relative_attention_bias: + relative_attention_bias = nn.Embedding(relative_attention_num_buckets, self.num_heads) + else: + relative_attention_bias = None + if not self._qkv_same_embed_dim: attn_output, attn_output_weights, position_bias = t5_multi_head_attention_forward( query, @@ -539,10 +532,10 @@ def forward( self.dropout, self.out_proj.weight, self.out_proj.bias, - self.has_relative_attention_bias, - self.relative_attention_num_buckets, - self.relative_attention_max_distance, - self.relative_attention_bias, + has_relative_attention_bias=has_relative_attention_bias, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + relative_attention_bias=relative_attention_bias, position_bias=position_bias, training=self.training, key_padding_mask=key_padding_mask, @@ -569,10 +562,10 @@ def forward( self.dropout, self.out_proj.weight, self.out_proj.bias, - self.has_relative_attention_bias, - self.relative_attention_num_buckets, - self.relative_attention_max_distance, - self.relative_attention_bias, + has_relative_attention_bias=has_relative_attention_bias, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + relative_attention_bias=relative_attention_bias, position_bias=position_bias, training=self.training, key_padding_mask=key_padding_mask, @@ -588,12 +581,12 @@ def forward( # NOTE: Taken from HF class T5LayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, d_model, eps=1e-6): """ Construct a layernorm module in the T5 style. No bias and no subtraction of mean. """ super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) + self.weight = nn.Parameter(torch.ones(d_model)) self.variance_epsilon = eps def forward(self, hidden_states): @@ -616,15 +609,15 @@ def forward(self, hidden_states): class T5EncoderLayer(nn.TransformerEncoderLayer): def __init__( self, - d_model: int, - nhead: int, + d_model: int = 512, + nhead: int = 8, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - layer_norm_eps: float = 1e-5, + layer_norm_eps: float = 1e-6, batch_first: bool = False, norm_first: bool = True, - has_relative_attention_bias: bool = True, + has_relative_attention_bias: bool = False, relative_attention_num_buckets: int = 32, relative_attention_max_distance: int = 128, device=None, @@ -634,22 +627,134 @@ def __init__( d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, norm_first, device, dtype ) - self.self_attn = T5MultiheadAttention( - d_model, - nhead, - dropout=dropout, - batch_first=batch_first, - has_relative_attention_bias=has_relative_attention_bias, - relative_attention_num_buckets=relative_attention_num_buckets, - relative_attention_max_distance=relative_attention_max_distance, - ) + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + + self.self_attn = T5MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first) self.norm1 = T5LayerNorm(d_model, eps=layer_norm_eps) self.norm2 = T5LayerNorm(d_model, eps=layer_norm_eps) - def forward(): - pass + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layer. + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + Shape: + see the docs in Transformer class. + """ + + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + x = src + if self.norm_first: + attn_out, position_bias = self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, position_bias) + # residual connection + x = x + attn_out + x = x + self._ff_block(self.norm2(x)) + else: + attn_out, position_bias = self._sa_block(x, src_mask, src_key_padding_mask, position_bias) + x, position_bias = self.norm1(x + attn_out) + x = self.norm2(x + self._ff_block(x)) + + return x, position_bias + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + position_bias: Optional[Tensor], + ) -> Tensor: + attn = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + has_relative_attention_bias=self.has_relative_attention_bias, + relative_attention_num_buckets=self.relative_attention_num_buckets, + relative_attention_max_distance=self.relative_attention_max_distance, + position_bias=position_bias, + ) + + x = attn[0] + if self.has_relative_attention_bias and position_bias is None: + position_bias = attn[2] + + return self.dropout1(x), position_bias + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) class T5Encoder(nn.TransformerEncoder): - pass + r"""TransformerEncoder is a stack of N encoder layers. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to nested tensor + (and convert back on output). This will improve the overall performance of + TransformerEncoder when padding rate is high. Default: ``True`` (enabled). + Examples:: + >>> encoder_layer = T5EncoderLayer(d_model=512, nhead=8, batch_first=True) + >>> t5_norm = T5LayerNorm(d_model=512) + >>> t5_encoder = T5Encoder(encoder_layer, num_layers=6, norm=t5_norm) + >>> src = torch.rand(10, 32, 512) + >>> out = t5_encoder(src) + """ + __constants__ = ["norm"] + + def __init__(self, encoder_layer, num_layers=6, norm=None, enable_nested_tensor=True): + super(T5Encoder, self).__init__(encoder_layer, num_layers, norm, enable_nested_tensor) + + first_layer = copy.deepcopy(encoder_layer) + first_layer.has_relative_attention_bias = True + self.layers = nn.ModuleList([first_layer] + [copy.deepcopy(encoder_layer) for i in range(num_layers - 1)]) + self.num_layers = num_layers + self.norm = norm + self.enable_nested_tensor = enable_nested_tensor + + def forward( + self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + Shape: + see the docs in Transformer class. + """ + output = src + convert_to_nested = False + + position_bias = None + for mod in self.layers: + if convert_to_nested: + output, position_bias = mod(output, src_mask=mask, position_bias=position_bias) + else: + output, position_bias = mod( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, position_bias=position_bias + ) + + if convert_to_nested: + output = output.to_padded_tensor(0.0) + + if self.norm is not None: + output = self.norm(output) + + return output From 5a0f7bdd295264cace883e09344cb21f3c1d8615 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 6 Jul 2022 10:35:34 -0400 Subject: [PATCH 7/8] implement t5 encoder --- torchtext/prototype/t5/modules.py | 141 +++++++++++++++++++++--------- 1 file changed, 100 insertions(+), 41 deletions(-) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index bbda1865bf..50cf35d66a 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -154,7 +154,7 @@ def t5_multi_head_attention_forward( dropout_p: float, out_proj_weight: Tensor, out_proj_bias: Optional[Tensor], - has_relative_attention_bias: bool, + compute_relative_attention_bias: bool, relative_attention_num_buckets: Optional[int], relative_attention_max_distance: Optional[int], relative_attention_bias: Optional[Tensor], @@ -183,7 +183,7 @@ def t5_multi_head_attention_forward( value sequences at dim=1. dropout_p: probability of an element to be zeroed. out_proj_weight, out_proj_bias: the output projection weight and bias. - has_relative_attention_bias: whether or not relative attention bias should be used in the layer + compute_relative_attention_bias: whether or not relative attention bias should be computed in this layer. relative_attention_num_buckets: The number of buckets to use when computing the relative attention bias. relative_attention_max_distance: The maximum distance of the longer sequences for the bucket separation. position_bias: relative attention bias tensor, is computed at first layer and passed up to subsequent layers @@ -252,7 +252,7 @@ def t5_multi_head_attention_forward( dropout_p, out_proj_weight, out_proj_bias, - has_relative_attention_bias, + compute_relative_attention_bias, relative_attention_num_buckets, relative_attention_max_distance, relative_attention_bias, @@ -430,7 +430,7 @@ def t5_multi_head_attention_forward( dropout_p = 0.0 if position_bias is None: - if not has_relative_attention_bias: + if not compute_relative_attention_bias: position_bias = torch.zeros((bsz * num_heads, tgt_len, src_len), device=k.device, dtype=k.dtype) else: position_bias = _compute_bias( @@ -477,8 +477,8 @@ def __init__( self, embed_dim, num_heads, - dropout=0.1, - bias=True, + dropout=0.0, + bias=False, add_bias_kv=False, add_zero_attn=False, kdim=None, @@ -501,9 +501,10 @@ def forward( need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True, - has_relative_attention_bias=False, + compute_relative_attention_bias=False, relative_attention_num_buckets=32, relative_attention_max_distance=128, + relative_attention_bias: Optional[Tensor] = None, position_bias: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: @@ -512,11 +513,6 @@ def forward( if self.batch_first and is_batched: query, key, value = [x.transpose(1, 0) for x in (query, key, value)] - if has_relative_attention_bias: - relative_attention_bias = nn.Embedding(relative_attention_num_buckets, self.num_heads) - else: - relative_attention_bias = None - if not self._qkv_same_embed_dim: attn_output, attn_output_weights, position_bias = t5_multi_head_attention_forward( query, @@ -532,7 +528,7 @@ def forward( self.dropout, self.out_proj.weight, self.out_proj.bias, - has_relative_attention_bias=has_relative_attention_bias, + compute_relative_attention_bias=compute_relative_attention_bias, relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, relative_attention_bias=relative_attention_bias, @@ -562,7 +558,7 @@ def forward( self.dropout, self.out_proj.weight, self.out_proj.bias, - has_relative_attention_bias=has_relative_attention_bias, + compute_relative_attention_bias=compute_relative_attention_bias, relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, relative_attention_bias=relative_attention_bias, @@ -609,17 +605,18 @@ def forward(self, hidden_states): class T5EncoderLayer(nn.TransformerEncoderLayer): def __init__( self, - d_model: int = 512, - nhead: int = 8, - dim_feedforward: int = 2048, + d_model: int = 768, + nhead: int = 12, + dim_feedforward: int = 3072, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-6, batch_first: bool = False, norm_first: bool = True, - has_relative_attention_bias: bool = False, + compute_relative_attention_bias: bool = False, relative_attention_num_buckets: int = 32, relative_attention_max_distance: int = 128, + relative_attention_bias: Optional[Tensor] = None, device=None, dtype=None, ) -> None: @@ -627,11 +624,14 @@ def __init__( d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, norm_first, device, dtype ) - self.has_relative_attention_bias = has_relative_attention_bias + self.compute_relative_attention_bias = compute_relative_attention_bias self.relative_attention_num_buckets = relative_attention_num_buckets self.relative_attention_max_distance = relative_attention_max_distance + self.relative_attention_bias = relative_attention_bias self.self_attn = T5MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first) + self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False) + self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False) self.norm1 = T5LayerNorm(d_model, eps=layer_norm_eps) self.norm2 = T5LayerNorm(d_model, eps=layer_norm_eps) @@ -680,14 +680,15 @@ def _sa_block( attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, - has_relative_attention_bias=self.has_relative_attention_bias, + compute_relative_attention_bias=self.compute_relative_attention_bias, relative_attention_num_buckets=self.relative_attention_num_buckets, relative_attention_max_distance=self.relative_attention_max_distance, + relative_attention_bias=self.relative_attention_bias, position_bias=position_bias, ) x = attn[0] - if self.has_relative_attention_bias and position_bias is None: + if self.compute_relative_attention_bias and position_bias is None: position_bias = attn[2] return self.dropout1(x), position_bias @@ -710,22 +711,29 @@ class T5Encoder(nn.TransformerEncoder): (and convert back on output). This will improve the overall performance of TransformerEncoder when padding rate is high. Default: ``True`` (enabled). Examples:: - >>> encoder_layer = T5EncoderLayer(d_model=512, nhead=8, batch_first=True) - >>> t5_norm = T5LayerNorm(d_model=512) - >>> t5_encoder = T5Encoder(encoder_layer, num_layers=6, norm=t5_norm) + >>> encoder_layer = T5EncoderLayer(d_model=768, nhead=12, dim_feedfoward=3072, dropout=0.1, activation='relu', batch_first=True) + >>> t5_norm = T5LayerNorm(d_model=768) + >>> t5_encoder = T5Encoder(encoder_layer, num_layers=12, norm=t5_norm) >>> src = torch.rand(10, 32, 512) >>> out = t5_encoder(src) """ - __constants__ = ["norm"] - def __init__(self, encoder_layer, num_layers=6, norm=None, enable_nested_tensor=True): + def __init__( + self, + encoder_layer, + relative_attention_num_buckets, + num_heads, + num_layers=12, + norm=None, + enable_nested_tensor=True, + ): super(T5Encoder, self).__init__(encoder_layer, num_layers, norm, enable_nested_tensor) first_layer = copy.deepcopy(encoder_layer) - first_layer.has_relative_attention_bias = True + first_layer.compute_relative_attention_bias = True + first_layer.relative_attention_bias = nn.Embedding(relative_attention_num_buckets, num_heads) self.layers = nn.ModuleList([first_layer] + [copy.deepcopy(encoder_layer) for i in range(num_layers - 1)]) self.num_layers = num_layers - self.norm = norm self.enable_nested_tensor = enable_nested_tensor def forward( @@ -739,22 +747,73 @@ def forward( Shape: see the docs in Transformer class. """ - output = src - convert_to_nested = False + output = src position_bias = None for mod in self.layers: - if convert_to_nested: - output, position_bias = mod(output, src_mask=mask, position_bias=position_bias) - else: - output, position_bias = mod( - output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, position_bias=position_bias - ) + output, position_bias = mod( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, position_bias=position_bias + ) - if convert_to_nested: - output = output.to_padded_tensor(0.0) + return output - if self.norm is not None: - output = self.norm(output) - return output +class T5EncoderModel(nn.Module): + def __init__( + self, + d_model: int, + d_feedforward: int, + dropout: float, + activation: Union[str, Callable[[Tensor], Tensor]], + layer_norm_eps: float, + num_heads: int, + num_layers: int, + batch_first: bool, + relative_attention_num_buckets: int, + relative_attention_max_distance: int, + padding_idx: int, + max_seq_len: int, + vocab_size: int, + ) -> None: + super().__init__() + + self.d_model = d_model + self.d_feedforward = d_feedforward + self.dropout = dropout + self.activation = activation + self.layer_norm_eps = layer_norm_eps + self.num_heads = num_heads + self.num_layers = num_layers + self.batch_first = batch_first + self.relative_attention_num_buckets = relative_attention_num_buckets + self.realtive_attention_max_distance = relative_attention_max_distance + self.padding_idx = padding_idx + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + + self.token_embeddings = nn.Embedding(vocab_size, d_model, padding_idx) + self.encoder_layer = T5EncoderLayer( + d_model, + num_heads, + d_feedforward, + dropout, + activation, + layer_norm_eps, + batch_first, + norm_first=True, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + ) + self.norm = T5LayerNorm(d_model) + self.encoder = T5Encoder(self.encoder_layer, relative_attention_num_buckets, num_heads, num_layers) + self.dropout = nn.Dropout(dropout) + + def forward(self, tokens: torch.Tensor): + + padding_mask = tokens.eq(self.padding_idx) + embeddings = self.dropout(self.token_embeddings(tokens)) + encoder_output = self.encoder(embeddings, src_key_padding_mask=padding_mask) + encoder_output = self.norm(encoder_output) + encoder_output = self.dropout(encoder_output) + + return encoder_output From 27738308b743e9c0dc91c2d29a7bca005e4ca228 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Thu, 7 Jul 2022 14:06:11 -0400 Subject: [PATCH 8/8] return hidden states from each layer of encoder --- torchtext/prototype/t5/modules.py | 111 +++++++++++++----------------- 1 file changed, 46 insertions(+), 65 deletions(-) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index 50cf35d66a..e8dfeab00d 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -114,7 +114,8 @@ def _t5_scaled_dot_product_attention( have shape :math:`(B, Nt, Ns)` """ B, Nt, E = q.shape - q = q / math.sqrt(E) + # NOTE: HF implementation does not perform this normalization. For the sake of matching test results, we have commented it out + # q = q / math.sqrt(E) n_heads, tgt_len, src_len = position_bias.size()[1:] assert B % n_heads == 0 @@ -491,6 +492,11 @@ def __init__( super(T5MultiheadAttention, self).__init__( embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, device, dtype ) + factory_kwargs = {"device": device, "dtype": dtype} + self.q_proj_weight = nn.Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) + self.k_proj_weight = nn.Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) + self.v_proj_weight = nn.Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) + self.register_parameter("in_proj_weight", None) def forward( self, @@ -513,62 +519,35 @@ def forward( if self.batch_first and is_batched: query, key, value = [x.transpose(1, 0) for x in (query, key, value)] - if not self._qkv_same_embed_dim: - attn_output, attn_output_weights, position_bias = t5_multi_head_attention_forward( - query, - key, - value, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.bias_k, - self.bias_v, - self.add_zero_attn, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - compute_relative_attention_bias=compute_relative_attention_bias, - relative_attention_num_buckets=relative_attention_num_buckets, - relative_attention_max_distance=relative_attention_max_distance, - relative_attention_bias=relative_attention_bias, - position_bias=position_bias, - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - use_separate_proj_weight=True, - q_proj_weight=self.q_proj_weight, - k_proj_weight=self.k_proj_weight, - v_proj_weight=self.v_proj_weight, - average_attn_weights=average_attn_weights, - ) - else: - attn_output, attn_output_weights, position_bias = t5_multi_head_attention_forward( - query, - key, - value, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.bias_k, - self.bias_v, - self.add_zero_attn, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - compute_relative_attention_bias=compute_relative_attention_bias, - relative_attention_num_buckets=relative_attention_num_buckets, - relative_attention_max_distance=relative_attention_max_distance, - relative_attention_bias=relative_attention_bias, - position_bias=position_bias, - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - average_attn_weights=average_attn_weights, - ) + attn_output, attn_output_weights, position_bias = t5_multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + compute_relative_attention_bias=compute_relative_attention_bias, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + relative_attention_bias=relative_attention_bias, + position_bias=position_bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + ) if self.batch_first and is_batched: return attn_output.transpose(1, 0), attn_output_weights, position_bias else: @@ -747,15 +726,15 @@ def forward( Shape: see the docs in Transformer class. """ - output = src + all_outputs = () position_bias = None for mod in self.layers: + all_outputs = all_outputs + (output,) output, position_bias = mod( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, position_bias=position_bias ) - - return output + return output, all_outputs, position_bias class T5EncoderModel(nn.Module): @@ -806,14 +785,16 @@ def __init__( ) self.norm = T5LayerNorm(d_model) self.encoder = T5Encoder(self.encoder_layer, relative_attention_num_buckets, num_heads, num_layers) - self.dropout = nn.Dropout(dropout) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) def forward(self, tokens: torch.Tensor): padding_mask = tokens.eq(self.padding_idx) - embeddings = self.dropout(self.token_embeddings(tokens)) - encoder_output = self.encoder(embeddings, src_key_padding_mask=padding_mask) + embeddings = self.dropout1(self.token_embeddings(tokens)) + encoder_output, all_hidden_states, position_bias = self.encoder(embeddings, src_key_padding_mask=padding_mask) encoder_output = self.norm(encoder_output) - encoder_output = self.dropout(encoder_output) + last_hidden_state = self.dropout2(encoder_output) + all_hidden_states = all_hidden_states + (last_hidden_state,) - return encoder_output + return last_hidden_state, all_hidden_states, position_bias