Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
a603d7b
add MHA building blocks in torchtext
Apr 2, 2020
5d06447
add docs
Apr 2, 2020
e9e18cb
combine forward function with functional
Apr 2, 2020
36c876a
add models to init
Apr 2, 2020
f7f2816
Merge branch 'master' into mha_blocks
Apr 2, 2020
bddc782
minor revisions
Apr 2, 2020
e665f38
minor change
Apr 2, 2020
4a44337
revision
Apr 2, 2020
fa36f85
Add unit test
Apr 2, 2020
45fed34
update docs
Apr 2, 2020
a6a2d94
flake8
Apr 2, 2020
b741e1f
add MultiheadAttentionContainer
Apr 8, 2020
86a1641
Merge branch 'master' into mha_blocks
Apr 8, 2020
ba8cd3a
update models init file
Apr 8, 2020
1c35a05
update docs of container
Apr 8, 2020
5415568
update MHA test
Apr 8, 2020
2055c16
remove in/out projection
Apr 13, 2020
9adc723
Switch MultiheadAttentionContainer to accept ScaledDotProduct, Multih…
Apr 15, 2020
9e4e0b7
Merge branch 'master' into mha_blocks
Apr 15, 2020
f94506a
add JIT support for MHA blocks
Apr 15, 2020
f3ed887
standardlize attn_mask
Apr 15, 2020
4a38802
update docs
Apr 15, 2020
a5bfdee
fix a bug in torchscript test
Apr 15, 2020
e81c4b3
add attn_mask in test_multiheadattention and test_torchscript_multihe…
Apr 16, 2020
66b71ac
add partial broadcast support for ScaledDotProduct. Only allow the ba…
Apr 16, 2020
da1bc7a
add more broadcast tests for scaled dot product model
Apr 17, 2020
f681d2d
Merge branch 'master' into mha_blocks
Apr 23, 2020
accceeb
add support for incremental decoding
Apr 23, 2020
7bd3beb
remove nheads from ScaledDotProduct
Apr 23, 2020
14da915
minor fix in jit test
Apr 23, 2020
032d749
adjust attn_mask
Apr 23, 2020
97bda4e
Merge branch 'master' into mha_blocks
Apr 23, 2020
5c1198c
fix jit annotation
Apr 23, 2020
c4ccac7
minor
Apr 23, 2020
295ab13
refine optional tensor in torchscript
Apr 23, 2020
9679022
minor fix in mha test
Apr 23, 2020
bc8a75f
remove a few assert statements
Apr 24, 2020
3a7d70d
a few changes to for torchscript in python 3
Apr 24, 2020
8008798
switch the name from models to modules
Apr 24, 2020
6b20f4a
Merge branch 'master' into mha_blocks
Apr 24, 2020
e12e131
minor fix
Apr 24, 2020
4d3be66
Merge branch 'master' into mha_blocks
Apr 27, 2020
659db7a
move reshape to MHA container
Apr 27, 2020
a3a21e7
udpdate doc
Apr 27, 2020
f7e75d1
minor
Apr 27, 2020
11e3027
asserRaises tests in broadcast
Apr 27, 2020
5a709a5
fix typo
Apr 28, 2020
a90c826
minor fix
Apr 28, 2020
0409636
add benchmark case
Apr 28, 2020
6c9a7a3
remove bias from test
Apr 29, 2020
45d28b5
update benchmark case
Apr 29, 2020
4f3b458
add InProjContainer
Apr 29, 2020
da4b302
update benchmark
Apr 29, 2020
4aeaf5e
minor test
Apr 29, 2020
517b921
minor fix
Apr 29, 2020
87801e2
flake8
Apr 29, 2020
494c1e9
Merge branch 'master' into mha_blocks
May 1, 2020
3380012
minor docs update
May 4, 2020
156ee4d
Merge remote-tracking branch 'upstream/master' into mha_blocks
May 4, 2020
5c51e2c
Merge branch 'master' into mha_blocks
May 4, 2020
f02073f
add self-attention in the benchmark
May 5, 2020
9a0a789
update benchmark test with more cases
May 7, 2020
6e9adf4
Merge remote-tracking branch 'upstream/master' into mha_blocks
May 13, 2020
941d184
Merge branch 'master' into mha_blocks
May 15, 2020
c0c3152
Merge branch 'master' into mha_blocks
May 15, 2020
9f2491a
update attn_mask
May 15, 2020
8771e3f
add generate_square_subsequent_mask
May 15, 2020
097f690
Merge branch 'master' into mha_blocks
May 18, 2020
c958652
Merge branch 'master' into mha_blocks
May 19, 2020
8b50742
update docs in MHA container
May 21, 2020
f799ef4
Merge branch 'master' into mha_blocks
Jun 2, 2020
496c43b
Merge branch 'master' into mha_blocks
Jun 5, 2020
7078c93
add InProjContainer in docs
Jun 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[flake8]
ignore = E402,E722,W503,W504,F821
# E501 is not flexible enough, we're using B950 instead. Consistent with pytorch
ignore = E402,E722,W503,W504,F821,E501
max-line-length = 120
exclude = docs/source,third_party
103 changes: 103 additions & 0 deletions benchmark/mha_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
from torchtext.modules import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct
from torch.nn.functional import multi_head_attention_forward as mha_forward
import time


def benchmark_mha_block():

def _run_benchmark(embed_dim, nhead, bsz, device, tgt_len, src_len=None):
# Build torchtext MultiheadAttention module
in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim),
torch.nn.Linear(embed_dim, embed_dim),
torch.nn.Linear(embed_dim, embed_dim))
MHA = MultiheadAttentionContainer(nhead, in_proj_container,
ScaledDotProduct(),
torch.nn.Linear(embed_dim, embed_dim)).to(device)

query = torch.rand((tgt_len, bsz, embed_dim)).to(device)
if src_len is None:
key = value = query
src_len = tgt_len
else:
key = value = torch.rand((src_len, bsz, embed_dim)).to(device)
attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool).to(device)
attn_mask = torch.stack([attn_mask_2D] * (bsz * nhead))
bias_k = bias_v = torch.rand((1, 1, embed_dim)).to(device)
print("starting torchtext.modules.MultiheadAttentionContainer")
if device == torch.device("cuda"):
torch.cuda.synchronize()
t0 = time.monotonic()
for _ in range(100):
mha_output, attn_weights = MHA(query, key, value,
attn_mask=attn_mask,
bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1),
bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1))
if device == torch.device("cuda"):
torch.cuda.synchronize()
print(time.monotonic() - t0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are benchmarking with CUDA, you need to add a torch.cuda.synchronize() before and after measuring the time, otherwise the timings won't be correct

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Will add them there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for this is that calls into cuda verions of operations are launched asynchronously. Only when you print a Tensor or convert it onto CPU can you be sure all operations have finished. Using synchronize here helps you make sure indeed all the work has finished and you're timing things correctly. Also see torch.cuda.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 Could you share with us how your implementation performs compared to the PyTorch one after you have fixed the timing? Thanks.


# Use torch.nn.functional.multi_head_attention_forward
torch_attn_mask = torch.zeros((tgt_len, src_len)).to(device).masked_fill_(attn_mask_2D, float('-inf'))
print("starting torch.nn.functional.multi_head_attention_forward")
in_proj_weight = torch.cat([MHA.in_proj_container.query_proj.weight,
MHA.in_proj_container.key_proj.weight,
MHA.in_proj_container.value_proj.weight])
if device == torch.device("cuda"):
torch.cuda.synchronize()
t0 = time.monotonic()
for _ in range(100):
torch_mha_output, torch_mha_weights = mha_forward(query, key, value,
embed_dim, nhead,
in_proj_weight, None,
bias_k, bias_v,
False, 0.0,
MHA.out_proj.weight,
MHA.out_proj.bias,
attn_mask=torch_attn_mask)
if device == torch.device("cuda"):
torch.cuda.synchronize()
print(time.monotonic() - t0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here.


# GPU test
device = torch.device("cuda")
for embed_dim in [64, 768]:
for nhead in [2, 16]:
for seq_len in [10, 128, 1000]:
for bsz in [2, 72]:
if seq_len == 1000 and bsz == 72:
continue
print("*" * 80)
print("test case GPU with embed_dim, nhead, seq_len, bsz:",
embed_dim, nhead, seq_len, seq_len, bsz)
_run_benchmark(embed_dim, nhead, bsz, device, seq_len, seq_len)

# GPU test for self-attention
device = torch.device("cuda")
for embed_dim in [64, 256]:
for nhead in [2, 16]:
for seq_len in [10, 128, 1000]:
for bsz in [2, 72]:
if seq_len == 1000 and bsz == 72:
continue
print("*" * 80)
print("self-attention test case GPU with embed_dim, nhead, seq_len, bsz:",
embed_dim, nhead, seq_len, seq_len, bsz)
_run_benchmark(embed_dim, nhead, bsz, device, seq_len, None)

# CPU test for self-attention
device = torch.device("cpu")
for embed_dim in [64, 768]:
for nhead in [2, 16]:
for seq_len in [10, 128, 1000]:
for bsz in [2, 72]:
if seq_len == 1000 and bsz == 72:
continue
print("*" * 80)
print("test case CPU with embed_dim, nhead, seq_len, bsz:",
embed_dim, nhead, seq_len, seq_len, bsz)
_run_benchmark(embed_dim, nhead, bsz, device, seq_len, None)


if __name__ == "__main__":
benchmark_mha_block()
23 changes: 23 additions & 0 deletions docs/source/modules.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. role:: hidden
:class: hidden-section

torchtext.models.multiheadattention
==================================

.. automodule:: torchtext.models.multiheadattention
.. currentmodule:: torchtext.models.multiheadattention

:hidden:`MultiheadAttentionContainer`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: MultiheadAttentionContainer

:hidden:`InProjContainer`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: InProjContainer

:hidden:`ScaledDotProduct`
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: ScaledDotProduct
28 changes: 28 additions & 0 deletions test/data/test_jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
from torchtext.modules import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct
from torch.testing import assert_allclose
from ..common.torchtext_test_case import TorchtextTestCase


class TestJIT(TorchtextTestCase):

def test_torchscript_multiheadattention(self):
embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64
# Build torchtext MultiheadAttention models
in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim, bias=False),
torch.nn.Linear(embed_dim, embed_dim, bias=False),
torch.nn.Linear(embed_dim, embed_dim, bias=False))

MHA = MultiheadAttentionContainer(nhead, in_proj_container,
ScaledDotProduct(),
torch.nn.Linear(embed_dim, embed_dim, bias=False))
query = torch.rand((tgt_len, bsz, embed_dim))
key = value = torch.rand((src_len, bsz, embed_dim))
attn_mask = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool)
attn_mask = torch.stack([attn_mask] * (bsz * nhead))
mha_output, attn_weights = MHA(query, key, value, attn_mask=attn_mask)

ts_MHA = torch.jit.script(MHA)
ts_mha_output, ts_attn_weights = ts_MHA(query, key, value, attn_mask=attn_mask)
assert_allclose(mha_output, ts_mha_output)
assert_allclose(attn_weights, ts_attn_weights)
123 changes: 123 additions & 0 deletions test/data/test_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import torch
from torchtext.modules import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct
from torch.nn.functional import multi_head_attention_forward as mha_forward
from torch.testing import assert_allclose
from ..common.torchtext_test_case import TorchtextTestCase


class TestModels(TorchtextTestCase):

def test_multiheadattention(self):
embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64
# Build torchtext MultiheadAttention module
in_proj = InProjContainer(torch.nn.Linear(embed_dim, embed_dim, bias=False),
torch.nn.Linear(embed_dim, embed_dim, bias=False),
torch.nn.Linear(embed_dim, embed_dim, bias=False))

MHA = MultiheadAttentionContainer(nhead, in_proj,
ScaledDotProduct(),
torch.nn.Linear(embed_dim, embed_dim, bias=False))

query = torch.rand((tgt_len, bsz, embed_dim))
key = value = torch.rand((src_len, bsz, embed_dim))
attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool)
bias_k = bias_v = torch.rand((1, 1, embed_dim))
mha_output, attn_weights = MHA(query, key, value,
attn_mask=torch.stack([attn_mask_2D] * (bsz * nhead)),
bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1),
bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1))

# Use torch.nn.functional.multi_head_attention_forward
torch_attn_mask = torch.zeros((tgt_len, src_len)).masked_fill_(attn_mask_2D, float('-inf'))
in_proj_weight = torch.cat([MHA.in_proj_container.query_proj.weight,
MHA.in_proj_container.key_proj.weight,
MHA.in_proj_container.value_proj.weight])
torch_mha_output, torch_mha_weights = mha_forward(query, key, value,
embed_dim, nhead,
in_proj_weight, None,
bias_k, bias_v,
False, 0.0,
MHA.out_proj.weight, None,
attn_mask=torch_attn_mask)

assert_allclose(mha_output, torch_mha_output)
# With bias_k and bias_v, src_len needs to plus 1
attn_weights = attn_weights.view(bsz, nhead, tgt_len, src_len + 1).sum(dim=1) / nhead
assert_allclose(attn_weights, torch_mha_weights)

def test_broadcast_scaled_dot_product(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add some "self.assertRaises" to exercises some explicit cases where we expect broadcasting to fail definitely.

embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64
SDP = ScaledDotProduct()
query = torch.rand((tgt_len, 1, embed_dim))
key = value = torch.rand((src_len, 1, embed_dim))
attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool)

sdp_attn_output_full, sdp_attn_weights_full = SDP(query.expand(tgt_len, bsz * nhead, embed_dim),
key.expand(src_len, bsz * nhead, embed_dim),
value.expand(src_len, bsz * nhead, embed_dim),
attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len))

# query has a batch size of 1 while key/value have a batch size of bsz * nhead
sdp_attn_output, sdp_attn_weights = SDP(query, key.expand(src_len, bsz * nhead, embed_dim),
value.expand(src_len, bsz * nhead, embed_dim),
attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len))
assert_allclose(sdp_attn_output, sdp_attn_output_full)
assert_allclose(sdp_attn_weights, sdp_attn_weights_full)

# key/value have a batch size of 1 while query has a batch size of bsz * nhead
sdp_attn_output, sdp_attn_weights = SDP(query.expand(tgt_len, bsz * nhead, embed_dim),
key, value,
attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len))
assert_allclose(sdp_attn_output, sdp_attn_output_full)
assert_allclose(sdp_attn_weights, sdp_attn_weights_full)

# key/value have a size of (3, 3, src_len, bsz * nhead, embed_dim)
# while query has a size of (tgt_len, 1, embed_dim)
sdp_attn_output, sdp_attn_weights = SDP(query.expand(tgt_len, 1, embed_dim),
key.expand(3, 3, src_len, bsz * nhead, embed_dim),
value.expand(3, 3, src_len, bsz * nhead, embed_dim),
attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len))
assert list(sdp_attn_output.size()) == [3, 3, tgt_len, bsz * nhead, embed_dim]
assert list(sdp_attn_weights.size()) == [3, 3, bsz * nhead, tgt_len, embed_dim]
assert_allclose(sdp_attn_output[2][2], sdp_attn_output_full)
assert_allclose(sdp_attn_weights[2][2], sdp_attn_weights_full)
# dim -2 is not equal to neither key/value's dim -2 or 1
with self.assertRaises(RuntimeError):
SDP(query.expand(tgt_len, 2, embed_dim), key.expand(3, 3, src_len, bsz * nhead, embed_dim),
value.expand(3, 3, src_len, bsz * nhead, embed_dim),
attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len))

# key/value have a size of (src_len, 1, embed_dim)
# while query has a size of (1, 2, 3, tgt_len, bsz * nhead, embed_dim)
sdp_attn_output, sdp_attn_weights = SDP(query.expand(1, 2, 3, tgt_len, bsz * nhead, embed_dim),
key.expand(src_len, 1, embed_dim),
value.expand(src_len, 1, embed_dim),
attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len))
assert list(sdp_attn_output.size()) == [1, 2, 3, tgt_len, bsz * nhead, embed_dim]
assert list(sdp_attn_weights.size()) == [1, 2, 3, bsz * nhead, tgt_len, embed_dim]
assert_allclose(sdp_attn_output[0][1][2], sdp_attn_output_full)
assert_allclose(sdp_attn_weights[0][1][2], sdp_attn_weights_full)
# key dim -2 is not equal to value dim -2
with self.assertRaisesRegex(AssertionError, "Shape of key, value must match"):
SDP(query.expand(1, 2, 3, tgt_len, bsz * nhead, embed_dim), key.expand(src_len, 2, embed_dim),
value.expand(src_len, 1, embed_dim),
attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len))
# key/value dim -2 is not equal to neither query's dim -2 or 1
with self.assertRaises(RuntimeError):
SDP(query.expand(1, 2, 3, tgt_len, bsz * nhead, embed_dim), key.expand(src_len, 2, embed_dim),
value.expand(src_len, 2, embed_dim),
attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len))

# attn_mask in a size of (1, tgt_len, src_len)
# 2D tensor is not supported for attn_mask
sdp_attn_output, sdp_attn_weights = SDP(query.expand(tgt_len, bsz * nhead, embed_dim),
key.expand(src_len, bsz * nhead, embed_dim),
value.expand(src_len, bsz * nhead, embed_dim),
attn_mask=attn_mask_2D.expand(1, tgt_len, src_len))
assert_allclose(sdp_attn_output, sdp_attn_output_full)
assert_allclose(sdp_attn_weights, sdp_attn_weights_full)
# attn_mask's dim -3 is not equal to neither batch size or 1
with self.assertRaisesRegex(RuntimeError, "The size of the attn_mask is not correct."):
SDP(query.expand(tgt_len, bsz * nhead, embed_dim), key.expand(src_len, bsz * nhead, embed_dim),
value.expand(src_len, bsz * nhead, embed_dim),
attn_mask=attn_mask_2D.expand(2, tgt_len, src_len))
2 changes: 2 additions & 0 deletions torchtext/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import data
from . import modules
from . import datasets
from . import utils
from . import vocab
Expand All @@ -11,6 +12,7 @@
pass

__all__ = ['data',
'modules',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why modules and not the torch.nn path convention?

'datasets',
'utils',
'vocab',
Expand Down
6 changes: 6 additions & 0 deletions torchtext/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .multiheadattention import InProjContainer, \
MultiheadAttentionContainer, ScaledDotProduct

__all__ = ['InProjContainer',
'MultiheadAttentionContainer',
'ScaledDotProduct']
Loading