Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 667321b

Browse files
committed
merge changes
2 parents 16123f0 + d857e4f commit 667321b

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

torchtext/models/roberta/modules.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,6 @@ def __init__(
286286
def forward(self, tokens: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Union[torch.Tensor, List[torch.Tensor]]:
287287
if attn_mask is not None:
288288
torch._assert(attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}")
289-
290289
padding_mask = tokens.eq(self.padding_idx)
291290

292291
token_embeddings = self.token_embedding(tokens)
@@ -325,4 +324,4 @@ def forward(self, tokens: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
325324
encoded = self.embedding_layer_norm(encoded)
326325

327326
# states are returned as T x B x C
328-
return encoded
327+
return encoded

0 commit comments

Comments
 (0)