This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 814
MultiheadAttention building blocks in torchtext #720
Merged
Merged
Changes from all commits
Commits
Show all changes
73 commits
Select commit
Hold shift + click to select a range
a603d7b
add MHA building blocks in torchtext
5d06447
add docs
e9e18cb
combine forward function with functional
36c876a
add models to init
f7f2816
Merge branch 'master' into mha_blocks
bddc782
minor revisions
e665f38
minor change
4a44337
revision
fa36f85
Add unit test
45fed34
update docs
a6a2d94
flake8
b741e1f
add MultiheadAttentionContainer
86a1641
Merge branch 'master' into mha_blocks
ba8cd3a
update models init file
1c35a05
update docs of container
5415568
update MHA test
2055c16
remove in/out projection
9adc723
Switch MultiheadAttentionContainer to accept ScaledDotProduct, Multih…
9e4e0b7
Merge branch 'master' into mha_blocks
f94506a
add JIT support for MHA blocks
f3ed887
standardlize attn_mask
4a38802
update docs
a5bfdee
fix a bug in torchscript test
e81c4b3
add attn_mask in test_multiheadattention and test_torchscript_multihe…
66b71ac
add partial broadcast support for ScaledDotProduct. Only allow the ba…
da1bc7a
add more broadcast tests for scaled dot product model
f681d2d
Merge branch 'master' into mha_blocks
accceeb
add support for incremental decoding
7bd3beb
remove nheads from ScaledDotProduct
14da915
minor fix in jit test
032d749
adjust attn_mask
97bda4e
Merge branch 'master' into mha_blocks
5c1198c
fix jit annotation
c4ccac7
minor
295ab13
refine optional tensor in torchscript
9679022
minor fix in mha test
bc8a75f
remove a few assert statements
3a7d70d
a few changes to for torchscript in python 3
8008798
switch the name from models to modules
6b20f4a
Merge branch 'master' into mha_blocks
e12e131
minor fix
4d3be66
Merge branch 'master' into mha_blocks
659db7a
move reshape to MHA container
a3a21e7
udpdate doc
f7e75d1
minor
11e3027
asserRaises tests in broadcast
5a709a5
fix typo
a90c826
minor fix
0409636
add benchmark case
6c9a7a3
remove bias from test
45d28b5
update benchmark case
4f3b458
add InProjContainer
da4b302
update benchmark
4aeaf5e
minor test
517b921
minor fix
87801e2
flake8
494c1e9
Merge branch 'master' into mha_blocks
3380012
minor docs update
156ee4d
Merge remote-tracking branch 'upstream/master' into mha_blocks
5c51e2c
Merge branch 'master' into mha_blocks
f02073f
add self-attention in the benchmark
9a0a789
update benchmark test with more cases
6e9adf4
Merge remote-tracking branch 'upstream/master' into mha_blocks
941d184
Merge branch 'master' into mha_blocks
c0c3152
Merge branch 'master' into mha_blocks
9f2491a
update attn_mask
8771e3f
add generate_square_subsequent_mask
097f690
Merge branch 'master' into mha_blocks
c958652
Merge branch 'master' into mha_blocks
8b50742
update docs in MHA container
f799ef4
Merge branch 'master' into mha_blocks
496c43b
Merge branch 'master' into mha_blocks
7078c93
add InProjContainer in docs
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| [flake8] | ||
| ignore = E402,E722,W503,W504,F821 | ||
| # E501 is not flexible enough, we're using B950 instead. Consistent with pytorch | ||
| ignore = E402,E722,W503,W504,F821,E501 | ||
| max-line-length = 120 | ||
| exclude = docs/source,third_party |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| import torch | ||
| from torchtext.modules import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct | ||
| from torch.nn.functional import multi_head_attention_forward as mha_forward | ||
| import time | ||
|
|
||
|
|
||
| def benchmark_mha_block(): | ||
|
|
||
| def _run_benchmark(embed_dim, nhead, bsz, device, tgt_len, src_len=None): | ||
| # Build torchtext MultiheadAttention module | ||
| in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), | ||
| torch.nn.Linear(embed_dim, embed_dim), | ||
| torch.nn.Linear(embed_dim, embed_dim)) | ||
| MHA = MultiheadAttentionContainer(nhead, in_proj_container, | ||
| ScaledDotProduct(), | ||
| torch.nn.Linear(embed_dim, embed_dim)).to(device) | ||
|
|
||
| query = torch.rand((tgt_len, bsz, embed_dim)).to(device) | ||
| if src_len is None: | ||
| key = value = query | ||
| src_len = tgt_len | ||
| else: | ||
| key = value = torch.rand((src_len, bsz, embed_dim)).to(device) | ||
| attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool).to(device) | ||
| attn_mask = torch.stack([attn_mask_2D] * (bsz * nhead)) | ||
| bias_k = bias_v = torch.rand((1, 1, embed_dim)).to(device) | ||
| print("starting torchtext.modules.MultiheadAttentionContainer") | ||
| if device == torch.device("cuda"): | ||
| torch.cuda.synchronize() | ||
| t0 = time.monotonic() | ||
| for _ in range(100): | ||
| mha_output, attn_weights = MHA(query, key, value, | ||
| attn_mask=attn_mask, | ||
| bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1), | ||
| bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1)) | ||
| if device == torch.device("cuda"): | ||
| torch.cuda.synchronize() | ||
| print(time.monotonic() - t0) | ||
|
|
||
| # Use torch.nn.functional.multi_head_attention_forward | ||
| torch_attn_mask = torch.zeros((tgt_len, src_len)).to(device).masked_fill_(attn_mask_2D, float('-inf')) | ||
| print("starting torch.nn.functional.multi_head_attention_forward") | ||
| in_proj_weight = torch.cat([MHA.in_proj_container.query_proj.weight, | ||
| MHA.in_proj_container.key_proj.weight, | ||
| MHA.in_proj_container.value_proj.weight]) | ||
| if device == torch.device("cuda"): | ||
| torch.cuda.synchronize() | ||
| t0 = time.monotonic() | ||
| for _ in range(100): | ||
| torch_mha_output, torch_mha_weights = mha_forward(query, key, value, | ||
| embed_dim, nhead, | ||
| in_proj_weight, None, | ||
| bias_k, bias_v, | ||
| False, 0.0, | ||
| MHA.out_proj.weight, | ||
| MHA.out_proj.bias, | ||
| attn_mask=torch_attn_mask) | ||
| if device == torch.device("cuda"): | ||
| torch.cuda.synchronize() | ||
| print(time.monotonic() - t0) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here. |
||
|
|
||
| # GPU test | ||
| device = torch.device("cuda") | ||
| for embed_dim in [64, 768]: | ||
| for nhead in [2, 16]: | ||
| for seq_len in [10, 128, 1000]: | ||
| for bsz in [2, 72]: | ||
| if seq_len == 1000 and bsz == 72: | ||
| continue | ||
| print("*" * 80) | ||
| print("test case GPU with embed_dim, nhead, seq_len, bsz:", | ||
| embed_dim, nhead, seq_len, seq_len, bsz) | ||
| _run_benchmark(embed_dim, nhead, bsz, device, seq_len, seq_len) | ||
|
|
||
| # GPU test for self-attention | ||
| device = torch.device("cuda") | ||
| for embed_dim in [64, 256]: | ||
| for nhead in [2, 16]: | ||
| for seq_len in [10, 128, 1000]: | ||
| for bsz in [2, 72]: | ||
| if seq_len == 1000 and bsz == 72: | ||
| continue | ||
| print("*" * 80) | ||
| print("self-attention test case GPU with embed_dim, nhead, seq_len, bsz:", | ||
| embed_dim, nhead, seq_len, seq_len, bsz) | ||
| _run_benchmark(embed_dim, nhead, bsz, device, seq_len, None) | ||
|
|
||
| # CPU test for self-attention | ||
| device = torch.device("cpu") | ||
| for embed_dim in [64, 768]: | ||
| for nhead in [2, 16]: | ||
| for seq_len in [10, 128, 1000]: | ||
| for bsz in [2, 72]: | ||
| if seq_len == 1000 and bsz == 72: | ||
| continue | ||
| print("*" * 80) | ||
| print("test case CPU with embed_dim, nhead, seq_len, bsz:", | ||
| embed_dim, nhead, seq_len, seq_len, bsz) | ||
| _run_benchmark(embed_dim, nhead, bsz, device, seq_len, None) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| benchmark_mha_block() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| .. role:: hidden | ||
| :class: hidden-section | ||
|
|
||
| torchtext.models.multiheadattention | ||
| ================================== | ||
|
|
||
| .. automodule:: torchtext.models.multiheadattention | ||
| .. currentmodule:: torchtext.models.multiheadattention | ||
|
|
||
| :hidden:`MultiheadAttentionContainer` | ||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| .. autofunction:: MultiheadAttentionContainer | ||
|
|
||
| :hidden:`InProjContainer` | ||
| ~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| .. autofunction:: InProjContainer | ||
|
|
||
| :hidden:`ScaledDotProduct` | ||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| .. autofunction:: ScaledDotProduct |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| import torch | ||
| from torchtext.modules import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct | ||
| from torch.testing import assert_allclose | ||
| from ..common.torchtext_test_case import TorchtextTestCase | ||
|
|
||
|
|
||
| class TestJIT(TorchtextTestCase): | ||
|
|
||
| def test_torchscript_multiheadattention(self): | ||
| embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 | ||
| # Build torchtext MultiheadAttention models | ||
| in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim, bias=False), | ||
| torch.nn.Linear(embed_dim, embed_dim, bias=False), | ||
| torch.nn.Linear(embed_dim, embed_dim, bias=False)) | ||
|
|
||
| MHA = MultiheadAttentionContainer(nhead, in_proj_container, | ||
| ScaledDotProduct(), | ||
| torch.nn.Linear(embed_dim, embed_dim, bias=False)) | ||
| query = torch.rand((tgt_len, bsz, embed_dim)) | ||
| key = value = torch.rand((src_len, bsz, embed_dim)) | ||
| attn_mask = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool) | ||
| attn_mask = torch.stack([attn_mask] * (bsz * nhead)) | ||
| mha_output, attn_weights = MHA(query, key, value, attn_mask=attn_mask) | ||
|
|
||
| ts_MHA = torch.jit.script(MHA) | ||
| ts_mha_output, ts_attn_weights = ts_MHA(query, key, value, attn_mask=attn_mask) | ||
| assert_allclose(mha_output, ts_mha_output) | ||
| assert_allclose(attn_weights, ts_attn_weights) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| import torch | ||
| from torchtext.modules import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct | ||
| from torch.nn.functional import multi_head_attention_forward as mha_forward | ||
| from torch.testing import assert_allclose | ||
| from ..common.torchtext_test_case import TorchtextTestCase | ||
|
|
||
|
|
||
| class TestModels(TorchtextTestCase): | ||
|
|
||
| def test_multiheadattention(self): | ||
| embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 | ||
| # Build torchtext MultiheadAttention module | ||
| in_proj = InProjContainer(torch.nn.Linear(embed_dim, embed_dim, bias=False), | ||
| torch.nn.Linear(embed_dim, embed_dim, bias=False), | ||
| torch.nn.Linear(embed_dim, embed_dim, bias=False)) | ||
|
|
||
| MHA = MultiheadAttentionContainer(nhead, in_proj, | ||
| ScaledDotProduct(), | ||
| torch.nn.Linear(embed_dim, embed_dim, bias=False)) | ||
|
|
||
| query = torch.rand((tgt_len, bsz, embed_dim)) | ||
| key = value = torch.rand((src_len, bsz, embed_dim)) | ||
| attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool) | ||
| bias_k = bias_v = torch.rand((1, 1, embed_dim)) | ||
| mha_output, attn_weights = MHA(query, key, value, | ||
| attn_mask=torch.stack([attn_mask_2D] * (bsz * nhead)), | ||
| bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1), | ||
| bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1)) | ||
|
|
||
| # Use torch.nn.functional.multi_head_attention_forward | ||
| torch_attn_mask = torch.zeros((tgt_len, src_len)).masked_fill_(attn_mask_2D, float('-inf')) | ||
| in_proj_weight = torch.cat([MHA.in_proj_container.query_proj.weight, | ||
| MHA.in_proj_container.key_proj.weight, | ||
| MHA.in_proj_container.value_proj.weight]) | ||
| torch_mha_output, torch_mha_weights = mha_forward(query, key, value, | ||
| embed_dim, nhead, | ||
| in_proj_weight, None, | ||
| bias_k, bias_v, | ||
| False, 0.0, | ||
| MHA.out_proj.weight, None, | ||
| attn_mask=torch_attn_mask) | ||
|
|
||
| assert_allclose(mha_output, torch_mha_output) | ||
| # With bias_k and bias_v, src_len needs to plus 1 | ||
| attn_weights = attn_weights.view(bsz, nhead, tgt_len, src_len + 1).sum(dim=1) / nhead | ||
| assert_allclose(attn_weights, torch_mha_weights) | ||
|
|
||
| def test_broadcast_scaled_dot_product(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd add some "self.assertRaises" to exercises some explicit cases where we expect broadcasting to fail definitely. |
||
| embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 | ||
| SDP = ScaledDotProduct() | ||
| query = torch.rand((tgt_len, 1, embed_dim)) | ||
| key = value = torch.rand((src_len, 1, embed_dim)) | ||
| attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool) | ||
|
|
||
| sdp_attn_output_full, sdp_attn_weights_full = SDP(query.expand(tgt_len, bsz * nhead, embed_dim), | ||
| key.expand(src_len, bsz * nhead, embed_dim), | ||
| value.expand(src_len, bsz * nhead, embed_dim), | ||
| attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) | ||
|
|
||
| # query has a batch size of 1 while key/value have a batch size of bsz * nhead | ||
| sdp_attn_output, sdp_attn_weights = SDP(query, key.expand(src_len, bsz * nhead, embed_dim), | ||
| value.expand(src_len, bsz * nhead, embed_dim), | ||
| attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) | ||
| assert_allclose(sdp_attn_output, sdp_attn_output_full) | ||
| assert_allclose(sdp_attn_weights, sdp_attn_weights_full) | ||
|
|
||
| # key/value have a batch size of 1 while query has a batch size of bsz * nhead | ||
| sdp_attn_output, sdp_attn_weights = SDP(query.expand(tgt_len, bsz * nhead, embed_dim), | ||
| key, value, | ||
| attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) | ||
| assert_allclose(sdp_attn_output, sdp_attn_output_full) | ||
| assert_allclose(sdp_attn_weights, sdp_attn_weights_full) | ||
|
|
||
| # key/value have a size of (3, 3, src_len, bsz * nhead, embed_dim) | ||
| # while query has a size of (tgt_len, 1, embed_dim) | ||
| sdp_attn_output, sdp_attn_weights = SDP(query.expand(tgt_len, 1, embed_dim), | ||
| key.expand(3, 3, src_len, bsz * nhead, embed_dim), | ||
| value.expand(3, 3, src_len, bsz * nhead, embed_dim), | ||
| attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) | ||
| assert list(sdp_attn_output.size()) == [3, 3, tgt_len, bsz * nhead, embed_dim] | ||
| assert list(sdp_attn_weights.size()) == [3, 3, bsz * nhead, tgt_len, embed_dim] | ||
| assert_allclose(sdp_attn_output[2][2], sdp_attn_output_full) | ||
| assert_allclose(sdp_attn_weights[2][2], sdp_attn_weights_full) | ||
| # dim -2 is not equal to neither key/value's dim -2 or 1 | ||
| with self.assertRaises(RuntimeError): | ||
| SDP(query.expand(tgt_len, 2, embed_dim), key.expand(3, 3, src_len, bsz * nhead, embed_dim), | ||
| value.expand(3, 3, src_len, bsz * nhead, embed_dim), | ||
| attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) | ||
|
|
||
| # key/value have a size of (src_len, 1, embed_dim) | ||
| # while query has a size of (1, 2, 3, tgt_len, bsz * nhead, embed_dim) | ||
| sdp_attn_output, sdp_attn_weights = SDP(query.expand(1, 2, 3, tgt_len, bsz * nhead, embed_dim), | ||
| key.expand(src_len, 1, embed_dim), | ||
| value.expand(src_len, 1, embed_dim), | ||
| attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) | ||
| assert list(sdp_attn_output.size()) == [1, 2, 3, tgt_len, bsz * nhead, embed_dim] | ||
| assert list(sdp_attn_weights.size()) == [1, 2, 3, bsz * nhead, tgt_len, embed_dim] | ||
| assert_allclose(sdp_attn_output[0][1][2], sdp_attn_output_full) | ||
| assert_allclose(sdp_attn_weights[0][1][2], sdp_attn_weights_full) | ||
| # key dim -2 is not equal to value dim -2 | ||
| with self.assertRaisesRegex(AssertionError, "Shape of key, value must match"): | ||
| SDP(query.expand(1, 2, 3, tgt_len, bsz * nhead, embed_dim), key.expand(src_len, 2, embed_dim), | ||
| value.expand(src_len, 1, embed_dim), | ||
| attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) | ||
| # key/value dim -2 is not equal to neither query's dim -2 or 1 | ||
| with self.assertRaises(RuntimeError): | ||
| SDP(query.expand(1, 2, 3, tgt_len, bsz * nhead, embed_dim), key.expand(src_len, 2, embed_dim), | ||
| value.expand(src_len, 2, embed_dim), | ||
| attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) | ||
|
|
||
| # attn_mask in a size of (1, tgt_len, src_len) | ||
| # 2D tensor is not supported for attn_mask | ||
| sdp_attn_output, sdp_attn_weights = SDP(query.expand(tgt_len, bsz * nhead, embed_dim), | ||
| key.expand(src_len, bsz * nhead, embed_dim), | ||
| value.expand(src_len, bsz * nhead, embed_dim), | ||
| attn_mask=attn_mask_2D.expand(1, tgt_len, src_len)) | ||
| assert_allclose(sdp_attn_output, sdp_attn_output_full) | ||
| assert_allclose(sdp_attn_weights, sdp_attn_weights_full) | ||
| # attn_mask's dim -3 is not equal to neither batch size or 1 | ||
| with self.assertRaisesRegex(RuntimeError, "The size of the attn_mask is not correct."): | ||
| SDP(query.expand(tgt_len, bsz * nhead, embed_dim), key.expand(src_len, bsz * nhead, embed_dim), | ||
| value.expand(src_len, bsz * nhead, embed_dim), | ||
| attn_mask=attn_mask_2D.expand(2, tgt_len, src_len)) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| from . import data | ||
| from . import modules | ||
| from . import datasets | ||
| from . import utils | ||
| from . import vocab | ||
|
|
@@ -11,6 +12,7 @@ | |
| pass | ||
|
|
||
| __all__ = ['data', | ||
| 'modules', | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why modules and not the torch.nn path convention? |
||
| 'datasets', | ||
| 'utils', | ||
| 'vocab', | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| from .multiheadattention import InProjContainer, \ | ||
| MultiheadAttentionContainer, ScaledDotProduct | ||
|
|
||
| __all__ = ['InProjContainer', | ||
| 'MultiheadAttentionContainer', | ||
| 'ScaledDotProduct'] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you are benchmarking with CUDA, you need to add a
torch.cuda.synchronize()before and after measuring the time, otherwise the timings won't be correctThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Will add them there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason for this is that calls into cuda verions of operations are launched asynchronously. Only when you print a Tensor or convert it onto CPU can you be sure all operations have finished. Using synchronize here helps you make sure indeed all the work has finished and you're timing things correctly. Also see torch.cuda.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangguanheng66 Could you share with us how your implementation performs compared to the PyTorch one after you have fixed the timing? Thanks.