Skip to content

Commit aaacedb

Browse files
authored
Merge 60a479e into c7f30a2
2 parents c7f30a2 + 60a479e commit aaacedb

File tree

16 files changed

+208
-166
lines changed

16 files changed

+208
-166
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4141
- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
4242

4343

44+
- Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386))
45+
46+
4447
### Deprecated
4548

4649

@@ -107,6 +110,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
107110
- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324))
108111

109112

113+
- Fixed `.teardown(stage='fit')` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386))
114+
115+
116+
- Fixed `.on_fit_{start,end}()` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386))
117+
118+
110119
- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260))
111120

112121

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ def package_list_from_file(file):
371371
doctest_global_setup = """
372372
import importlib
373373
import os
374+
from typing import Optional
374375
import torch
375376
from torch import nn
376377
import pytorch_lightning as pl

docs/source/extensions/datamodules.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa
8080
self.data_dir = data_dir
8181
self.batch_size = batch_size
8282
83-
def setup(self, stage=None):
83+
def setup(self, stage: Optional[str] = None):
8484
self.mnist_test = MNIST(self.data_dir, train=False)
8585
mnist_full = MNIST(self.data_dir, train=True)
8686
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
@@ -138,7 +138,7 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
138138
MNIST(self.data_dir, train=True, download=True)
139139
MNIST(self.data_dir, train=False, download=True)
140140
141-
def setup(self, stage=None):
141+
def setup(self, stage: Optional[str] = None):
142142
143143
# Assign train/val datasets for use in dataloaders
144144
if stage == 'fit' or stage is None:
@@ -382,12 +382,12 @@ still ensures the method runs on the correct devices)
382382
383383
dm = MNISTDataModule()
384384
dm.prepare_data()
385-
dm.setup('fit')
385+
dm.setup(stage='fit')
386386
387387
model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
388388
trainer.fit(model, dm)
389389
390-
dm.setup('test')
390+
dm.setup(stage='test')
391391
trainer.test(datamodule=dm)
392392
393393
----------------
@@ -403,7 +403,7 @@ You can of course use DataModules in plain PyTorch code as well.
403403
dm.prepare_data()
404404
405405
# splits/transforms
406-
dm.setup('fit')
406+
dm.setup(stage='fit')
407407
408408
# use data
409409
for batch in dm.train_dataloader():
@@ -412,7 +412,7 @@ You can of course use DataModules in plain PyTorch code as well.
412412
...
413413
414414
# lazy load test data
415-
dm.setup('test')
415+
dm.setup(stage='test')
416416
for batch in dm.test_dataloader():
417417
...
418418

docs/source/starter/introduction_guide.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ In this case, it's better to group the full definition of a dataset into a `Data
240240
tokenize()
241241
build_vocab()
242242

243-
def setup(self):
243+
def setup(self, stage: Optional[str] = None):
244244
# called on every GPU
245245
vocab = load_vocab()
246246
self.vocab_size = len(vocab)
@@ -310,8 +310,8 @@ An alternative to using a DataModule is to defer initialization of the models mo
310310
download_data()
311311
tokenize()
312312

313-
def setup(self, step):
314-
# step is either 'fit' or 'test' 90% of the time not relevant
313+
def setup(self, stage: Optional[str] = None):
314+
# step is either 'fit', 'validate', 'test', or 'predict'. 90% of the time not relevant
315315
data = load_data()
316316
num_classes = data.classes
317317
self.l1 = nn.Linear(..., num_classes)
@@ -598,7 +598,7 @@ In this method we do all the preparation we need to do once (instead of on every
598598
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
599599
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
600600
601-
def setup(self, stage):
601+
def setup(self, stage: Optional[str] = None):
602602
# transform
603603
transform=transforms.Compose([transforms.ToTensor()])
604604
mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform)

docs/source/starter/new-project.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning.
651651
MNIST(os.getcwd(), train=False, download=True)
652652

653653
# OPTIONAL, called for every GPU/machine (assigning state is OK)
654-
def setup(self, stage):
654+
def setup(self, stage: Optional[str] = None):
655655
# transforms
656656
transform=transforms.Compose([
657657
transforms.ToTensor(),

pytorch_lightning/callbacks/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""
1818

1919
import abc
20-
from typing import Any, Dict
20+
from typing import Any, Dict, Optional
2121

2222
from pytorch_lightning.core.lightning import LightningModule
2323

@@ -33,12 +33,12 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModul
3333
"""Called before accelerator is being setup"""
3434
pass
3535

36-
def setup(self, trainer, pl_module: LightningModule, stage: str) -> None:
37-
"""Called when fit or test begins"""
36+
def setup(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
37+
"""Called when fit, validate, test, predict, or tune begins"""
3838
pass
3939

40-
def teardown(self, trainer, pl_module: LightningModule, stage: str) -> None:
41-
"""Called when fit or test ends"""
40+
def teardown(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
41+
"""Called when fit, validate, test, predict, or tune ends"""
4242
pass
4343

4444
def on_init_start(self, trainer) -> None:

pytorch_lightning/core/datamodule.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ def __call__(cls, *args, **kwargs):
5555
def track_data_hook_calls(fn):
5656
"""A decorator that checks if prepare_data/setup have been called.
5757
58-
- When dm.prepare_data() is called, dm.has_prepared_data gets set to True
59-
- When dm.setup('fit') is called, dm.has_setup_fit gets set to True
60-
- When dm.setup('test') is called, dm.has_setup_test gets set to True
61-
- When dm.setup() is called without stage arg, both dm.has_setup_fit and dm.has_setup_test get set to True
58+
- When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True
59+
- When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True
60+
- When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``.
61+
Its corresponding `dm_has_setup_{stage}` attribute gets set to True
6262
6363
Args:
6464
fn (function): Function that will be tracked to see if it has been called.
@@ -77,15 +77,15 @@ def wrapped_fn(*args, **kwargs):
7777
if fn.__name__ == "setup":
7878

7979
# Get stage either by grabbing from args or checking kwargs.
80-
# If not provided, set call status of 'fit' and 'test' to True.
80+
# If not provided, set call status of 'fit', 'validate', and 'test' to True.
8181
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test()
8282
stage = args[1] if len(args) > 1 else kwargs.get("stage", None)
8383

84-
if stage == "fit" or stage is None:
85-
obj._has_setup_fit = True
86-
87-
if stage == "test" or stage is None:
88-
obj._has_setup_test = True
84+
if stage is None:
85+
for s in ("fit", "validate", "test"):
86+
setattr(obj, f"_has_setup_{s}", True)
87+
else:
88+
setattr(obj, f"_has_setup_{stage}", True)
8989

9090
if fn.__name__ == "prepare_data":
9191
obj._has_prepared_data = True
@@ -156,7 +156,9 @@ def __init__(
156156
# Private attrs to keep track of whether or not data hooks have been called yet
157157
self._has_prepared_data = False
158158
self._has_setup_fit = False
159+
self._has_setup_validate = False
159160
self._has_setup_test = False
161+
self._has_setup_predict = False
160162

161163
@property
162164
def train_transforms(self):
@@ -214,32 +216,50 @@ def size(self, dim=None) -> Union[Tuple, int]:
214216
return self.dims
215217

216218
@property
217-
def has_prepared_data(self):
218-
"""Return bool letting you know if datamodule.prepare_data() has been called or not.
219+
def has_prepared_data(self) -> bool:
220+
"""Return bool letting you know if ``datamodule.prepare_data()`` has been called or not.
219221
220222
Returns:
221-
bool: True if datamodule.prepare_data() has been called. False by default.
223+
bool: True if ``datamodule.prepare_data()`` has been called. False by default.
222224
"""
223225
return self._has_prepared_data
224226

225227
@property
226-
def has_setup_fit(self):
227-
"""Return bool letting you know if datamodule.setup('fit') has been called or not.
228+
def has_setup_fit(self) -> bool:
229+
"""Return bool letting you know if ``datamodule.setup(stage='fit')`` has been called or not.
228230
229231
Returns:
230-
bool: True if datamodule.setup('fit') has been called. False by default.
232+
bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default.
231233
"""
232234
return self._has_setup_fit
233235

234236
@property
235-
def has_setup_test(self):
236-
"""Return bool letting you know if datamodule.setup('test') has been called or not.
237+
def has_setup_validate(self) -> bool:
238+
"""Return bool letting you know if ``datamodule.setup(stage='validate')`` has been called or not.
239+
240+
Returns:
241+
bool: True if ``datamodule.setup(stage='validate')`` has been called. False by default.
242+
"""
243+
return self._has_setup_validate
244+
245+
@property
246+
def has_setup_test(self) -> bool:
247+
"""Return bool letting you know if ``datamodule.setup(stage='test')`` has been called or not.
237248
238249
Returns:
239-
bool: True if datamodule.setup('test') has been called. False by default.
250+
bool: True if ``datamodule.setup(stage='test')`` has been called. False by default.
240251
"""
241252
return self._has_setup_test
242253

254+
@property
255+
def has_setup_predict(self) -> bool:
256+
"""Return bool letting you know if ``datamodule.setup(stage='predict')`` has been called or not.
257+
258+
Returns:
259+
bool: True if ``datamodule.setup(stage='predict')`` has been called. False by default.
260+
"""
261+
return self._has_setup_predict
262+
243263
@abstractmethod
244264
def prepare_data(self, *args, **kwargs):
245265
pass

pytorch_lightning/core/hooks.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525
class ModelHooks:
2626
"""Hooks to be used in LightningModule."""
2727

28-
def setup(self, stage: str) -> None:
28+
def setup(self, stage: Optional[str] = None) -> None:
2929
"""
30-
Called at the beginning of fit and test.
30+
Called at the beginning of fit (train + validate), validate, test, predict, or tune.
3131
This is a good hook when you need to build models dynamically or adjust something about them.
3232
This hook is called on every process when using DDP.
3333
3434
Args:
35-
stage: either 'fit' or 'test'
35+
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
3636
3737
Example::
3838
@@ -53,12 +53,12 @@ def setup(stage):
5353
5454
"""
5555

56-
def teardown(self, stage: str) -> None:
56+
def teardown(self, stage: Optional[str] = None) -> None:
5757
"""
58-
Called at the end of fit and test.
58+
Called at the end of fit (train + validate), validate, test, predict, or tune.
5959
6060
Args:
61-
stage: either 'fit' or 'test'
61+
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
6262
"""
6363

6464
def on_fit_start(self) -> None:

pytorch_lightning/trainer/callback_hook.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from abc import ABC
1616
from copy import deepcopy
1717
from inspect import signature
18-
from typing import Any, Callable, Dict, List, Type
18+
from typing import Any, Callable, Dict, List, Type, Optional
1919

2020
from pytorch_lightning.callbacks import Callback
2121
from pytorch_lightning.core.lightning import LightningModule
@@ -29,18 +29,18 @@ class TrainerCallbackHookMixin(ABC):
2929
callbacks: List[Callback] = []
3030
lightning_module: LightningModule
3131

32-
def on_before_accelerator_backend_setup(self, model):
33-
"""Called in the beginning of fit and test"""
32+
def on_before_accelerator_backend_setup(self, model: LightningModule) -> None:
33+
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
3434
for callback in self.callbacks:
3535
callback.on_before_accelerator_backend_setup(self, model)
3636

37-
def setup(self, model, stage: str):
38-
"""Called in the beginning of fit and test"""
37+
def setup(self, model: LightningModule, stage: Optional[str]) -> None:
38+
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
3939
for callback in self.callbacks:
4040
callback.setup(self, model, stage)
4141

42-
def teardown(self, stage: str):
43-
"""Called at the end of fit and test"""
42+
def teardown(self, stage: Optional[str] = None) -> None:
43+
"""Called at the end of fit (train + validate), validate, test, or predict, or tune."""
4444
for callback in self.callbacks:
4545
callback.teardown(self, self.lightning_module, stage)
4646

@@ -124,15 +124,15 @@ def on_train_end(self):
124124
for callback in self.callbacks:
125125
callback.on_train_end(self, self.lightning_module)
126126

127-
def on_pretrain_routine_start(self, model):
128-
"""Called when the train begins."""
127+
def on_pretrain_routine_start(self) -> None:
128+
"""Called when the pre-train routine begins."""
129129
for callback in self.callbacks:
130-
callback.on_pretrain_routine_start(self, model)
130+
callback.on_pretrain_routine_start(self, self.lightning_module)
131131

132-
def on_pretrain_routine_end(self, model):
133-
"""Called when the train ends."""
132+
def on_pretrain_routine_end(self) -> None:
133+
"""Called when the pre-train routine ends."""
134134
for callback in self.callbacks:
135-
callback.on_pretrain_routine_end(self, model)
135+
callback.on_pretrain_routine_end(self, self.lightning_module)
136136

137137
def on_batch_start(self):
138138
"""Called when the training batch begins."""

pytorch_lightning/trainer/model_hooks.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import inspect
1616
from abc import ABC
17+
from typing import Optional
1718

1819
from pytorch_lightning.core.lightning import LightningModule
1920

@@ -22,13 +23,14 @@ class TrainerModelHooksMixin(ABC):
2223

2324
lightning_module: LightningModule
2425

25-
def is_function_implemented(self, f_name, model=None):
26+
def is_function_implemented(self, f_name: str, model: Optional[LightningModule] = None) -> bool:
27+
# note: currently unused - kept as it is public
2628
if model is None:
2729
model = self.lightning_module
2830
f_op = getattr(model, f_name, None)
2931
return callable(f_op)
3032

31-
def has_arg(self, f_name, arg_name):
33+
def has_arg(self, f_name: str, arg_name: str) -> bool:
3234
model = self.lightning_module
3335
f_op = getattr(model, f_name, None)
3436
return arg_name in inspect.signature(f_op).parameters

0 commit comments

Comments
 (0)