From 7ff37d84105937197c8c9b99a0d8115cf9cb0ac7 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 13 Jul 2022 13:24:18 -0400 Subject: [PATCH] computing relative attention bias [ghstack-poisoned] --- torchtext/prototype/t5/modules.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/t5/modules.py index ec67898c75..0dc3e4848b 100644 --- a/torchtext/prototype/t5/modules.py +++ b/torchtext/prototype/t5/modules.py @@ -66,6 +66,33 @@ def __init__( def forward(): pass + # NOTE: modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + def _compute_bias( + self, + query_length: int, + key_length: int, + relative_attention_bias: Tensor, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + bidirectional: bool = True, + device=None, + ): + """Compute binned relative position bias""" + if device is None: + device = relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=bidirectional, + num_buckets=relative_attention_num_buckets, + max_distance=relative_attention_max_distance, + ) + values = relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + # NOTE: taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py def _relative_position_bucket( self, relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128