diff --git a/CHANGELOG.md b/CHANGELOG.md index a6afe75a1e821..f73bb5e491e38 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enable `NeptuneLogger` to work with `distributed_backend=ddp` ([#1753](https://github.com/PyTorchLightning/pytorch-lightning/pull/1753)) +- Added automatic GPU data transfer to single GPU and CPU inference ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1526)) + ### Changed - Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 32dd13a77982b..a8e7b083fd770 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -20,7 +20,7 @@ from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities import rank_zero_warn - +from pytorch_lightning.utilities import transfer_data_to_device try: import torch_xla.core.xla_model as xm except ImportError: @@ -94,11 +94,42 @@ def forward(self, x): if self.trainer.proc_rank == 0: print(*args, **kwargs) + def __call__(self, *data, **kwargs): + r""" + Automatically moves data to correct device if possible, then call torch.nn.Module.__call__ + Lightning will warn you if it automatically moves any data + + Args: + *data: Any positional arguments for torch.nn.Module.__call__. These are typically input data + **kwargs: Any keyword arguments for torch.nn.Module.__call__ + + Example: + + .. code-block:: python + + model = model.cuda(0) + model.prepare_data() + loader = model.train_dataloader() + for x, y in loader: + output = model(x) # Lightning will automove data here and warn you of it + + """ + devices = [p.device for p in self.parameters()] + # All parameters must be on same device to automove data + # Otherwise we just do what nn.Module does normally + if len(set(devices)) == 1: + device = devices[0] + data = transfer_data_to_device(data, device.type, device.index, warn_on_transfer=True) + kwargs = transfer_data_to_device(kwargs, device.type, device.index, warn_on_transfer=True) + return super(LightningModule, self).__call__(*data, **kwargs) + @abstractmethod def forward(self, *args, **kwargs): r""" Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define the operations you want to use for prediction (i.e.: on a server or as a feature extractor). + LightningModule will also automatically copy data to the same device as the model if the model + is on CPU or a single GPU for inference. Normally you'd call ``self()`` from your :meth:`training_step` method. This makes it easy to write a complex system for training with the outputs diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 1bd235cebad80..b93c9ded1b8f1 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -353,6 +353,7 @@ ) from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.data import transfer_data_to_device try: from apex import amp @@ -435,52 +436,10 @@ def copy_trainer_model_properties(self, model): m.device = self.device def transfer_batch_to_tpu(self, batch): - return self.__transfer_data_to_device(batch, device='tpu') + return transfer_data_to_device(batch, device_type='tpu') def transfer_batch_to_gpu(self, batch, gpu_id): - return self.__transfer_data_to_device(batch, device='gpu', gpu_id=gpu_id) - - def __transfer_data_to_device(self, batch, device, gpu_id=None): - if device == 'tpu' and XLA_AVAILABLE: - # base case: object can be directly moved using `to` - if callable(getattr(batch, 'to', None)): - return batch.to(xm.xla_device()) - - if device == 'gpu': - # base case: object can be directly moved using `cuda` or `to` - if callable(getattr(batch, 'cuda', None)): - return batch.cuda(gpu_id) - - if callable(getattr(batch, 'to', None)): - return batch.to(torch.device('cuda', gpu_id)) - - # when list - if isinstance(batch, list): - for i, x in enumerate(batch): - batch[i] = self.__transfer_data_to_device(x, device, gpu_id) - return batch - - # when tuple - if isinstance(batch, tuple): - # when namedtuple - if hasattr(batch, '_fields'): - elem_type = type(batch) - return elem_type(*(self.__transfer_data_to_device(x, device, gpu_id) for x in batch)) - else: - batch = list(batch) - for i, x in enumerate(batch): - batch[i] = self.__transfer_data_to_device(x, device, gpu_id) - return tuple(batch) - - # when dict - if isinstance(batch, dict): - for k, v in batch.items(): - batch[k] = self.__transfer_data_to_device(v, device, gpu_id) - - return batch - - # nothing matches, return the value as is without transform - return batch + return transfer_data_to_device(batch, device_type='cuda', idx=gpu_id) def single_gpu_train(self, model): model.cuda(self.root_gpu) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index c8bc28052398b..afacdc7839e76 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -1,3 +1,4 @@ """General utilities""" from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.data import transfer_data_to_device diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py new file mode 100644 index 0000000000000..a6f79ff0c05e8 --- /dev/null +++ b/pytorch_lightning/utilities/data.py @@ -0,0 +1,56 @@ +import torch + +from pytorch_lightning.utilities import rank_zero_warn + +try: + import torch_xla.core.xla_model as xm +except ImportError: + XLA_AVAILABLE = False +else: + XLA_AVAILABLE = True + + +def transfer_data_to_device(batch, device_type, idx=None, warn_on_transfer=False): + """ + Utility function to copy data to given device + Works for any form of nested lists, tuples or dictionaries containting tensors + Deal with TPUs separately, they don't use device indexes for some reason + """ + if device_type == 'tpu' and XLA_AVAILABLE: + if callable(getattr(batch, 'to', None)): + if warn_on_transfer: + rank_zero_warn('Auto transferred data to device {}'.format(xm.xla_device())) + return batch.to(xm.xla_device()) + + # base case: nothing to do + device = torch.device(device_type, idx) + if torch.is_tensor(batch) and batch.device == device: + return batch + + # object can be directly moved using `cuda` or `to` + if callable(getattr(batch, 'cuda', None)) and device_type == 'cuda': + if warn_on_transfer: + rank_zero_warn('Auto transferred data to device {}'.format(device)) + return batch.cuda(device=device) + + if callable(getattr(batch, 'to', None)): + if warn_on_transfer: + rank_zero_warn('Auto transferred data to device {}'.format(device)) + return batch.to(device=device) + + # when list or tuple + if isinstance(batch, (list, tuple)): + if isinstance(batch, tuple): + batch = list(batch) + for i, x in enumerate(batch): + batch[i] = transfer_data_to_device(x, device_type, idx, warn_on_transfer) + return batch + + # when dict + if isinstance(batch, dict): + for k, v in batch.items(): + batch[k] = transfer_data_to_device(v, device_type, idx, warn_on_transfer) + return batch + + # nothing matches, return the value as is without transform + return batch diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index f9eb4b9e5810e..a799fedb4704b 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -338,3 +338,20 @@ def test_single_gpu_model(tmpdir): model = EvalModelTemplate() tutils.run_model_test(trainer_options, model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_auto_move_data(tmpdir): + """Make sure auto moving data works for the base case where it doesn't have to move anything""" + + tutils.reset_seed() + tutils.set_random_master_port() + + model, hparams = tutils.get_default_model() + model.prepare_data() + loader = model.train_dataloader() + for x, y in loader: + x = x.view(x.size(0), -1) + assert model(x).device == torch.device('cpu'), "Automoving data to same device as model failed" + x = x.cuda(0) + assert model(x).device == torch.device('cpu'), "Automoving data to same device as model failed" diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 6eafc19d863ee..56fc3e33304c9 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -147,6 +147,24 @@ def test_multi_gpu_none_backend(tmpdir): tutils.run_model_test(trainer_options, model) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_auto_move_data(tmpdir): + """Make sure auto moving data works""" + + tutils.reset_seed() + tutils.set_random_master_port() + + model, hparams = tutils.get_default_model() + model = model.cuda(0) + model.prepare_data() + loader = model.train_dataloader() + for x, y in loader: + x = x.view(x.size(0), -1) + assert model(x).device == torch.device('cuda:0'), "Automoving data to same device as model failed" + x = x.cuda(0) + assert model(x).device == torch.device('cuda:0'), "Automoving data to same device as model failed" + + @pytest.fixture def mocked_device_count(monkeypatch): def device_count():