Skip to content

Commit 18e61ee

Browse files
committed
example
1 parent fe0d088 commit 18e61ee

File tree

1 file changed

+16
-106
lines changed

1 file changed

+16
-106
lines changed

pl_examples/bug_report_model.py

Lines changed: 16 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,25 @@
1-
# Copyright The PyTorch Lightning team.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
15-
# --------------------------------------------
16-
# --------------------------------------------
17-
# --------------------------------------------
18-
# USE THIS MODEL TO REPRODUCE A BUG YOU REPORT
19-
# --------------------------------------------
20-
# --------------------------------------------
21-
# --------------------------------------------
221
import os
232

243
import torch
254
from torch.utils.data import Dataset
265

27-
from pl_examples import cli_lightning_logo
28-
from pytorch_lightning import LightningModule, Trainer
6+
from pytorch_lightning import LightningModule, Trainer, seed_everything
297

8+
import numpy as np
9+
from torch.utils.data import Dataset
3010

31-
class RandomDataset(Dataset):
32-
"""
33-
>>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS
34-
<...bug_report_model.RandomDataset object at ...>
35-
"""
36-
37-
def __init__(self, size, length):
38-
self.len = length
39-
self.data = torch.randn(length, size)
4011

12+
class RandomDataset(Dataset):
4113
def __getitem__(self, index):
42-
return self.data[index]
14+
return np.random.randint(0, 10, 3)
4315

4416
def __len__(self):
45-
return self.len
17+
return 16
4618

4719

4820
class BoringModel(LightningModule):
49-
"""
50-
>>> BoringModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
51-
BoringModel(
52-
(layer): Linear(...)
53-
)
54-
"""
5521

5622
def __init__(self):
57-
"""
58-
Testing PL Module
59-
60-
Use as follows:
61-
- subclass
62-
- modify the behavior for what you want
63-
64-
class TestModel(BaseTestModel):
65-
def training_step(...):
66-
# do your own thing
67-
68-
or:
69-
70-
model = BaseTestModel()
71-
model.training_epoch_end = None
72-
73-
"""
7423
super().__init__()
7524
self.layer = torch.nn.Linear(32, 2)
7625

@@ -87,71 +36,32 @@ def step(self, x):
8736
return out
8837

8938
def training_step(self, batch, batch_idx):
90-
output = self.layer(batch)
91-
loss = self.loss(batch, output)
92-
return {"loss": loss}
93-
94-
def training_step_end(self, training_step_outputs):
95-
return training_step_outputs
96-
97-
def training_epoch_end(self, outputs) -> None:
98-
torch.stack([x["loss"] for x in outputs]).mean()
99-
100-
def validation_step(self, batch, batch_idx):
101-
output = self.layer(batch)
102-
loss = self.loss(batch, output)
103-
return {"x": loss}
104-
105-
def validation_epoch_end(self, outputs) -> None:
106-
torch.stack([x['x'] for x in outputs]).mean()
107-
108-
def test_step(self, batch, batch_idx):
109-
output = self.layer(batch)
110-
loss = self.loss(batch, output)
111-
return {"y": loss}
112-
113-
def test_epoch_end(self, outputs) -> None:
114-
torch.stack([x["y"] for x in outputs]).mean()
39+
print(batch)
40+
return None
11541

11642
def configure_optimizers(self):
11743
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
11844
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
11945
return [optimizer], [lr_scheduler]
12046

12147

122-
# NOTE: If you are using a cmd line to run your script,
123-
# provide the cmd line as below.
124-
# opt = "--max_epochs 1 --limit_train_batches 1".split(" ")
125-
# parser = ArgumentParser()
126-
# args = parser.parse_args(opt)
127-
128-
129-
class TestModel(BoringModel):
130-
131-
def on_train_epoch_start(self) -> None:
132-
print('override any method to prove your bug')
133-
134-
135-
def test_run():
48+
def run():
13649

13750
# fake data
138-
train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
139-
val_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
140-
test_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
51+
train_data = torch.utils.data.DataLoader(RandomDataset(), batch_size=2, num_workers=4)
14152

14253
# model
143-
model = TestModel()
54+
model = BoringModel()
14455
trainer = Trainer(
14556
default_root_dir=os.getcwd(),
146-
limit_train_batches=1,
147-
limit_val_batches=1,
57+
limit_train_batches=4,
14858
max_epochs=1,
14959
weights_summary=None,
60+
progress_bar_refresh_rate=0,
15061
)
151-
trainer.fit(model, train_data, val_data)
152-
trainer.test(test_dataloaders=test_data)
62+
trainer.fit(model, train_data)
15363

15464

15565
if __name__ == '__main__':
156-
cli_lightning_logo()
157-
test_run()
66+
seed_everything(1)
67+
run()

0 commit comments

Comments
 (0)