Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 6860a30

Browse files
committed
Fix linting issues
1 parent b699de2 commit 6860a30

File tree

6 files changed

+27
-33
lines changed

6 files changed

+27
-33
lines changed

test/torchtext_unittest/prototype/test_generate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from unittest.mock import patch
2+
3+
import torch
24
from torchtext.prototype.generate import GenerationUtil
35
from torchtext.prototype.models import T5_BASE_GENERATION
46
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase
5-
import torch
67

78

89
class TestGenerationUtil(TorchtextTestCase):

torchtext/prototype/generate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1+
import logging
12
from typing import Optional
23

34
import torch
45
import torch.nn.functional as F
56
from torch import nn
67

7-
import logging
8-
98
logger = logging.getLogger(__name__)
109

1110

torchtext/prototype/models/t5/__init__.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
from .bundler import (
2-
T5_BASE_ENCODER,
2+
T5_11B,
3+
T5_11B_ENCODER,
4+
T5_11B_GENERATION,
5+
T5_3B,
6+
T5_3B_ENCODER,
7+
T5_3B_GENERATION,
38
T5_BASE,
9+
T5_BASE_ENCODER,
410
T5_BASE_GENERATION,
5-
T5_SMALL_ENCODER,
6-
T5_SMALL,
7-
T5_SMALL_GENERATION,
8-
T5_LARGE_ENCODER,
911
T5_LARGE,
12+
T5_LARGE_ENCODER,
1013
T5_LARGE_GENERATION,
11-
T5_3B_ENCODER,
12-
T5_3B,
13-
T5_3B_GENERATION,
14-
T5_11B_ENCODER,
15-
T5_11B,
16-
T5_11B_GENERATION,
14+
T5_SMALL,
15+
T5_SMALL_ENCODER,
16+
T5_SMALL_GENERATION,
1717
T5Bundle,
1818
)
1919
from .model import T5Conf, T5Model

torchtext/prototype/models/t5/model.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1+
# logging library is not automatically supported by Torchscript
2+
import warnings
13
from dataclasses import dataclass
2-
from typing import Dict, List, Optional, Union, Callable
4+
from typing import Callable, Dict, List, Optional, Union
35

46
import torch
57
import torch.nn as nn
68
from torch import Tensor
79

8-
from .modules import T5Encoder, T5Decoder
9-
10-
# logging library is not automatically supported by Torchscript
11-
import warnings
10+
from .modules import T5Decoder, T5Encoder
1211

1312

1413
@dataclass(frozen=True)
@@ -194,7 +193,6 @@ def forward(
194193
)
195194

196195
if not self.encoder_only:
197-
198196
assert self.decoder is not None
199197
assert encoder_outputs is not None
200198

@@ -238,7 +236,7 @@ def forward(
238236
# Rescale output before projecting on vocab. This happens when the encoder and decoder share the
239237
# same word embeddings, which is always the case in our t5 implementation.
240238
# See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661
241-
decoder_output = decoder_output * (self.embedding_dim**-0.5)
239+
decoder_output = decoder_output * (self.embedding_dim ** -0.5)
242240
decoder_output = self.lm_head(decoder_output)
243241
decoder_outputs["decoder_output"] = decoder_output
244242

torchtext/prototype/models/t5/modules.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import math
1717
import warnings
18-
from typing import Dict, List, Optional, Tuple, Union, Callable
18+
from typing import Callable, Dict, List, Optional, Tuple, Union
1919

2020
import torch
2121
import torch.nn as nn
@@ -843,7 +843,7 @@ def forward(
843843
embedded_tgt: Optional[Tensor] = None,
844844
) -> Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]]:
845845
r"""Pass the input (and masks) through the stack of encoder layers.
846-
846+
847847
Args:
848848
tgt (Optional[Tensor]): Tokenized input sequence to the encoder.
849849
Must be batch first with shape (B, Ne) where B is the batch size and Ne is the
@@ -857,7 +857,7 @@ def forward(
857857
length, and E is the model dimension.
858858
*Note*: If you do not provide this `embedded_tgt`, you must have provided a `token_embedding` layer \
859859
in the initialization of the T5Encoder.
860-
860+
861861
Returns:
862862
Tuple of last hidden layer, all hidden layers, position bias, and self-attention scores
863863
"""

torchtext/prototype/models/t5/wrapper.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
1+
import warnings
12
from typing import Any, Dict, List, Optional, Tuple, Union
23

34
import torch
45
import torch.nn as nn
56
import torch.nn.functional as F
67
from torch import Tensor
78
from torchtext.prototype.models import (
9+
T5_11B_GENERATION,
10+
T5_3B_GENERATION,
811
T5_BASE_GENERATION,
9-
T5_SMALL_GENERATION,
1012
T5_LARGE_GENERATION,
11-
T5_3B_GENERATION,
12-
T5_11B_GENERATION,
13+
T5_SMALL_GENERATION,
14+
T5Bundle,
1315
T5Conf,
1416
T5Transform,
15-
T5Bundle,
1617
)
1718

18-
import warnings
19-
2019

2120
BUNDLERS = {
2221
"base": T5_BASE_GENERATION,
@@ -139,7 +138,6 @@ def beam_search(
139138
return new_decoder_tokens, new_scores, new_incomplete_sentences
140139

141140
def generate(self, encoder_tokens: Tensor, beam_size: int, eos_idx: int = 1, max_seq_len: int = 512) -> Tensor:
142-
143141
# pass tokens through encoder
144142
bsz = encoder_tokens.size(0)
145143
encoder = self.model.get_encoder()
@@ -155,7 +153,6 @@ def generate(self, encoder_tokens: Tensor, beam_size: int, eos_idx: int = 1, max
155153

156154
# iteratively generate output sequence until all sequences in the batch have generated the end-of-sequence token
157155
for step in range(max_seq_len):
158-
159156
if step == 1:
160157
# duplicate and order encoder output so that each beam is treated as its own independent sequence
161158
encoder_output = encoder_outputs.get("encoder_output")
@@ -189,7 +186,6 @@ def generate(self, encoder_tokens: Tensor, beam_size: int, eos_idx: int = 1, max
189186
return decoder_tokens
190187

191188
def forward(self, input_text: List[str], beam_size: int, max_seq_len: int) -> Union[List[str], str]:
192-
193189
model_input = self.transform(input_text)
194190
model_output_tensor = self.generate(encoder_tokens=model_input, beam_size=beam_size, max_seq_len=max_seq_len)
195191
model_output_list = torch.jit.annotate(List[List[int]], model_output_tensor.tolist())

0 commit comments

Comments
 (0)