Skip to content

Commit 969c25e

Browse files
author
SeanNaren
committed
Merge branch 'master' into fix/sharded_clip_val
2 parents 3f1aeb6 + b9cf122 commit 969c25e

31 files changed

+485
-173
lines changed

.github/workflows/ci_test-conda.yml

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,6 @@ jobs:
3030
pip install --requirement requirements/devel.txt --upgrade-strategy only-if-needed
3131
pip list
3232
33-
- name: Cache datasets
34-
# todo this probably does not work with docker images, rather cache dockers
35-
uses: actions/cache@v2
36-
with:
37-
path: Datasets
38-
key: pl-dataset
39-
4033
- name: Pull checkpoints from S3
4134
# todo: consider adding coma caching, but ATM all models have less then 100KB
4235
run: |
@@ -46,6 +39,12 @@ jobs:
4639
unzip -o checkpoints.zip
4740
ls -l checkpoints/
4841
42+
# todo: require proper fix in docker image
43+
- name: Hotfix dependency
44+
run: |
45+
pip install torchtext==0.6.0 -U
46+
shell: bash
47+
4948
- name: Tests
5049
run: |
5150
# NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003

.github/workflows/ci_test-full.yml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ jobs:
112112
pip list
113113
shell: bash
114114

115+
# todo: require proper fix in docker image
116+
- name: Hotfix dependency
117+
run: |
118+
pip install torchtext==0.6.0 -U
119+
shell: bash
120+
115121
- name: Reinstall Horovod if necessary
116122
if: runner.os != 'windows'
117123
env:
@@ -135,7 +141,12 @@ jobs:
135141
- name: Tests
136142
run: |
137143
# NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003
138-
coverage run --source pytorch_lightning -m pytest pytorch_lightning tests pl_examples -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
144+
coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
145+
146+
# todo: put this back just when TorchVision can download datasets
147+
#- name: Examples
148+
# run: |
149+
# python -m pytest pl_examples -v --durations=10
139150

140151
- name: Upload pytest test results
141152
uses: actions/upload-artifact@v2

azure-pipelines.yml

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ jobs:
7171
python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu >= 2, f'GPU: {mgpu}'"
7272
displayName: 'Env details'
7373
74+
# todo: require proper fix in docker image
75+
- bash: |
76+
pip install torchtext==0.7 -U
77+
displayName: 'HotFix'
78+
7479
- bash: |
7580
wget https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip -P legacy/
7681
unzip -o legacy/checkpoints.zip -d legacy/
@@ -92,11 +97,13 @@ jobs:
9297
displayName: 'Statistics'
9398
9499
- bash: |
95-
python -m pytest benchmarks pl_examples -v --maxfail=2 --durations=0
96-
displayName: 'Testing: extended'
97-
98-
- bash: |
99-
python setup.py install --user --quiet
100-
bash pl_examples/run_ddp-example.sh
101-
pip uninstall -y pytorch-lightning
102-
displayName: 'Examples'
100+
python -m pytest benchmarks -v --maxfail=2 --durations=0
101+
displayName: 'Testing: benchmarks'
102+
103+
# todo: put this back just when TorchVision can download datasets
104+
#- bash: |
105+
# python -m pytest pl_examples -v --maxfail=2 --durations=0
106+
# python setup.py install --user --quiet
107+
# bash pl_examples/run_ddp-example.sh
108+
# pip uninstall -y pytorch-lightning
109+
# displayName: 'Examples'

docs/source/advanced/multiple_loaders.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ Lightning supports multiple dataloaders in a few ways.
1616

1717
----------
1818

19+
.. _multiple-training-dataloaders:
20+
1921
Multiple training dataloaders
2022
-----------------------------
2123
For training, the usual way to use multiple dataloaders is to create a ``DataLoader`` class
@@ -86,6 +88,27 @@ For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer
8688

8789
return loaders
8890

91+
Furthermore, Lightning also supports that nested lists and dicts (or a combination) can
92+
be returned
93+
94+
.. testcode::
95+
96+
class LitModel(LightningModule):
97+
98+
def train_dataloader(self):
99+
100+
loader_a = torch.utils.data.DataLoader(range(8), batch_size=4)
101+
loader_b = torch.utils.data.DataLoader(range(16), batch_size=4)
102+
loader_c = torch.utils.data.DataLoader(range(32), batch_size=4)
103+
loader_c = torch.utils.data.DataLoader(range(64), batch_size=4)
104+
105+
# pass loaders as a nested dict. This will create batches like this:
106+
# {'loader_a_b': {'a': batch from loader a, 'b': batch from loader b},
107+
# 'loader_c_d': {'c': batch from loader c, 'd': batch from loader d}}
108+
loaders = {'loaders_a_b': {'a': loader_a, 'b': loader_b},
109+
'loaders_c_d': {'c': loader_c, 'd': loader_d}}
110+
return loaders
111+
89112
----------
90113

91114
Test/Val dataloaders

docs/source/governance.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ Core Maintainers
2121
- Nicki Skafte (`skaftenicki <https://github.com/SkafteNicki>`_)
2222
- Peter Yu (`yukw777 <https://github.com/yukw777>`_)
2323
- Rohit Gupta (`rohitgr7 <https://github.com/rohitgr7>`_)
24-
- Lezwon Castelino (`lezwon <https://github.com/lezwon>`_)
2524
- Jeff Yang (`ydcjeff <https://github.com/ydcjeff>`_)
2625
- Roger Shieh (`s-rog <https://github.com/s-rog>`_)
2726
- Carlos Mocholí (`carmocca <https://github.com/carmocca>`_)

pytorch_lightning/accelerators/accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, torch.Tensor]:
379379
return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer)
380380

381381
def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]:
382-
return checkpoint
382+
return self.training_type_plugin.on_save(checkpoint)
383383

384384
def barrier(self, name: Optional[str] = None) -> None:
385385
self.training_type_plugin.barrier(name=name)

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def save_checkpoint(self, trainer, pl_module):
239239
self._save_top_k_checkpoints(trainer, pl_module, monitor_candidates)
240240

241241
# Mode 2: save the last checkpoint
242-
self._save_last_checkpoint(trainer, pl_module, monitor_candidates)
242+
self._save_last_checkpoint(trainer, monitor_candidates)
243243

244244
def __validate_init_configuration(self):
245245
if self.save_top_k is not None and self.save_top_k < -1:
@@ -291,8 +291,7 @@ def _del_model(self, filepath: str):
291291
self._fs.rm(filepath)
292292
log.debug(f"Removed checkpoint: {filepath}")
293293

294-
def _save_model(self, filepath: str, trainer, pl_module):
295-
# Todo: required argument `pl_module` is not used
294+
def _save_model(self, filepath: str, trainer):
296295
# in debugging, track when we save checkpoints
297296
trainer.dev_debugger.track_checkpointing_history(filepath)
298297

@@ -481,7 +480,7 @@ def _monitor_candidates(self, trainer):
481480
monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch)
482481
return monitor_candidates
483482

484-
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
483+
def _save_last_checkpoint(self, trainer, ckpt_name_metrics):
485484
should_save_last = self.monitor is None or self.save_last
486485
if not should_save_last:
487486
return
@@ -505,9 +504,9 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
505504

506505
if trainer.training_type_plugin.rpc_enabled:
507506
# RPCPlugin manages saving all model states
508-
trainer.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module)
507+
trainer.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer)
509508
else:
510-
self._save_model(last_filepath, trainer, pl_module)
509+
self._save_model(last_filepath, trainer)
511510
if (
512511
self.last_model_path and self.last_model_path != last_filepath
513512
and (self.save_top_k != -1 or self.save_last) and trainer.is_global_zero
@@ -574,7 +573,7 @@ def _update_best_and_save(
574573
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
575574
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
576575
)
577-
self._save_model(filepath, trainer, pl_module)
576+
self._save_model(filepath, trainer)
578577

579578
if del_filepath is not None and filepath != del_filepath:
580579
self._del_model(del_filepath)

pytorch_lightning/core/hooks.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,12 +383,14 @@ def prepare_data(self):
383383
model.test_dataloader()
384384
"""
385385

386-
def train_dataloader(self) -> DataLoader:
386+
def train_dataloader(self) -> Any:
387387
"""
388-
Implement a PyTorch DataLoader for training.
388+
Implement one or more PyTorch DataLoaders for training.
389389
390390
Return:
391-
Single PyTorch :class:`~torch.utils.data.DataLoader`.
391+
Either a single PyTorch :class:`~torch.utils.data.DataLoader` or a collection of these
392+
(list, dict, nested lists and dicts). In the case of multiple dataloaders, please see
393+
this :ref:`page <multiple-training-dataloaders>`
392394
393395
The dataloader you return will not be called every epoch unless you set
394396
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
@@ -414,6 +416,7 @@ def train_dataloader(self) -> DataLoader:
414416
415417
Example::
416418
419+
# single dataloader
417420
def train_dataloader(self):
418421
transform = transforms.Compose([transforms.ToTensor(),
419422
transforms.Normalize((0.5,), (1.0,))])
@@ -426,6 +429,32 @@ def train_dataloader(self):
426429
)
427430
return loader
428431
432+
# multiple dataloaders, return as list
433+
def train_dataloader(self):
434+
mnist = MNIST(...)
435+
cifar = CIFAR(...)
436+
mnist_loader = torch.utils.data.DataLoader(
437+
dataset=mnist, batch_size=self.batch_size, shuffle=True
438+
)
439+
cifar_loader = torch.utils.data.DataLoader(
440+
dataset=cifar, batch_size=self.batch_size, shuffle=True
441+
)
442+
# each batch will be a list of tensors: [batch_mnist, batch_cifar]
443+
return [mnist_loader, cifar_loader]
444+
445+
# multiple dataloader, return as dict
446+
def train_dataloader(self):
447+
mnist = MNIST(...)
448+
cifar = CIFAR(...)
449+
mnist_loader = torch.utils.data.DataLoader(
450+
dataset=mnist, batch_size=self.batch_size, shuffle=True
451+
)
452+
cifar_loader = torch.utils.data.DataLoader(
453+
dataset=cifar, batch_size=self.batch_size, shuffle=True
454+
)
455+
# each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
456+
return {'mnist': mnist_loader, 'cifar': cifar_loader}
457+
429458
"""
430459
rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer")
431460

pytorch_lightning/core/lightning.py

Lines changed: 72 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
import logging
2020
import os
2121
import tempfile
22+
import types
2223
import uuid
2324
from abc import ABC
2425
from argparse import Namespace
2526
from functools import partial
2627
from pathlib import Path
27-
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
28+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
2829

2930
import torch
3031
from torch import ScriptModule, Tensor
@@ -1591,55 +1592,84 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]:
15911592
parents_arguments.update(args)
15921593
return self_arguments, parents_arguments
15931594

1594-
def save_hyperparameters(self, *args, frame=None) -> None:
1595-
"""Save all model arguments.
1595+
def save_hyperparameters(
1596+
self,
1597+
*args,
1598+
ignore: Optional[Union[Sequence[str], str]] = None,
1599+
frame: Optional[types.FrameType] = None
1600+
) -> None:
1601+
"""Save model arguments to ``hparams`` attribute.
15961602
15971603
Args:
15981604
args: single object of `dict`, `NameSpace` or `OmegaConf`
1599-
or string names or arguments from class `__init__`
1600-
1601-
>>> class ManuallyArgsModel(LightningModule):
1602-
... def __init__(self, arg1, arg2, arg3):
1603-
... super().__init__()
1604-
... # manually assign arguments
1605-
... self.save_hyperparameters('arg1', 'arg3')
1606-
... def forward(self, *args, **kwargs):
1607-
... ...
1608-
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
1609-
>>> model.hparams
1610-
"arg1": 1
1611-
"arg3": 3.14
1612-
1613-
>>> class AutomaticArgsModel(LightningModule):
1614-
... def __init__(self, arg1, arg2, arg3):
1615-
... super().__init__()
1616-
... # equivalent automatic
1617-
... self.save_hyperparameters()
1618-
... def forward(self, *args, **kwargs):
1619-
... ...
1620-
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
1621-
>>> model.hparams
1622-
"arg1": 1
1623-
"arg2": abc
1624-
"arg3": 3.14
1625-
1626-
>>> class SingleArgModel(LightningModule):
1627-
... def __init__(self, params):
1628-
... super().__init__()
1629-
... # manually assign single argument
1630-
... self.save_hyperparameters(params)
1631-
... def forward(self, *args, **kwargs):
1632-
... ...
1633-
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
1634-
>>> model.hparams
1635-
"p1": 1
1636-
"p2": abc
1637-
"p3": 3.14
1605+
or string names or arguments from class ``__init__``
1606+
ignore: an argument name or a list of argument names from
1607+
class ``__init__`` to be ignored
1608+
frame: a frame object. Default is None
1609+
1610+
Example::
1611+
>>> class ManuallyArgsModel(LightningModule):
1612+
... def __init__(self, arg1, arg2, arg3):
1613+
... super().__init__()
1614+
... # manually assign arguments
1615+
... self.save_hyperparameters('arg1', 'arg3')
1616+
... def forward(self, *args, **kwargs):
1617+
... ...
1618+
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
1619+
>>> model.hparams
1620+
"arg1": 1
1621+
"arg3": 3.14
1622+
1623+
>>> class AutomaticArgsModel(LightningModule):
1624+
... def __init__(self, arg1, arg2, arg3):
1625+
... super().__init__()
1626+
... # equivalent automatic
1627+
... self.save_hyperparameters()
1628+
... def forward(self, *args, **kwargs):
1629+
... ...
1630+
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
1631+
>>> model.hparams
1632+
"arg1": 1
1633+
"arg2": abc
1634+
"arg3": 3.14
1635+
1636+
>>> class SingleArgModel(LightningModule):
1637+
... def __init__(self, params):
1638+
... super().__init__()
1639+
... # manually assign single argument
1640+
... self.save_hyperparameters(params)
1641+
... def forward(self, *args, **kwargs):
1642+
... ...
1643+
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
1644+
>>> model.hparams
1645+
"p1": 1
1646+
"p2": abc
1647+
"p3": 3.14
1648+
1649+
>>> class ManuallyArgsModel(LightningModule):
1650+
... def __init__(self, arg1, arg2, arg3):
1651+
... super().__init__()
1652+
... # pass argument(s) to ignore as a string or in a list
1653+
... self.save_hyperparameters(ignore='arg2')
1654+
... def forward(self, *args, **kwargs):
1655+
... ...
1656+
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
1657+
>>> model.hparams
1658+
"arg1": 1
1659+
"arg3": 3.14
16381660
"""
16391661
if not frame:
16401662
frame = inspect.currentframe().f_back
16411663
init_args = get_init_args(frame)
16421664
assert init_args, "failed to inspect the self init"
1665+
1666+
if ignore is not None:
1667+
if isinstance(ignore, str):
1668+
ignore = [ignore]
1669+
if isinstance(ignore, (list, tuple)):
1670+
ignore = [arg for arg in ignore if isinstance(arg, str)]
1671+
init_args = {k: v for k, v in init_args.items() if k not in ignore}
1672+
16431673
if not args:
16441674
# take all arguments
16451675
hp = init_args

0 commit comments

Comments
 (0)