Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 2b9b68c

Browse files
author
Guanheng Zhang
committed
add partial broadcast support for ScaledDotProduct. Only allow the batch dim of either query or key/value to be 1
1 parent e81c4b3 commit 2b9b68c

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

test/data/test_models.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,29 @@ def test_multiheadattention(self):
3939
assert_allclose(mha_output, torch_mha_output)
4040
attn_weights = attn_weights.view(bsz, nhead, tgt_len, src_len).sum(dim=1) / nhead
4141
assert_allclose(attn_weights, torch_mha_weights)
42+
43+
def test_broadcast_scaled_dot_product(self):
44+
embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64
45+
SDP = ScaledDotProduct(nhead)
46+
query = torch.rand((tgt_len, 1, embed_dim))
47+
key = value = torch.rand((src_len, 1, embed_dim))
48+
attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool)
49+
50+
sdp_attn_output_full, sdp_attn_weights_full = SDP(query.expand(tgt_len, bsz * nhead, embed_dim),
51+
key.expand(src_len, bsz * nhead, embed_dim),
52+
value.expand(src_len, bsz * nhead, embed_dim),
53+
attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len))
54+
55+
# query has a batch size of 1 while key/value have a batch size of bsz * nhead
56+
sdp_attn_output, sdp_attn_weights = SDP(query, key.expand(src_len, bsz * nhead, embed_dim),
57+
value.expand(src_len, bsz * nhead, embed_dim),
58+
attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len))
59+
assert_allclose(sdp_attn_output, sdp_attn_output_full)
60+
assert_allclose(sdp_attn_weights, sdp_attn_weights_full)
61+
62+
# key/value have a batch size of 1 while query has a batch size of bsz * nhead
63+
sdp_attn_output, sdp_attn_weights = SDP(query.expand(tgt_len, bsz * nhead, embed_dim),
64+
key, value,
65+
attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len))
66+
assert_allclose(sdp_attn_output, sdp_attn_output_full)
67+
assert_allclose(sdp_attn_weights, sdp_attn_weights_full)

torchtext/models/multiheadattention.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,19 +185,24 @@ def forward(self, query, key, value, attn_mask=None):
185185
where L is the target length, S is the source length, H is the number
186186
of attention heads, N is the batch size, and E is the embedding dimension.
187187
"""
188-
tgt_len, batch_heads, head_dim = query.size()
189-
assert query.size(1) == key.size(1) == value.size(1), "Dimension 0 of query, key, value must be equal."
190-
assert batch_heads % self.num_heads == 0, "Dimension 0 of query, key, value must be divisible by num_heads"
188+
tgt_len, head_dim = query.size(-3), query.size(-1)
189+
assert query.size(-1) == key.size(-1) == value.size(-1), "The feature dim of query, key, value must be equal."
191190
assert key.size() == value.size(), "Shape of key, value must match"
192-
assert query.size(-1) == key.size(-1), "The head dimension of query must be equal to that of key"
193-
src_len = key.size(0)
191+
src_len = key.size(-3)
192+
batch_heads = max(query.size(-2), key.size(-2))
194193

195194
# Scale query
196-
query, key, value = query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1)
195+
query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)
197196
query = query * (float(head_dim) ** -0.5)
198197
if attn_mask is not None:
199-
if list(attn_mask.size()) != [batch_heads, tgt_len, src_len]:
200-
raise RuntimeError('The size of the 3D attn_mask is not correct.')
198+
if attn_mask.dim() != 3:
199+
raise RuntimeError('attn_mask must be a 3D tensor.')
200+
print(attn_mask.size(-1), src_len)
201+
print(attn_mask.size(-2), tgt_len)
202+
print(attn_mask.size(-3), batch_heads)
203+
if (attn_mask.size(-1) == src_len) and (attn_mask.size(-2) == tgt_len) and \
204+
(attn_mask.size(-3) == 1 or attn_mask.size(-3) == batch_heads):
205+
raise RuntimeError('The size of the attn_mask is not correct.')
201206
if attn_mask.dtype != torch.bool:
202207
raise RuntimeError('Only bool tensor is supported for attn_mask')
203208

@@ -211,4 +216,4 @@ def forward(self, query, key, value, attn_mask=None):
211216
attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1)
212217
attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training)
213218
attn_output = torch.matmul(attn_output_weights, value)
214-
return attn_output.transpose(0, 1), attn_output_weights
219+
return attn_output.transpose(-2, -3), attn_output_weights

0 commit comments

Comments
 (0)