@@ -200,13 +200,14 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
200200
201201class InProjContainer (torch .nn .Module ):
202202 def __init__ (self , query_proj , key_proj , value_proj ):
203- r"""A in-proj container to process inputs.
203+ r"""A in-proj container to project query/key/value in MultiheadAttention. This module happens before reshaping
204+ the projected query/key/value into multiple heads. See the linear layers (bottom) of Multi-head Attention in
205+ Fig 2 of Attention Is All You Need paper. Also check the usage example in torchtext.nn.MultiheadAttentionContainer.
204206
205207 Args:
206- query_proj: a proj layer for query.
207- key_proj: a proj layer for key.
208- value_proj: a proj layer for value.
209-
208+ query_proj: a proj layer for query. A typical projection layer is torch.nn.Linear.
209+ key_proj: a proj layer for key. A typical projection layer is torch.nn.Linear.
210+ value_proj: a proj layer for value. A typical projection layer is torch.nn.Linear.
210211 """
211212
212213 super (InProjContainer , self ).__init__ ()
@@ -218,16 +219,21 @@ def forward(self,
218219 query : torch .Tensor ,
219220 key : torch .Tensor ,
220221 value : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
221- r"""Projects the input sequences using in-proj layers.
222+ r"""Projects the input sequences using in-proj layers. query/key/value are simply passed to
223+ the forward func of query/key/value_proj, respectively.
222224
223225 Args:
224226 query, key, value (Tensors): sequence to be projected
225227
226- Shape:
227- - query, key, value: :math:`(S, N, E)`
228- - Output: :math:`(S, N, E)`.
229-
230- Note: S is the sequence length, N is the batch size, and E is the embedding dimension.
228+ Examples::
229+ >>> from torchtext.nn import InProjContainer
230+ >>> embed_dim, bsz = 10, 64
231+ >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim),
232+ torch.nn.Linear(embed_dim, embed_dim),
233+ torch.nn.Linear(embed_dim, embed_dim))
234+ >>> q = torch.rand((5, bsz, embed_dim))
235+ >>> k = v = torch.rand((6, bsz, embed_dim))
236+ >>> q, k, v = in_proj_container(q, k, v)
231237
232238 """
233239 return self .query_proj (query ), self .key_proj (key ), self .value_proj (value )
0 commit comments