From 88743d6b4e327d2a85e58d34ed7b2e12524222fd Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 12 Jul 2022 17:35:29 -0400 Subject: [PATCH] compute relative position bias for t5 attention [ghstack-poisoned] --- torchtext/prototype/t5/modules.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index de8d438e37..a3db63b9ae 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -19,6 +19,33 @@ from torch import Tensor +# NOTE: modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py +def _compute_bias( + query_length: int, + key_length: int, + relative_attention_bias: Tensor, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + bidirectional: bool = True, + device=None, +): + """Compute binned relative position bias""" + if device is None: + device = relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=bidirectional, + num_buckets=relative_attention_num_buckets, + max_distance=relative_attention_max_distance, + ) + values = relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + # NOTE: taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py def _relative_position_bucket( relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128