Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 38 additions & 37 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

class ConfigValidator:

def __init__(self, trainer: 'pl.Trainer') -> None:
def __init__(self, trainer: "pl.Trainer") -> None:
self.trainer = trainer

def verify_loop_configurations(self, model: 'pl.LightningModule') -> None:
def verify_loop_configurations(self, model: "pl.LightningModule") -> None:
r"""
Checks that the model is configured correctly before the run is started.

Expand All @@ -33,82 +33,83 @@ def verify_loop_configurations(self, model: 'pl.LightningModule') -> None:
"""
if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, 'val')
self.__verify_eval_loop_configuration(model, "val")
elif self.trainer.state.fn == TrainerFn.VALIDATING:
self.__verify_eval_loop_configuration(model, 'val')
self.__verify_eval_loop_configuration(model, "val")
elif self.trainer.state.fn == TrainerFn.TESTING:
self.__verify_eval_loop_configuration(model, 'test')
self.__verify_eval_loop_configuration(model, "test")
elif self.trainer.state.fn == TrainerFn.PREDICTING:
self.__verify_predict_loop_configuration(model)
self.__verify_dp_batch_transfer_support(model)

def __verify_train_loop_configuration(self, model: 'pl.LightningModule') -> None:
def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None:
# -----------------------------------
# verify model has a training step
# -----------------------------------
has_training_step = is_overridden('training_step', model)
has_training_step = is_overridden("training_step", model)
if not has_training_step:
raise MisconfigurationException(
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
"No `training_step()` method defined. Lightning `Trainer` expects as minimum a"
" `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
)

# -----------------------------------
# verify model has a train dataloader
# -----------------------------------
has_train_dataloader = is_overridden('train_dataloader', model)
if not has_train_dataloader:
raise MisconfigurationException(
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
)
self.__verify_train_dataloader(model)

# -----------------------------------
# verify model has optimizer
# -----------------------------------
has_optimizers = is_overridden('configure_optimizers', model)
has_optimizers = is_overridden("configure_optimizers", model)
if not has_optimizers:
raise MisconfigurationException(
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
"No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a"
" `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
)

trainer = self.trainer

trainer.overriden_optimizer_step = is_overridden('optimizer_step', model)
trainer.overriden_optimizer_zero_grad = is_overridden('optimizer_zero_grad', model)
trainer.overriden_optimizer_step = is_overridden("optimizer_step", model)
trainer.overriden_optimizer_zero_grad = is_overridden("optimizer_zero_grad", model)
automatic_optimization = model.automatic_optimization
going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches()
going_to_accumulate_grad_batches = (trainer.accumulation_scheduler.going_to_accumulate_grad_batches())

has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization:
has_overriden_optimization_functions = (
trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
)
if (has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization):
raise MisconfigurationException(
'When overriding `LightningModule` optimizer_step or optimizer_zero_grad,'
' `accumulate_grad_batches` in `Trainer` should be 1.'
' It ensures optimizer_step or optimizer_zero_grad are called on every batch.'
"When overriding `LightningModule` optimizer_step or optimizer_zero_grad,"
" `accumulate_grad_batches` in `Trainer` should be 1."
" It ensures optimizer_step or optimizer_zero_grad are called on every batch."
)

def __verify_eval_loop_configuration(self, model: 'pl.LightningModule', stage: str) -> None:
loader_name = f'{stage}_dataloader'
step_name = 'validation_step' if stage == 'val' else 'test_step'
def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: str) -> None:
loader_name = f"{stage}_dataloader"
step_name = "validation_step" if stage == "val" else "test_step"

has_loader = is_overridden(loader_name, model)
has_step = is_overridden(step_name, model)

if has_loader and not has_step:
rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop')
rank_zero_warn(f"you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop")
if has_step and not has_loader:
rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop')
rank_zero_warn(f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop")

def __verify_predict_loop_configuration(self, model: 'pl.LightningModule') -> None:
has_predict_dataloader = is_overridden('predict_dataloader', model)
def __verify_predict_loop_configuration(self, model: "pl.LightningModule") -> None:
has_predict_dataloader = is_overridden("predict_dataloader", model)
if not has_predict_dataloader:
raise MisconfigurationException('Dataloader not found for `Trainer.predict`')
raise MisconfigurationException("Dataloader not found for `Trainer.predict`")

def __verify_dp_batch_transfer_support(self, model: 'pl.LightningModule') -> None:
def __verify_dp_batch_transfer_support(self, model: "pl.LightningModule") -> None:
"""Raise Misconfiguration exception since these hooks are not supported in DP mode"""
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
batch_transfer_hooks = (
"on_before_batch_transfer",
"transfer_batch_to_device",
"on_after_batch_transfer",
)
for hook in batch_transfer_hooks:
if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model):
raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.')
raise MisconfigurationException(f"Overriding `{hook}` is not supported in DP mode.")
35 changes: 25 additions & 10 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(
self, check_val_every_n_epoch: int, reload_dataloaders_every_epoch: bool, prepare_data_per_node: bool
self,
check_val_every_n_epoch: int,
reload_dataloaders_every_epoch: bool,
prepare_data_per_node: bool,
) -> None:
self.trainer.datamodule = None
self.trainer.prepare_data_per_node = prepare_data_per_node
Expand Down Expand Up @@ -59,22 +62,22 @@ def prepare_data(self, model):

def can_prepare_data(self):
should_call_dm_prepare_data = True
if self.trainer.datamodule is not None and is_overridden('prepare_data', self.trainer.datamodule):
if self.trainer.datamodule is not None and is_overridden("prepare_data", self.trainer.datamodule):
should_call_dm_prepare_data = not self.trainer.datamodule.has_prepared_data

if self.trainer.prepare_data_per_node:
return self.trainer.local_rank == 0 and should_call_dm_prepare_data
else:
return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data
return (self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data)

def attach_data(
self,
model: 'pl.LightningModule',
model: "pl.LightningModule",
train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional['pl.LightningDataModule'] = None
datamodule: Optional["pl.LightningDataModule"] = None,
) -> None:
# set up the passed in dataloaders (if needed)
self.attach_dataloaders(
Expand All @@ -90,7 +93,7 @@ def attach_data(

def attach_dataloaders(
self,
model: 'pl.LightningModule',
model: "pl.LightningModule",
train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
Expand All @@ -111,22 +114,34 @@ def attach_dataloaders(
model.predict_dataloader = _PatchDataLoader(predict_dataloaders)

def attach_datamodule(
self, model: 'pl.LightningModule', datamodule: Optional['pl.LightningDataModule'] = None
self,
model: "pl.LightningModule",
datamodule: Optional["pl.LightningDataModule"] = None,
) -> None:
# We use datamodule if it's been provided, otherwise we check model for it
datamodule = datamodule or getattr(model, 'datamodule', None)
datamodule = datamodule or getattr(model, "datamodule", None)

# If we have a datamodule, attach necessary hooks + dataloaders
if datamodule:

# TODO: should't override user code
# Override loader hooks
dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader')
dl_methods = (
"train_dataloader",
"val_dataloader",
"test_dataloader",
"predict_dataloader",
)
for method in dl_methods:
if is_overridden(method, datamodule):
setattr(model, method, getattr(datamodule, method))

# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
batch_transfer_hooks = (
"on_before_batch_transfer",
"transfer_batch_to_device",
"on_after_batch_transfer",
)
for hook in batch_transfer_hooks:
if is_overridden(hook, datamodule):
setattr(model, hook, getattr(datamodule, hook))
Expand Down
Loading