Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ To use lightning do 2 things:
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

@pl.data_loader
def val_dataloader(self):
def valid_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

Expand Down
2 changes: 1 addition & 1 deletion pl_examples/basic_examples/lightning_module_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def train_dataloader(self):
return self.__dataloader(train=True)

@pl.data_loader
def val_dataloader(self):
def valid_dataloader(self):
logging.info('val data loader called')
return self.__dataloader(train=False)

Expand Down
4 changes: 2 additions & 2 deletions pl_examples/full_examples/imagenet/imagenet_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
class ImageNetLightningModel(pl.LightningModule):

def __init__(self, hparams):
super(ImageNetLightningModel, self).__init__()
super(ImageNetLightningModel, self).__idatasetdatasetnit__()
self.hparams = hparams
self.model = models.__dict__[self.hparams.arch](pretrained=self.hparams.pretrained)

Expand Down Expand Up @@ -159,7 +159,7 @@ def train_dataloader(self):
return train_loader

@pl.data_loader
def val_dataloader(self):
def valid_dataloader(self):
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def train_dataloader(self):
transform=transforms.ToTensor()), batch_size=32)

@pl.data_loader
def val_dataloader(self):
def valid_dataloader(self):
# OPTIONAL
# can also return a list of val dataloaders
return DataLoader(MNIST(os.getcwd(), train=True, download=True,
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _get_data_loader(self):
if (
value is not None and
not isinstance(value, list) and
fn.__name__ in ['test_dataloader', 'val_dataloader']
fn.__name__ in ['test_dataloader', 'valid_dataloader']
):
value = [value]
except AttributeError as e:
Expand Down
32 changes: 20 additions & 12 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def validation_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx, dataset_idx):
# dataset_idx tells you which dataset this is.

The `dataset_idx` corresponds to the order of datasets returned in `val_dataloader`.
The `dataset_idx` corresponds to the order of datasets returned in `valid_dataloader`.
"""
pass

Expand Down Expand Up @@ -850,13 +850,10 @@ def tng_dataloader(self):

.. warning:: Deprecated in v0.5.0. use train_dataloader instead.
"""
try:
output = self.tng_dataloader()
warnings.warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
return output
except NotImplementedError:
raise NotImplementedError
output = self.train_dataloader()
warnings.warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
return output

@data_loader
def test_dataloader(self):
Expand Down Expand Up @@ -891,7 +888,7 @@ def test_dataloader(self):
return None

@data_loader
def val_dataloader(self):
def valid_dataloader(self):
"""Implement a PyTorch DataLoader.

:return: PyTorch DataLoader or list of PyTorch Dataloaders.
Expand All @@ -908,7 +905,7 @@ def val_dataloader(self):
.. code-block:: python

@pl.data_loader
def val_dataloader(self):
def valid_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True)
loader = torch.utils.data.DataLoader(
Expand All @@ -921,14 +918,25 @@ def val_dataloader(self):

# can also return multiple dataloaders
@pl.data_loader
def val_dataloader(self):
def valid_dataloader(self):
return [loader_a, loader_b, ..., loader_n]

In the case where you return multiple `val_dataloaders`, the `validation_step`
In the case where you return multiple `valid_dataloaders`, the `validation_step`
will have an arguement `dataset_idx` which matches the order here.
"""
return None

@data_loader
def val_dataloader(self):
"""Implement a PyTorch DataLoader.

.. warning:: Deprecated in v0.5.0. use valid_dataloader instead.
"""
output = self.valid_dataloader()
warnings.warn("`val_dataloader` has been renamed to `valid_dataloader` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
return output

@classmethod
def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
"""Primary way of loading model from csv weights path.
Expand Down
20 changes: 10 additions & 10 deletions pytorch_lightning/testing/model_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

class LightningValidationStepMixin:
"""
Add val_dataloader and validation_step methods for the case
when val_dataloader returns a single dataloader
Add valid_dataloader and validation_step methods for the case
when valid_dataloader returns a single dataloader
"""

@data_loader
def val_dataloader(self):
def valid_dataloader(self):
return self._dataloader(train=False)

def validation_step(self, batch, batch_idx):
Expand Down Expand Up @@ -61,8 +61,8 @@ def validation_step(self, batch, batch_idx):

class LightningValidationMixin(LightningValidationStepMixin):
"""
Add val_dataloader, validation_step, and validation_end methods for the case
when val_dataloader returns a single dataloader
Add valid_dataloader, validation_step, and validation_end methods for the case
when valid_dataloader returns a single dataloader
"""

def validation_end(self, outputs):
Expand Down Expand Up @@ -101,12 +101,12 @@ def validation_end(self, outputs):

class LightningValidationStepMultipleDataloadersMixin:
"""
Add val_dataloader and validation_step methods for the case
when val_dataloader returns multiple dataloaders
Add valid_dataloader and validation_step methods for the case
when valid_dataloader returns multiple dataloaders
"""

@data_loader
def val_dataloader(self):
def valid_dataloader(self):
return [self._dataloader(train=False), self._dataloader(train=False)]

def validation_step(self, batch, batch_idx, dataloader_idx):
Expand Down Expand Up @@ -161,8 +161,8 @@ def validation_step(self, batch, batch_idx, dataloader_idx):

class LightningValidationMultipleDataloadersMixin(LightningValidationStepMultipleDataloadersMixin):
"""
Add val_dataloader, validation_step, and validation_end methods for the case
when val_dataloader returns multiple dataloaders
Add valid_dataloader, validation_step, and validation_end methods for the case
when valid_dataloader returns multiple dataloaders
"""

def validation_end(self, outputs):
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,27 @@ def init_train_dataloader(self, model):
self.shown_warnings.add(msg)
warnings.warn(msg)

def init_val_dataloader(self, model):
def init_valid_dataloader(self, model):
"""
Dataloaders are provided by the model
:param model:
:return:
"""
self.get_val_dataloaders = model.val_dataloader
self.get_valid_dataloaders = model.valid_dataloader

# determine number of validation batches
# val datasets could be none, 1 or 2+
if self.get_val_dataloaders() is not None:
self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders())
if self.get_valid_dataloaders() is not None:
self.num_val_batches = sum(len(dataloader) for dataloader in self.get_valid_dataloaders())
self.num_val_batches = int(self.num_val_batches * self.val_percent_check)
self.num_val_batches = max(1, self.num_val_batches)

on_ddp = self.use_ddp or self.use_ddp2
if on_ddp and self.get_val_dataloaders() is not None:
for dataloader in self.get_val_dataloaders():
if on_ddp and self.get_valid_dataloaders() is not None:
for dataloader in self.get_valid_dataloaders():
if not isinstance(dataloader.sampler, DistributedSampler):
msg = """
Your val_dataloader(s) don't use DistributedSampler.
Your valid_dataloader(s) don't use DistributedSampler.

You're using multiple gpus and multiple nodes without using a
DistributedSampler to assign a subset of your data to each process.
Expand Down Expand Up @@ -177,7 +177,7 @@ def get_dataloaders(self, model):

self.init_train_dataloader(model)
self.init_test_dataloader(model)
self.init_val_dataloader(model)
self.init_valid_dataloader(model)

if self.use_ddp or self.use_ddp2:
# wait for all processes to catch up
Expand All @@ -186,7 +186,7 @@ def get_dataloaders(self, model):
# load each dataloader
self.get_train_dataloader()
self.get_test_dataloaders()
self.get_val_dataloaders()
self.get_valid_dataloaders()

# support IterableDataset for train data
self.is_iterable_train_dataloader = (
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(self):
self.current_epoch = None
self.callback_metrics = None
self.get_test_dataloaders = None
self.get_val_dataloaders = None
self.get_valid_dataloaders = None

@abstractmethod
def copy_trainer_model_properties(self, model):
Expand Down Expand Up @@ -290,7 +290,7 @@ def run_evaluation(self, test=False):
max_batches = self.num_test_batches
else:
# val
dataloaders = self.get_val_dataloaders()
dataloaders = self.get_valid_dataloaders()
max_batches = self.num_val_batches

# cap max batches to 1 when using fast_dev_run
Expand Down Expand Up @@ -349,7 +349,7 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False
if test and len(self.get_test_dataloaders()) > 1:
args.append(dataloader_idx)

elif not test and len(self.get_val_dataloaders()) > 1:
elif not test and len(self.get_valid_dataloaders()) > 1:
args.append(dataloader_idx)

# handle DP, DDP forward
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(
self.num_test_batches = 0
self.get_train_dataloader = None
self.get_test_dataloaders = None
self.get_val_dataloaders = None
self.get_valid_dataloaders = None
self.is_iterable_train_dataloader = False

# training state
Expand Down Expand Up @@ -490,17 +490,17 @@ def run_pretrain_routine(self, model):
# to make sure program won't crash during val
ref_model.on_sanity_check_start()
ref_model.on_train_start()
if self.get_val_dataloaders() is not None and self.num_sanity_val_steps > 0:
if self.get_valid_dataloaders() is not None and self.num_sanity_val_steps > 0:
# init progress bars for validation sanity check
pbar = tqdm.tqdm(desc='Validation sanity check',
total=self.num_sanity_val_steps * len(self.get_val_dataloaders()),
total=self.num_sanity_val_steps * len(self.get_valid_dataloaders()),
leave=False, position=2 * self.process_position,
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
self.main_progress_bar = pbar
# dummy validation progress bar
self.val_progress_bar = tqdm.tqdm(disable=True)

self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing)
self.evaluate(model, self.get_valid_dataloaders(), self.num_sanity_val_steps, self.testing)

# close progress bars
self.main_progress_bar.close()
Expand Down
2 changes: 1 addition & 1 deletion tests/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def train_dataloader(self):
return DataLoader(MNIST('path/to/save', train=True), batch_size=32)

@pl.data_loader
def val_dataloader(self):
def valid_dataloader(self):
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)

@pl.data_loader
Expand Down
2 changes: 1 addition & 1 deletion tests/test_restore_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def assert_good_acc():
# if model and state loaded correctly, predictions will be good even though we
# haven't trained with the new loaded model
trainer.model.eval()
for dataloader in trainer.get_val_dataloaders():
for dataloader in trainer.get_valid_dataloaders():
tutils.run_prediction(dataloader, trainer.model)

model.on_train_start = assert_good_acc
Expand Down
10 changes: 5 additions & 5 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,8 @@ def test_model_freeze_unfreeze():
model.unfreeze()


def test_multiple_val_dataloader(tmpdir):
"""Verify multiple val_dataloader."""
def test_multiple_valid_dataloader(tmpdir):
"""Verify multiple valid_dataloader."""
tutils.reset_seed()

class CurrentTestModel(
Expand Down Expand Up @@ -367,11 +367,11 @@ class CurrentTestModel(
assert result == 1

# verify there are 2 val loaders
assert len(trainer.get_val_dataloaders()) == 2, \
'Multiple val_dataloaders not initiated properly'
assert len(trainer.get_valid_dataloaders()) == 2, \
'Multiple valid_dataloaders not initiated properly'

# make sure predictions are good for each val set
for dataloader in trainer.get_val_dataloaders():
for dataloader in trainer.get_valid_dataloaders():
tutils.run_prediction(dataloader, trainer.model)


Expand Down