From 104c4f8b1e4ba16261f488164f952966fd9ae7f0 Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Thu, 13 Oct 2022 10:57:00 -0400 Subject: [PATCH 1/5] Move relative_buckets Tensor to same device as relative_position --- torchtext/prototype/models/t5/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/prototype/models/t5/modules.py b/torchtext/prototype/models/t5/modules.py index f66036b96e..0e7d5faad2 100644 --- a/torchtext/prototype/models/t5/modules.py +++ b/torchtext/prototype/models/t5/modules.py @@ -448,7 +448,7 @@ def _relative_position_bucket( Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ - relative_buckets = torch.zeros(relative_position.shape, dtype=torch.long) + relative_buckets = torch.zeros(relative_position.shape, dtype=torch.long, device=relative_position.device) if bidirectional: num_buckets = num_buckets // 2 relative_buckets += (relative_position > 0).to(torch.long) * num_buckets From eb0073ffcb09e9fff5234c1686fe25ae18e0fa7d Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Thu, 13 Oct 2022 10:57:48 -0400 Subject: [PATCH 2/5] Update code pointer comments --- torchtext/prototype/models/t5/modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtext/prototype/models/t5/modules.py b/torchtext/prototype/models/t5/modules.py index 0e7d5faad2..aa0c8f0d58 100644 --- a/torchtext/prototype/models/t5/modules.py +++ b/torchtext/prototype/models/t5/modules.py @@ -402,7 +402,7 @@ def _t5_dot_product_attention( output = output.transpose(1, 2).contiguous().view(B, -1, H * E) return output, attn - # NOTE: modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L421 + # NOTE: Modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L421 def _compute_bias( self, query_length: int, @@ -427,7 +427,7 @@ def _compute_bias( values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values - # NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L374 + # NOTE: Modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L374 def _relative_position_bucket( self, relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128 ) -> Tensor: From bf9e1f17e8b15c392c588f1ceab42642caf29865 Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Fri, 14 Oct 2022 14:25:25 -0400 Subject: [PATCH 3/5] Reference self.device from within MultiHeadedAttention private methods --- torchtext/prototype/models/t5/modules.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchtext/prototype/models/t5/modules.py b/torchtext/prototype/models/t5/modules.py index aa0c8f0d58..c44f1a603a 100644 --- a/torchtext/prototype/models/t5/modules.py +++ b/torchtext/prototype/models/t5/modules.py @@ -408,14 +408,11 @@ def _compute_bias( query_length: int, key_length: int, bidirectional: bool = True, - device: Optional[torch.device] = None, ) -> Tensor: """Compute binned relative position bias""" assert self.relative_attention_bias is not None - if device is None: - device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + context_position = torch.arange(query_length, dtype=torch.long, device=self.device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=self.device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) @@ -448,7 +445,7 @@ def _relative_position_bucket( Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ - relative_buckets = torch.zeros(relative_position.shape, dtype=torch.long, device=relative_position.device) + relative_buckets = torch.zeros(relative_position.shape, dtype=torch.long, device=self.device) if bidirectional: num_buckets = num_buckets // 2 relative_buckets += (relative_position > 0).to(torch.long) * num_buckets From 1ef86615e4436941b363661e37d44647cc6900dc Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Fri, 14 Oct 2022 14:52:02 -0400 Subject: [PATCH 4/5] Remove faulty call with device to t5 forward method --- torchtext/prototype/models/t5/modules.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtext/prototype/models/t5/modules.py b/torchtext/prototype/models/t5/modules.py index c44f1a603a..df9dc611b2 100644 --- a/torchtext/prototype/models/t5/modules.py +++ b/torchtext/prototype/models/t5/modules.py @@ -258,7 +258,6 @@ def _t5_multi_head_attention_forward( tgt_len, src_len, bidirectional=(not self.is_decoder), - device=k.device, ) # Calculate attention and out projection From 99b0872213518a1127f144e9f81b6a976ae74bd9 Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Fri, 14 Oct 2022 15:11:11 -0400 Subject: [PATCH 5/5] Add device to Attention obj --- torchtext/prototype/models/t5/modules.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchtext/prototype/models/t5/modules.py b/torchtext/prototype/models/t5/modules.py index df9dc611b2..63ec17170d 100644 --- a/torchtext/prototype/models/t5/modules.py +++ b/torchtext/prototype/models/t5/modules.py @@ -74,6 +74,8 @@ def __init__( else: self.relative_attention_bias = None + self.device = device + def forward( self, query: Tensor,