Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Conversation

@pmabbo13 pmabbo13 requested a review from Nayef211 July 13, 2022 18:15
@Nayef211
Copy link
Contributor

This is still a relatively large PR and it's a bit hard to tell which lines were added by you. It might be worthwhile leaving a comment on the PR highlighting the lines that are added here and are different from the original implementation (like you did in #1812).

dropout_p = 0.0
else:
dropout_p = self.dropout

Copy link
Contributor Author

@pmabbo13 pmabbo13 Jul 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lines 256-274 encapsulate the changes made from _torch.nn.functional.multi_head_attention_forward to include relative attention bias. position_bias is then also added as a return value in lines 393, 400.

@pmabbo13
Copy link
Contributor Author

Description

The forward method for T5MultiheadAttention is a modified version of nn.Functional.multi_head_attention_forward meant to perform multihead attention with relative attention bias on the input query, key, and value tensors. The main modifications are as follows:

  1. Parameters needed to compute relative attention bias are added as input arguments
  2. Deprecated non-core functionalities such as add_zero_attn (adds a new batch of zeros to the key and value sequences at dim=1), and adding bias terms to the key and value projections
  3. The nn.Functional.multi_head_attention_forward method reshapes the q, k, v to be 3D. Doing so appears to have led to discrepancies to the decoder outputs of the HF implementation when the input decoder sequence had a batch size larger than 1. The resolution was to shape these tensors to 4D (similarly to HF implementation), and that appeared to resolve the issue.

These changes are best visible via this commit.

@pmabbo13 pmabbo13 requested review from abhinavarora and parmeet July 15, 2022 15:57
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn("Byte tensor for key_padding_mask is not supported. Using bool tensor instead.")
key_padding_mask = key_padding_mask.to(torch.bool)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were having issues from the original implementation where the q, k, v tensors were reshaped to (batch_size * num_heads, seq_len, head_dim). This led to outputs that differed from the HF output when the input sequence to the decoder had batch size larger 1. The resolution to this was to shape these into 4D tensors of shape (batch_size, num_heads, seq_len, head_dim) which is the same reshaping done in the HF implementation.

Copy link
Contributor

@parmeet parmeet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM!

Copy link
Contributor

@Nayef211 Nayef211 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for adding a detailed description with the changes wrt to original implementation and resolving all PR comments!

@pmabbo13 pmabbo13 merged commit 283590d into gh/pmabbo13/13/base Jul 18, 2022
@facebook-github-bot facebook-github-bot deleted the gh/pmabbo13/13/head branch August 18, 2022 14:20
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants