-
Notifications
You must be signed in to change notification settings - Fork 814
MultiheadAttention building blocks in torchtext #720
Conversation
…eadInProject, MultiheadOutProject
1765b69 to
2b9b68c
Compare
…tch dim of either query or key/value to be 1
2b9b68c to
66b71ac
Compare
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a quick pass on the benchmark scripts, and I think we can still improve it (specially for CUDA).
This could explain why the MHA implementation in here seems to be significantly faster than the PyTorch one (which has a number of sync points internally).
| attn_mask=torch.stack([attn_mask_2D] * (bsz * nhead)), | ||
| bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1), | ||
| bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1)) | ||
| print(time.monotonic() - t0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you are benchmarking with CUDA, you need to add a torch.cuda.synchronize() before and after measuring the time, otherwise the timings won't be correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Will add them there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason for this is that calls into cuda verions of operations are launched asynchronously. Only when you print a Tensor or convert it onto CPU can you be sure all operations have finished. Using synchronize here helps you make sure indeed all the work has finished and you're timing things correctly. Also see torch.cuda.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangguanheng66 Could you share with us how your implementation performs compared to the PyTorch one after you have fixed the timing? Thanks.
| MHA.out_proj.weight, | ||
| MHA.out_proj.bias, | ||
| attn_mask=torch_attn_mask) | ||
| print(time.monotonic() - t0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here.
benchmark/mha_block.py
Outdated
| print(time.monotonic() - t0) | ||
|
|
||
| print("*" * 80) | ||
| print("test case GPU with embed_dim, nhead, tgt_len, src_len, bsz:", 768, 12, 128, 128, 72) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe most of the potential speed benefits from the MHA implemented in PyTorch are only valid when query = key = value (because it computes the projections in a single kernel launch for the 3).
Can you add more benchmarks for different sizes in the query = key = value case? A for loop would be helpful there, something like
for embed_dim in [256, 768]:
for ...
for ...
print(...)
_run_benchmark(...)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've run benchmarks on this and it depends on the size of the inputs as well. For large inputs, as you can probably imagine, it shouldn't make much of a difference since the overhead disappears.
| head_dim = v.size(-1) // self.nhead | ||
| v = v.reshape(src_len, bsz * self.nhead, head_dim) | ||
|
|
||
| attn_output, attn_output_weights = self.attention_layer(q, k, v, attn_mask=attn_mask, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that for this container in particular there are no assumptions made on the dtype of attn_mask. I think we can relax that constraint. It stems from the fact that ScaledDotProduct needs a BoolTensor as a mask, but not for the container.
| # Dot product of q, k | ||
| attn_output_weights = torch.matmul(query, key.transpose(-2, -1)) | ||
| if attn_mask is not None: | ||
| attn_output_weights.masked_fill_(attn_mask, float('-inf'),) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe for some speech use case they needed to use -1e8 instead of -inf to avoid NaN: https://github.com/pytorch/fairseq/blob/928dc47e7e72f3e6ed96e50942e7fb8892cdcf32/fairseq/modules/transformer_layer.py#L108-L112
Does it make sense to have this be user configurable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I will follow the convention in fairseq. We could add this user configurable later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, since this is part of ScaledDotProduct we can create variants of ScaledDotProduct that are more flexible for this kind of stuff. I think we'll end up with a small collection of attention functions and maybe we'll come up with some common building blocks there as well.
|
@zhangguanheng66 In the docstrings, it seems |
I have a PR to update the doc. |
|
@zhangguanheng66 There seems to be some discrepancy compared to the PyTorch implementation, which has |
In pytorch MHA, |
|
@zhangguanheng66 Thanks. I've successfully built a MHA layer using your implementation. Its outputs numerically match the ones of PyTorch implementation (I had to write an auxiliary function to convert the state dict of the latter to the former). |
| value = torch.cat([value, bias_v]) | ||
| if attn_mask is not None: | ||
| _attn_mask = attn_mask | ||
| attn_mask = torch.nn.functional.pad(_attn_mask, (0, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangguanheng66 Why not simply attn_mask = torch.nn.functional.pad(attn_mask, (0, 1))?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update in the revised_mha PR link
@zhangguanheng66 I am trying to understand |
Kind of. From from code view, it pads an extra token in the sequence dimension of key/value. |
| # Dot product of q, k | ||
| attn_output_weights = torch.matmul(query, key.transpose(-2, -1)) | ||
| if attn_mask is not None: | ||
| attn_output_weights.masked_fill_(attn_mask, -1e8,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangguanheng66 To numerically match torch's implementation, this line should change to attn_output_weights.masked_fill_(attn_mask, float('-inf')).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@netw0rkf10w There are some ongoing discussions about NaN output for some special cases. We tried to avoid this when implementing MHA container in torchtext. I believe we will modify this accordingly as pytorch/pytorch#42323 concludes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangguanheng66 Great. Thanks for the information! I'll join that discussion later.
We propose to refactor
nn.MultiheadAttentionmodule as a MHA container:The objective is to add more flexibility to try different MHA variants. The new MHA container is capable of
nn.MultiheadAttentionto MHA container.To initiate
nn.MultiheadAttention:To initiate MHA container:
attn_output_weightsfrom MHA container is output without averaging. Therefore, for the drop-in replacement above, users will need to average the attention output weights in order to have the same results asnn.MultiheadAttention.bias_kandbias_vwill be attached to the sequence dim ofkey/valuequery/key/valuewith more than three dimensions. For example, for some CV applications, the input tensors have four dimensions (N, H, W, C) (link)SharedQK_Projclass below, we can drop the custom in-projection module in MHA container asAnother example is the relative attention implementation introduced in ref. The matrices for relative position distance are added to the the attention layer (see Equation 4 in the reference).
Here is another example to add normalization and dropout in out-projection layer: