diff --git a/README.md b/README.md index 4be0bfad5634b..67d61c8c31e03 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,7 @@ To use lightning do 2 things: return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @pl.data_loader - def val_dataloader(self): + def valid_dataloader(self): # OPTIONAL return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) diff --git a/pl_examples/basic_examples/lightning_module_template.py b/pl_examples/basic_examples/lightning_module_template.py index eb1fe9f149503..36b2e1bfb0ad6 100644 --- a/pl_examples/basic_examples/lightning_module_template.py +++ b/pl_examples/basic_examples/lightning_module_template.py @@ -219,7 +219,7 @@ def train_dataloader(self): return self.__dataloader(train=True) @pl.data_loader - def val_dataloader(self): + def valid_dataloader(self): logging.info('val data loader called') return self.__dataloader(train=False) diff --git a/pl_examples/full_examples/imagenet/imagenet_example.py b/pl_examples/full_examples/imagenet/imagenet_example.py index a277705c3296f..87c61b95d5f32 100644 --- a/pl_examples/full_examples/imagenet/imagenet_example.py +++ b/pl_examples/full_examples/imagenet/imagenet_example.py @@ -32,7 +32,7 @@ class ImageNetLightningModel(pl.LightningModule): def __init__(self, hparams): - super(ImageNetLightningModel, self).__init__() + super(ImageNetLightningModel, self).__idatasetdatasetnit__() self.hparams = hparams self.model = models.__dict__[self.hparams.arch](pretrained=self.hparams.pretrained) @@ -159,7 +159,7 @@ def train_dataloader(self): return train_loader @pl.data_loader - def val_dataloader(self): + def valid_dataloader(self): normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], diff --git a/pytorch_lightning/core/__init__.py b/pytorch_lightning/core/__init__.py index c2694eabf5758..39dfd888bb982 100644 --- a/pytorch_lightning/core/__init__.py +++ b/pytorch_lightning/core/__init__.py @@ -72,7 +72,7 @@ def train_dataloader(self): transform=transforms.ToTensor()), batch_size=32) @pl.data_loader - def val_dataloader(self): + def valid_dataloader(self): # OPTIONAL # can also return a list of val dataloaders return DataLoader(MNIST(os.getcwd(), train=True, download=True, diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 0a87e00f57fc7..e1e1a006b767c 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -19,7 +19,7 @@ def _get_data_loader(self): if ( value is not None and not isinstance(value, list) and - fn.__name__ in ['test_dataloader', 'val_dataloader'] + fn.__name__ in ['test_dataloader', 'valid_dataloader'] ): value = [value] except AttributeError as e: diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d2ba6c8a0b7d3..7a14953c22250 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -310,7 +310,7 @@ def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx, dataset_idx): # dataset_idx tells you which dataset this is. - The `dataset_idx` corresponds to the order of datasets returned in `val_dataloader`. + The `dataset_idx` corresponds to the order of datasets returned in `valid_dataloader`. """ pass @@ -850,13 +850,10 @@ def tng_dataloader(self): .. warning:: Deprecated in v0.5.0. use train_dataloader instead. """ - try: - output = self.tng_dataloader() - warnings.warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0" - " and will be removed in v0.8.0", DeprecationWarning) - return output - except NotImplementedError: - raise NotImplementedError + output = self.train_dataloader() + warnings.warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0" + " and will be removed in v0.8.0", DeprecationWarning) + return output @data_loader def test_dataloader(self): @@ -891,7 +888,7 @@ def test_dataloader(self): return None @data_loader - def val_dataloader(self): + def valid_dataloader(self): """Implement a PyTorch DataLoader. :return: PyTorch DataLoader or list of PyTorch Dataloaders. @@ -908,7 +905,7 @@ def val_dataloader(self): .. code-block:: python @pl.data_loader - def val_dataloader(self): + def valid_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( @@ -921,14 +918,25 @@ def val_dataloader(self): # can also return multiple dataloaders @pl.data_loader - def val_dataloader(self): + def valid_dataloader(self): return [loader_a, loader_b, ..., loader_n] - In the case where you return multiple `val_dataloaders`, the `validation_step` + In the case where you return multiple `valid_dataloaders`, the `validation_step` will have an arguement `dataset_idx` which matches the order here. """ return None + @data_loader + def val_dataloader(self): + """Implement a PyTorch DataLoader. + + .. warning:: Deprecated in v0.5.0. use valid_dataloader instead. + """ + output = self.valid_dataloader() + warnings.warn("`val_dataloader` has been renamed to `valid_dataloader` since v0.5.0" + " and will be removed in v0.8.0", DeprecationWarning) + return output + @classmethod def load_from_metrics(cls, weights_path, tags_csv, map_location=None): """Primary way of loading model from csv weights path. diff --git a/pytorch_lightning/testing/model_mixins.py b/pytorch_lightning/testing/model_mixins.py index b568676685c23..a3c340f989037 100644 --- a/pytorch_lightning/testing/model_mixins.py +++ b/pytorch_lightning/testing/model_mixins.py @@ -7,12 +7,12 @@ class LightningValidationStepMixin: """ - Add val_dataloader and validation_step methods for the case - when val_dataloader returns a single dataloader + Add valid_dataloader and validation_step methods for the case + when valid_dataloader returns a single dataloader """ @data_loader - def val_dataloader(self): + def valid_dataloader(self): return self._dataloader(train=False) def validation_step(self, batch, batch_idx): @@ -61,8 +61,8 @@ def validation_step(self, batch, batch_idx): class LightningValidationMixin(LightningValidationStepMixin): """ - Add val_dataloader, validation_step, and validation_end methods for the case - when val_dataloader returns a single dataloader + Add valid_dataloader, validation_step, and validation_end methods for the case + when valid_dataloader returns a single dataloader """ def validation_end(self, outputs): @@ -101,12 +101,12 @@ def validation_end(self, outputs): class LightningValidationStepMultipleDataloadersMixin: """ - Add val_dataloader and validation_step methods for the case - when val_dataloader returns multiple dataloaders + Add valid_dataloader and validation_step methods for the case + when valid_dataloader returns multiple dataloaders """ @data_loader - def val_dataloader(self): + def valid_dataloader(self): return [self._dataloader(train=False), self._dataloader(train=False)] def validation_step(self, batch, batch_idx, dataloader_idx): @@ -161,8 +161,8 @@ def validation_step(self, batch, batch_idx, dataloader_idx): class LightningValidationMultipleDataloadersMixin(LightningValidationStepMultipleDataloadersMixin): """ - Add val_dataloader, validation_step, and validation_end methods for the case - when val_dataloader returns multiple dataloaders + Add valid_dataloader, validation_step, and validation_end methods for the case + when valid_dataloader returns multiple dataloaders """ def validation_end(self, outputs): diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 84413697948d5..0c270abc3353f 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -82,27 +82,27 @@ def init_train_dataloader(self, model): self.shown_warnings.add(msg) warnings.warn(msg) - def init_val_dataloader(self, model): + def init_valid_dataloader(self, model): """ Dataloaders are provided by the model :param model: :return: """ - self.get_val_dataloaders = model.val_dataloader + self.get_valid_dataloaders = model.valid_dataloader # determine number of validation batches # val datasets could be none, 1 or 2+ - if self.get_val_dataloaders() is not None: - self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders()) + if self.get_valid_dataloaders() is not None: + self.num_val_batches = sum(len(dataloader) for dataloader in self.get_valid_dataloaders()) self.num_val_batches = int(self.num_val_batches * self.val_percent_check) self.num_val_batches = max(1, self.num_val_batches) on_ddp = self.use_ddp or self.use_ddp2 - if on_ddp and self.get_val_dataloaders() is not None: - for dataloader in self.get_val_dataloaders(): + if on_ddp and self.get_valid_dataloaders() is not None: + for dataloader in self.get_valid_dataloaders(): if not isinstance(dataloader.sampler, DistributedSampler): msg = """ - Your val_dataloader(s) don't use DistributedSampler. + Your valid_dataloader(s) don't use DistributedSampler. You're using multiple gpus and multiple nodes without using a DistributedSampler to assign a subset of your data to each process. @@ -177,7 +177,7 @@ def get_dataloaders(self, model): self.init_train_dataloader(model) self.init_test_dataloader(model) - self.init_val_dataloader(model) + self.init_valid_dataloader(model) if self.use_ddp or self.use_ddp2: # wait for all processes to catch up @@ -186,7 +186,7 @@ def get_dataloaders(self, model): # load each dataloader self.get_train_dataloader() self.get_test_dataloaders() - self.get_val_dataloaders() + self.get_valid_dataloaders() # support IterableDataset for train data self.is_iterable_train_dataloader = ( diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 0382477768cbf..441f0f553ef98 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -157,7 +157,7 @@ def __init__(self): self.current_epoch = None self.callback_metrics = None self.get_test_dataloaders = None - self.get_val_dataloaders = None + self.get_valid_dataloaders = None @abstractmethod def copy_trainer_model_properties(self, model): @@ -290,7 +290,7 @@ def run_evaluation(self, test=False): max_batches = self.num_test_batches else: # val - dataloaders = self.get_val_dataloaders() + dataloaders = self.get_valid_dataloaders() max_batches = self.num_val_batches # cap max batches to 1 when using fast_dev_run @@ -349,7 +349,7 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False if test and len(self.get_test_dataloaders()) > 1: args.append(dataloader_idx) - elif not test and len(self.get_val_dataloaders()) > 1: + elif not test and len(self.get_valid_dataloaders()) > 1: args.append(dataloader_idx) # handle DP, DDP forward diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 50282866e2b93..810f7e143a4b9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -207,7 +207,7 @@ def __init__( self.num_test_batches = 0 self.get_train_dataloader = None self.get_test_dataloaders = None - self.get_val_dataloaders = None + self.get_valid_dataloaders = None self.is_iterable_train_dataloader = False # training state @@ -490,17 +490,17 @@ def run_pretrain_routine(self, model): # to make sure program won't crash during val ref_model.on_sanity_check_start() ref_model.on_train_start() - if self.get_val_dataloaders() is not None and self.num_sanity_val_steps > 0: + if self.get_valid_dataloaders() is not None and self.num_sanity_val_steps > 0: # init progress bars for validation sanity check pbar = tqdm.tqdm(desc='Validation sanity check', - total=self.num_sanity_val_steps * len(self.get_val_dataloaders()), + total=self.num_sanity_val_steps * len(self.get_valid_dataloaders()), leave=False, position=2 * self.process_position, disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch') self.main_progress_bar = pbar # dummy validation progress bar self.val_progress_bar = tqdm.tqdm(disable=True) - self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing) + self.evaluate(model, self.get_valid_dataloaders(), self.num_sanity_val_steps, self.testing) # close progress bars self.main_progress_bar.close() diff --git a/tests/debug.py b/tests/debug.py index aa0614a9767bd..8893a0eb1c177 100644 --- a/tests/debug.py +++ b/tests/debug.py @@ -46,7 +46,7 @@ def train_dataloader(self): return DataLoader(MNIST('path/to/save', train=True), batch_size=32) @pl.data_loader - def val_dataloader(self): + def valid_dataloader(self): return DataLoader(MNIST('path/to/save', train=False), batch_size=32) @pl.data_loader diff --git a/tests/test_restore_models.py b/tests/test_restore_models.py index b04bb07df7557..fab94b30f3e89 100644 --- a/tests/test_restore_models.py +++ b/tests/test_restore_models.py @@ -308,7 +308,7 @@ def assert_good_acc(): # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model trainer.model.eval() - for dataloader in trainer.get_val_dataloaders(): + for dataloader in trainer.get_valid_dataloaders(): tutils.run_prediction(dataloader, trainer.model) model.on_train_start = assert_good_acc diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 65d5e4be575d2..56c16ea4be6f7 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -338,8 +338,8 @@ def test_model_freeze_unfreeze(): model.unfreeze() -def test_multiple_val_dataloader(tmpdir): - """Verify multiple val_dataloader.""" +def test_multiple_valid_dataloader(tmpdir): + """Verify multiple valid_dataloader.""" tutils.reset_seed() class CurrentTestModel( @@ -367,11 +367,11 @@ class CurrentTestModel( assert result == 1 # verify there are 2 val loaders - assert len(trainer.get_val_dataloaders()) == 2, \ - 'Multiple val_dataloaders not initiated properly' + assert len(trainer.get_valid_dataloaders()) == 2, \ + 'Multiple valid_dataloaders not initiated properly' # make sure predictions are good for each val set - for dataloader in trainer.get_val_dataloaders(): + for dataloader in trainer.get_valid_dataloaders(): tutils.run_prediction(dataloader, trainer.model)