diff --git a/torchtext/prototype/models/t5/modules.py b/torchtext/prototype/models/t5/modules.py index f66036b96e..63ec17170d 100644 --- a/torchtext/prototype/models/t5/modules.py +++ b/torchtext/prototype/models/t5/modules.py @@ -74,6 +74,8 @@ def __init__( else: self.relative_attention_bias = None + self.device = device + def forward( self, query: Tensor, @@ -258,7 +260,6 @@ def _t5_multi_head_attention_forward( tgt_len, src_len, bidirectional=(not self.is_decoder), - device=k.device, ) # Calculate attention and out projection @@ -402,20 +403,17 @@ def _t5_dot_product_attention( output = output.transpose(1, 2).contiguous().view(B, -1, H * E) return output, attn - # NOTE: modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L421 + # NOTE: Modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L421 def _compute_bias( self, query_length: int, key_length: int, bidirectional: bool = True, - device: Optional[torch.device] = None, ) -> Tensor: """Compute binned relative position bias""" assert self.relative_attention_bias is not None - if device is None: - device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + context_position = torch.arange(query_length, dtype=torch.long, device=self.device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=self.device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) @@ -427,7 +425,7 @@ def _compute_bias( values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values - # NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L374 + # NOTE: Modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L374 def _relative_position_bucket( self, relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128 ) -> Tensor: @@ -448,7 +446,7 @@ def _relative_position_bucket( Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ - relative_buckets = torch.zeros(relative_position.shape, dtype=torch.long) + relative_buckets = torch.zeros(relative_position.shape, dtype=torch.long, device=self.device) if bidirectional: num_buckets = num_buckets // 2 relative_buckets += (relative_position > 0).to(torch.long) * num_buckets