Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit d6d7f20

Browse files
Update docs for torchtext.nn.InProjContainer (#1083)
1 parent 633548a commit d6d7f20

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

torchtext/nn/modules/multiheadattention.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,14 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
200200

201201
class InProjContainer(torch.nn.Module):
202202
def __init__(self, query_proj, key_proj, value_proj):
203-
r"""A in-proj container to process inputs.
203+
r"""A in-proj container to project query/key/value in MultiheadAttention. This module happens before reshaping
204+
the projected query/key/value into multiple heads. See the linear layers (bottom) of Multi-head Attention in
205+
Fig 2 of Attention Is All You Need paper. Also check the usage example in torchtext.nn.MultiheadAttentionContainer.
204206
205207
Args:
206-
query_proj: a proj layer for query.
207-
key_proj: a proj layer for key.
208-
value_proj: a proj layer for value.
209-
208+
query_proj: a proj layer for query. A typical projection layer is torch.nn.Linear.
209+
key_proj: a proj layer for key. A typical projection layer is torch.nn.Linear.
210+
value_proj: a proj layer for value. A typical projection layer is torch.nn.Linear.
210211
"""
211212

212213
super(InProjContainer, self).__init__()
@@ -218,16 +219,21 @@ def forward(self,
218219
query: torch.Tensor,
219220
key: torch.Tensor,
220221
value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
221-
r"""Projects the input sequences using in-proj layers.
222+
r"""Projects the input sequences using in-proj layers. query/key/value are simply passed to
223+
the forward func of query/key/value_proj, respectively.
222224
223225
Args:
224226
query, key, value (Tensors): sequence to be projected
225227
226-
Shape:
227-
- query, key, value: :math:`(S, N, E)`
228-
- Output: :math:`(S, N, E)`.
229-
230-
Note: S is the sequence length, N is the batch size, and E is the embedding dimension.
228+
Examples::
229+
>>> from torchtext.nn import InProjContainer
230+
>>> embed_dim, bsz = 10, 64
231+
>>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim),
232+
torch.nn.Linear(embed_dim, embed_dim),
233+
torch.nn.Linear(embed_dim, embed_dim))
234+
>>> q = torch.rand((5, bsz, embed_dim))
235+
>>> k = v = torch.rand((6, bsz, embed_dim))
236+
>>> q, k, v = in_proj_container(q, k, v)
231237
232238
"""
233239
return self.query_proj(query), self.key_proj(key), self.value_proj(value)

0 commit comments

Comments
 (0)