Skip to content

Commit 0894eae

Browse files
authored
Merge branch 'master' into fix/4237-auc-unstable-reorder
2 parents ec6b0e3 + f6efb71 commit 0894eae

File tree

6 files changed

+59
-18
lines changed

6 files changed

+59
-18
lines changed

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2020
- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587))
2121

2222
- Allow changing the logged step value in `validation_step` ([#4130](https://github.com/PyTorchLightning/pytorch-lightning/pull/4130))
23+
- Allow setting `replace_sampler_ddp=True` with a distributed sampler already added ([#4273](https://github.com/PyTorchLightning/pytorch-lightning/pull/4273))
2324

2425
### Deprecated
2526

@@ -124,7 +125,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
124125
### Fixed
125126

126127
- Fixed `current_epoch` property update to reflect true epoch number inside `LightningDataModule`, when `reload_dataloaders_every_epoch=True`. ([#3974](https://github.com/PyTorchLightning/pytorch-lightning/pull/3974))
127-
- Fixed to print scaler value in progress bar ([#4053](https://github.com/PyTorchLightning/pytorch-lightning/pull/4053))
128+
- Fixed to print scaler value in progress bar ([#4053](https://github.com/PyTorchLightning/pytorch-lightning/pull/4053))
128129
- Fixed mismatch between docstring and code regarding when `on_load_checkpoint` hook is called ([#3996](https://github.com/PyTorchLightning/pytorch-lightning/pull/3996))
129130

130131

@@ -469,7 +470,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
469470
- Fixed adding val step argument to metrics ([#2986](https://github.com/PyTorchLightning/pytorch-lightning/pull/2986))
470471
- Fixed an issue that caused `Trainer.test()` to stall in ddp mode ([#2997](https://github.com/PyTorchLightning/pytorch-lightning/pull/2997))
471472
- Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020))
472-
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
473+
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
473474
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
474475
- Fixed setting device to root gpu ([#3042](https://github.com/PyTorchLightning/pytorch-lightning/pull/3042))
475476

pytorch_lightning/core/saving.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import inspect
1818
import os
1919
from argparse import Namespace
20-
from typing import Union, Dict, Any, Optional, Callable, MutableMapping
20+
from typing import Union, Dict, Any, Optional, Callable, MutableMapping, IO
2121
from warnings import warn
2222

2323
import fsspec
@@ -52,7 +52,7 @@ class ModelIO(object):
5252
@classmethod
5353
def load_from_checkpoint(
5454
cls,
55-
checkpoint_path: str,
55+
checkpoint_path: Union[str, IO],
5656
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
5757
hparams_file: Optional[str] = None,
5858
strict: bool = True,
@@ -65,7 +65,7 @@ def load_from_checkpoint(
6565
Any arguments specified through \*args and \*\*kwargs will override args stored in `hparams`.
6666
6767
Args:
68-
checkpoint_path: Path to checkpoint. This can also be a URL.
68+
checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object
6969
map_location:
7070
If your checkpoint saved a GPU model and you now load on CPUs
7171
or a different number of GPUs, use this to map to the new setup.

pytorch_lightning/trainer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,8 @@ def world_size(self):
12781278
Enables auto adding of distributed sampler. By default it will add ``shuffle=True``
12791279
for train sampler and ``shuffle=False`` for val/test sampler. If you want to customize
12801280
it, you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.
1281+
If ``replace_sampler_ddp=True`` and a distributed sampler was already added,
1282+
Lightning will not replace the existing one.
12811283
12821284
.. testcode::
12831285

pytorch_lightning/trainer/data_loading.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
103103
f' (try {num_cpus} which is the number of cpus on this machine)'
104104
' in the `DataLoader` init to improve performance.')
105105

106-
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
106+
def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:
107107

108108
# don't do anything if it's not a dataloader
109109
is_dataloader = isinstance(dataloader, DataLoader)
@@ -112,8 +112,9 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
112112

113113
if not is_dataloader or is_iterable_ds:
114114
return dataloader
115-
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)
116115

116+
is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu
117+
need_dist_sampler = is_in_dist and not isinstance(dataloader.sampler, DistributedSampler)
117118
if self.replace_sampler_ddp and need_dist_sampler:
118119
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
119120
raise MisconfigurationException(
@@ -123,7 +124,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
123124
' `replace_sampler_ddp`=False if you want to use your custom sampler.')
124125

125126
# replace with distributed sampler
126-
sampler = self._get_distributed_sampler(dataloader, train)
127+
sampler = self._get_distributed_sampler(dataloader, shuffle)
127128
dataloader = self.replace_sampler(dataloader, sampler)
128129

129130
return dataloader
@@ -136,10 +137,11 @@ def replace_sampler(self, dataloader, sampler):
136137
}
137138

138139
dl_args['sampler'] = sampler
140+
dl_args['shuffle'] = False
139141
dataloader = type(dataloader)(**dl_args)
140142
return dataloader
141143

142-
def _get_distributed_sampler(self, dataloader, train):
144+
def _get_distributed_sampler(self, dataloader, shuffle):
143145
if self.use_tpu:
144146
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
145147
elif self.use_horovod:
@@ -154,7 +156,7 @@ def _get_distributed_sampler(self, dataloader, train):
154156
assert self.distributed_backend is not None
155157
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)
156158

157-
kwargs['shuffle'] = train and not self.overfit_batches
159+
kwargs['shuffle'] = shuffle and not self.overfit_batches
158160
sampler = DistributedSampler(dataloader.dataset, **kwargs)
159161
return sampler
160162

@@ -179,7 +181,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
179181
self.num_training_batches = 0
180182

181183
# automatically add samplers
182-
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
184+
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, shuffle=True)
183185

184186
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
185187
self._worker_check(self.train_dataloader, 'train dataloader')
@@ -267,7 +269,7 @@ def _reset_eval_dataloader(
267269
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
268270

269271
# add samplers
270-
dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl is not None]
272+
dataloaders = [self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders if dl is not None]
271273

272274
loader_num_batches = []
273275

pytorch_lightning/utilities/cloud_io.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414

1515
import io
1616
from distutils.version import LooseVersion
17-
from typing import Union
17+
from typing import Union, IO
1818
from pathlib import Path
1919
from urllib.parse import urlparse
2020
import torch
2121
import fsspec
2222

2323

24-
def load(path_or_url: str, map_location=None):
24+
def load(path_or_url: Union[str, IO, Path], map_location=None):
25+
if not isinstance(path_or_url, (str, Path)):
26+
# any sort of BytesIO or similiar
27+
return torch.load(path_or_url, map_location=map_location)
2528
if path_or_url.startswith("http"):
2629
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
2730
fs = get_filesystem(path_or_url)

tests/trainer/test_dataloaders.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -686,17 +686,17 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
686686
class CustomDummyObj:
687687
sampler = None
688688

689-
result = trainer.auto_add_sampler(CustomDummyObj(), train=True)
689+
result = trainer.auto_add_sampler(CustomDummyObj(), shuffle=True)
690690
assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader"
691691

692692
dataset = list(range(1000))
693-
result = trainer.auto_add_sampler(CustomDataLoader(dataset), train=True)
693+
result = trainer.auto_add_sampler(CustomDataLoader(dataset), shuffle=True)
694694
assert isinstance(result, torch.utils.data.DataLoader)
695695
assert isinstance(result, CustomDataLoader)
696696
assert hasattr(result, 'dummy_kwarg')
697697

698698
# Shuffled DataLoader should also work
699-
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), train=True)
699+
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), shuffle=True)
700700
assert isinstance(result, torch.utils.data.DataLoader)
701701
assert isinstance(result, CustomDataLoader)
702702
assert hasattr(result, 'dummy_kwarg')
@@ -707,7 +707,7 @@ class CustomSampler(torch.utils.data.Sampler):
707707
# Should raise an error if existing sampler is being replaced
708708
with pytest.raises(MisconfigurationException, match='DistributedSampler'):
709709
trainer.auto_add_sampler(
710-
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True)
710+
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), shuffle=True)
711711

712712

713713
class DistribSamplerCallback(Callback):
@@ -746,6 +746,39 @@ def test_dataloader_distributed_sampler(tmpdir):
746746
trainer.test(ckpt_path=None)
747747

748748

749+
class ModelWithDataLoaderDistributedSampler(EvalModelTemplate):
750+
751+
def train_dataloader(self):
752+
dataloader = super().train_dataloader()
753+
dist_sampler = DistributedSampler(dataloader.dataset, shuffle=True)
754+
return DataLoader(
755+
dataloader.dataset,
756+
batch_size=self.batch_size,
757+
drop_last=False,
758+
sampler=dist_sampler,
759+
shuffle=False
760+
)
761+
762+
763+
@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.')
764+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
765+
def test_dataloader_distributed_sampler_already_attached(tmpdir):
766+
""" Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on dataloader """
767+
768+
model = ModelWithDataLoaderDistributedSampler()
769+
trainer = Trainer(
770+
gpus=[0, 1],
771+
num_nodes=1,
772+
distributed_backend='ddp_spawn',
773+
default_root_dir=tmpdir,
774+
max_steps=100,
775+
callbacks=[DistribSamplerCallback()],
776+
replace_sampler_ddp=True,
777+
)
778+
result = trainer.fit(model)
779+
assert result == 1, "DDP Training failed"
780+
781+
749782
@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
750783
def test_batch_size_smaller_than_num_gpus(tmpdir):
751784
# we need at least 3 gpus for this test

0 commit comments

Comments
 (0)