1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from typing import Optional
14+ from typing import cast , Dict , Iterator , List , Optional , Tuple , Union
1515
1616import torch
1717import torch .nn as nn
1818import torch .nn .functional as F
19+ from torch import Tensor
20+ from torch .optim import Optimizer
21+ from torch .optim .lr_scheduler import _LRScheduler
1922from torch .utils .data import DataLoader , Dataset , IterableDataset , Subset
2023
2124from pytorch_lightning import LightningDataModule , LightningModule
25+ from pytorch_lightning .core .optimizer import LightningOptimizer
26+ from pytorch_lightning .utilities .types import EPOCH_OUTPUT , STEP_OUTPUT
2227
2328
2429class RandomDictDataset (Dataset ):
2530 def __init__ (self , size : int , length : int ):
2631 self .len = length
2732 self .data = torch .randn (length , size )
2833
29- def __getitem__ (self , index ) :
34+ def __getitem__ (self , index : int ) -> Dict [ str , Tensor ] :
3035 a = self .data [index ]
3136 b = a + 2
3237 return {"a" : a , "b" : b }
@@ -40,7 +45,7 @@ def __init__(self, size: int, length: int):
4045 self .len = length
4146 self .data = torch .randn (length , size )
4247
43- def __getitem__ (self , index ) :
48+ def __getitem__ (self , index : int ) -> Tensor :
4449 return self .data [index ]
4550
4651 def __len__ (self ) -> int :
@@ -52,7 +57,7 @@ def __init__(self, size: int, count: int):
5257 self .count = count
5358 self .size = size
5459
55- def __iter__ (self ):
60+ def __iter__ (self ) -> Iterator [ Tensor ] :
5661 for _ in range (self .count ):
5762 yield torch .randn (self .size )
5863
@@ -62,16 +67,16 @@ def __init__(self, size: int, count: int):
6267 self .count = count
6368 self .size = size
6469
65- def __iter__ (self ):
70+ def __iter__ (self ) -> Iterator [ Tensor ] :
6671 for _ in range (len (self )):
6772 yield torch .randn (self .size )
6873
69- def __len__ (self ):
74+ def __len__ (self ) -> int :
7075 return self .count
7176
7277
7378class BoringModel (LightningModule ):
74- def __init__ (self ):
79+ def __init__ (self ) -> None :
7580 """Testing PL Module.
7681
7782 Use as follows:
@@ -90,60 +95,63 @@ def training_step(...):
9095 super ().__init__ ()
9196 self .layer = torch .nn .Linear (32 , 2 )
9297
93- def forward (self , x ):
98+ def forward (self , x : Tensor ) -> Tensor : # type: ignore[override]
9499 return self .layer (x )
95100
96- def loss (self , batch , preds ) :
101+ def loss (self , batch : Tensor , preds : Tensor ) -> Tensor :
97102 # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
98103 return torch .nn .functional .mse_loss (preds , torch .ones_like (preds ))
99104
100- def step (self , x ) :
105+ def step (self , x : Tensor ) -> Tensor :
101106 x = self (x )
102107 out = torch .nn .functional .mse_loss (x , torch .ones_like (x ))
103108 return out
104109
105- def training_step (self , batch , batch_idx ):
110+ def training_step (self , batch : Tensor , batch_idx : int ) -> STEP_OUTPUT : # type: ignore[override]
106111 output = self (batch )
107112 loss = self .loss (batch , output )
108113 return {"loss" : loss }
109114
110- def training_step_end (self , training_step_outputs ) :
115+ def training_step_end (self , training_step_outputs : STEP_OUTPUT ) -> STEP_OUTPUT :
111116 return training_step_outputs
112117
113- def training_epoch_end (self , outputs ) -> None :
118+ def training_epoch_end (self , outputs : EPOCH_OUTPUT ) -> None :
119+ outputs = cast (List [Dict [str , Tensor ]], outputs )
114120 torch .stack ([x ["loss" ] for x in outputs ]).mean ()
115121
116- def validation_step (self , batch , batch_idx ):
122+ def validation_step (self , batch : Tensor , batch_idx : int ) -> Optional [ STEP_OUTPUT ]: # type: ignore[override]
117123 output = self (batch )
118124 loss = self .loss (batch , output )
119125 return {"x" : loss }
120126
121- def validation_epoch_end (self , outputs ) -> None :
127+ def validation_epoch_end (self , outputs : Union [EPOCH_OUTPUT , List [EPOCH_OUTPUT ]]) -> None :
128+ outputs = cast (List [Dict [str , Tensor ]], outputs )
122129 torch .stack ([x ["x" ] for x in outputs ]).mean ()
123130
124- def test_step (self , batch , batch_idx ):
131+ def test_step (self , batch : Tensor , batch_idx : int ) -> Optional [ STEP_OUTPUT ]: # type: ignore[override]
125132 output = self (batch )
126133 loss = self .loss (batch , output )
127134 return {"y" : loss }
128135
129- def test_epoch_end (self , outputs ) -> None :
136+ def test_epoch_end (self , outputs : Union [EPOCH_OUTPUT , List [EPOCH_OUTPUT ]]) -> None :
137+ outputs = cast (List [Dict [str , Tensor ]], outputs )
130138 torch .stack ([x ["y" ] for x in outputs ]).mean ()
131139
132- def configure_optimizers (self ):
140+ def configure_optimizers (self ) -> Tuple [ List [ torch . optim . Optimizer ], List [ _LRScheduler ]] :
133141 optimizer = torch .optim .SGD (self .layer .parameters (), lr = 0.1 )
134142 lr_scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = 1 )
135143 return [optimizer ], [lr_scheduler ]
136144
137- def train_dataloader (self ):
145+ def train_dataloader (self ) -> DataLoader :
138146 return DataLoader (RandomDataset (32 , 64 ))
139147
140- def val_dataloader (self ):
148+ def val_dataloader (self ) -> DataLoader :
141149 return DataLoader (RandomDataset (32 , 64 ))
142150
143- def test_dataloader (self ):
151+ def test_dataloader (self ) -> DataLoader :
144152 return DataLoader (RandomDataset (32 , 64 ))
145153
146- def predict_dataloader (self ):
154+ def predict_dataloader (self ) -> DataLoader :
147155 return DataLoader (RandomDataset (32 , 64 ))
148156
149157
@@ -155,7 +163,7 @@ def __init__(self, data_dir: str = "./"):
155163 self .checkpoint_state : Optional [str ] = None
156164 self .random_full = RandomDataset (32 , 64 * 4 )
157165
158- def setup (self , stage : Optional [str ] = None ):
166+ def setup (self , stage : Optional [str ] = None ) -> None :
159167 if stage == "fit" or stage is None :
160168 self .random_train = Subset (self .random_full , indices = range (64 ))
161169
@@ -168,26 +176,27 @@ def setup(self, stage: Optional[str] = None):
168176 if stage == "predict" or stage is None :
169177 self .random_predict = Subset (self .random_full , indices = range (64 * 3 , 64 * 4 ))
170178
171- def train_dataloader (self ):
179+ def train_dataloader (self ) -> DataLoader :
172180 return DataLoader (self .random_train )
173181
174- def val_dataloader (self ):
182+ def val_dataloader (self ) -> DataLoader :
175183 return DataLoader (self .random_val )
176184
177- def test_dataloader (self ):
185+ def test_dataloader (self ) -> DataLoader :
178186 return DataLoader (self .random_test )
179187
180- def predict_dataloader (self ):
188+ def predict_dataloader (self ) -> DataLoader :
181189 return DataLoader (self .random_predict )
182190
183191
184192class ManualOptimBoringModel (BoringModel ):
185- def __init__ (self ):
193+ def __init__ (self ) -> None :
186194 super ().__init__ ()
187195 self .automatic_optimization = False
188196
189- def training_step (self , batch , batch_idx ):
197+ def training_step (self , batch : Tensor , batch_idx : int ) -> STEP_OUTPUT : # type: ignore[override]
190198 opt = self .optimizers ()
199+ assert isinstance (opt , (Optimizer , LightningOptimizer ))
191200 output = self (batch )
192201 loss = self .loss (batch , output )
193202 opt .zero_grad ()
@@ -202,21 +211,21 @@ def __init__(self, out_dim: int = 10, learning_rate: float = 0.02):
202211 self .l1 = torch .nn .Linear (32 , out_dim )
203212 self .learning_rate = learning_rate
204213
205- def forward (self , x ):
214+ def forward (self , x : Tensor ) -> Tensor : # type: ignore[override]
206215 return torch .relu (self .l1 (x .view (x .size (0 ), - 1 )))
207216
208- def training_step (self , batch , batch_nb ):
217+ def training_step (self , batch : Tensor , batch_nb : int ) -> STEP_OUTPUT : # type: ignore[override]
209218 x = batch
210219 x = self (x )
211220 loss = x .sum ()
212221 return loss
213222
214- def configure_optimizers (self ):
223+ def configure_optimizers (self ) -> torch . optim . Optimizer :
215224 return torch .optim .Adam (self .parameters (), lr = self .learning_rate )
216225
217226
218227class Net (nn .Module ):
219- def __init__ (self ):
228+ def __init__ (self ) -> None :
220229 super ().__init__ ()
221230 self .conv1 = nn .Conv2d (1 , 32 , 3 , 1 )
222231 self .conv2 = nn .Conv2d (32 , 64 , 3 , 1 )
@@ -225,7 +234,7 @@ def __init__(self):
225234 self .fc1 = nn .Linear (9216 , 128 )
226235 self .fc2 = nn .Linear (128 , 10 )
227236
228- def forward (self , x ) :
237+ def forward (self , x : Tensor ) -> Tensor :
229238 x = self .conv1 (x )
230239 x = F .relu (x )
231240 x = self .conv2 (x )
0 commit comments