1- # Copyright The PyTorch Lightning team.
2- #
3- # Licensed under the Apache License, Version 2.0 (the "License");
4- # you may not use this file except in compliance with the License.
5- # You may obtain a copy of the License at
6- #
7- # http://www.apache.org/licenses/LICENSE-2.0
8- #
9- # Unless required by applicable law or agreed to in writing, software
10- # distributed under the License is distributed on an "AS IS" BASIS,
11- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12- # See the License for the specific language governing permissions and
13- # limitations under the License.
14-
15- # --------------------------------------------
16- # --------------------------------------------
17- # --------------------------------------------
18- # USE THIS MODEL TO REPRODUCE A BUG YOU REPORT
19- # --------------------------------------------
20- # --------------------------------------------
21- # --------------------------------------------
221import os
232
243import torch
254from torch .utils .data import Dataset
265
27- from pl_examples import cli_lightning_logo
28- from pytorch_lightning import LightningModule , Trainer
6+ from pytorch_lightning import LightningModule , Trainer , seed_everything
297
8+ import numpy as np
9+ from torch .utils .data import Dataset
3010
31- class RandomDataset (Dataset ):
32- """
33- >>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS
34- <...bug_report_model.RandomDataset object at ...>
35- """
36-
37- def __init__ (self , size , length ):
38- self .len = length
39- self .data = torch .randn (length , size )
4011
12+ class RandomDataset (Dataset ):
4113 def __getitem__ (self , index ):
42- return self . data [ index ]
14+ return np . random . randint ( 0 , 10 , 3 )
4315
4416 def __len__ (self ):
45- return self . len
17+ return 16
4618
4719
4820class BoringModel (LightningModule ):
49- """
50- >>> BoringModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
51- BoringModel(
52- (layer): Linear(...)
53- )
54- """
5521
5622 def __init__ (self ):
57- """
58- Testing PL Module
59-
60- Use as follows:
61- - subclass
62- - modify the behavior for what you want
63-
64- class TestModel(BaseTestModel):
65- def training_step(...):
66- # do your own thing
67-
68- or:
69-
70- model = BaseTestModel()
71- model.training_epoch_end = None
72-
73- """
7423 super ().__init__ ()
7524 self .layer = torch .nn .Linear (32 , 2 )
7625
@@ -87,71 +36,32 @@ def step(self, x):
8736 return out
8837
8938 def training_step (self , batch , batch_idx ):
90- output = self .layer (batch )
91- loss = self .loss (batch , output )
92- return {"loss" : loss }
93-
94- def training_step_end (self , training_step_outputs ):
95- return training_step_outputs
96-
97- def training_epoch_end (self , outputs ) -> None :
98- torch .stack ([x ["loss" ] for x in outputs ]).mean ()
99-
100- def validation_step (self , batch , batch_idx ):
101- output = self .layer (batch )
102- loss = self .loss (batch , output )
103- return {"x" : loss }
104-
105- def validation_epoch_end (self , outputs ) -> None :
106- torch .stack ([x ['x' ] for x in outputs ]).mean ()
107-
108- def test_step (self , batch , batch_idx ):
109- output = self .layer (batch )
110- loss = self .loss (batch , output )
111- return {"y" : loss }
112-
113- def test_epoch_end (self , outputs ) -> None :
114- torch .stack ([x ["y" ] for x in outputs ]).mean ()
39+ print (batch )
40+ return None
11541
11642 def configure_optimizers (self ):
11743 optimizer = torch .optim .SGD (self .layer .parameters (), lr = 0.1 )
11844 lr_scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = 1 )
11945 return [optimizer ], [lr_scheduler ]
12046
12147
122- # NOTE: If you are using a cmd line to run your script,
123- # provide the cmd line as below.
124- # opt = "--max_epochs 1 --limit_train_batches 1".split(" ")
125- # parser = ArgumentParser()
126- # args = parser.parse_args(opt)
127-
128-
129- class TestModel (BoringModel ):
130-
131- def on_train_epoch_start (self ) -> None :
132- print ('override any method to prove your bug' )
133-
134-
135- def test_run ():
48+ def run ():
13649
13750 # fake data
138- train_data = torch .utils .data .DataLoader (RandomDataset (32 , 64 ))
139- val_data = torch .utils .data .DataLoader (RandomDataset (32 , 64 ))
140- test_data = torch .utils .data .DataLoader (RandomDataset (32 , 64 ))
51+ train_data = torch .utils .data .DataLoader (RandomDataset (), batch_size = 2 , num_workers = 4 )
14152
14253 # model
143- model = TestModel ()
54+ model = BoringModel ()
14455 trainer = Trainer (
14556 default_root_dir = os .getcwd (),
146- limit_train_batches = 1 ,
147- limit_val_batches = 1 ,
57+ limit_train_batches = 4 ,
14858 max_epochs = 1 ,
14959 weights_summary = None ,
60+ progress_bar_refresh_rate = 0 ,
15061 )
151- trainer .fit (model , train_data , val_data )
152- trainer .test (test_dataloaders = test_data )
62+ trainer .fit (model , train_data )
15363
15464
15565if __name__ == '__main__' :
156- cli_lightning_logo ( )
157- test_run ()
66+ seed_everything ( 1 )
67+ run ()
0 commit comments