-
Notifications
You must be signed in to change notification settings - Fork 814
adding forward method for multihead attention #1833
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
|
This is still a relatively large PR and it's a bit hard to tell which lines were added by you. It might be worthwhile leaving a comment on the PR highlighting the lines that are added here and are different from the original implementation (like you did in #1812). |
| dropout_p = 0.0 | ||
| else: | ||
| dropout_p = self.dropout | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lines 256-274 encapsulate the changes made from _torch.nn.functional.multi_head_attention_forward to include relative attention bias. position_bias is then also added as a return value in lines 393, 400.
[ghstack-poisoned]
DescriptionThe forward method for T5MultiheadAttention is a modified version of nn.Functional.multi_head_attention_forward meant to perform multihead attention with relative attention bias on the input query, key, and value tensors. The main modifications are as follows:
These changes are best visible via this commit. |
| if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: | ||
| warnings.warn("Byte tensor for key_padding_mask is not supported. Using bool tensor instead.") | ||
| key_padding_mask = key_padding_mask.to(torch.bool) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We were having issues from the original implementation where the q, k, v tensors were reshaped to (batch_size * num_heads, seq_len, head_dim). This led to outputs that differed from the HF output when the input sequence to the decoder had batch size larger 1. The resolution to this was to shape these into 4D tensors of shape (batch_size, num_heads, seq_len, head_dim) which is the same reshaping done in the HF implementation.
parmeet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM!
[ghstack-poisoned]
Nayef211
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for adding a detailed description with the changes wrt to original implementation and resolving all PR comments!
[ghstack-poisoned]
Stack from ghstack (oldest at bottom):