Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Conversation

@pmabbo13 pmabbo13 requested a review from Nayef211 July 13, 2022 18:15
WIP PR to workshop implementation: #1812 

[ghstack-poisoned]
@pmabbo13
Copy link
Contributor Author

Description

Having computed the relative attention bias term, this method computes the attention scores. The implementation is very similar to the nn.Functional._scaled_dot_product_attention, expect that we pass in position_bias as an input argument so that relative attention bias can be incorporated in the computation of the attention scores.

Since the input tensors to this function are 4-dimensional, we replace the torch.baddbmm and torch.bmm with torch.matmul. Since we are no longer using torch.baddbmm where attn_mask was passed as an input argument, we instead add the attn_mask directly to position_bias to ensure the mask is still applied before the softmax is taken to get the final attention scores.

@pmabbo13 pmabbo13 requested review from abhinavarora and parmeet July 15, 2022 15:57
Copy link
Contributor

@parmeet parmeet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

pmabbo13 added 2 commits July 15, 2022 16:58
WIP PR to workshop implementation: #1812 

[ghstack-poisoned]
WIP PR to workshop implementation: #1812 

[ghstack-poisoned]
@pmabbo13 pmabbo13 merged commit 9ed4314 into gh/pmabbo13/12/base Jul 18, 2022
@facebook-github-bot facebook-github-bot deleted the gh/pmabbo13/12/head branch August 18, 2022 14:20
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants