@@ -29,16 +29,16 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule
2929
3030 """
3131 if trainer .state .fn in (TrainerFn .FITTING , TrainerFn .TUNING ):
32- __verify_train_loop_configuration (trainer , model )
33- __verify_eval_loop_configuration (model , "val" )
32+ __verify_train_val_loop_configuration (trainer , model )
3433 __verify_manual_optimization_support (trainer , model )
3534 __check_training_step_requires_dataloader_iter (model )
3635 elif trainer .state .fn == TrainerFn .VALIDATING :
37- __verify_eval_loop_configuration (model , "val" )
36+ __verify_eval_loop_configuration (trainer , model , "val" )
3837 elif trainer .state .fn == TrainerFn .TESTING :
39- __verify_eval_loop_configuration (model , "test" )
38+ __verify_eval_loop_configuration (trainer , model , "test" )
4039 elif trainer .state .fn == TrainerFn .PREDICTING :
41- __verify_predict_loop_configuration (trainer , model )
40+ __verify_eval_loop_configuration (trainer , model , "predict" )
41+
4242 __verify_dp_batch_transfer_support (trainer , model )
4343 _check_add_get_queue (model )
4444 # TODO(@daniellepintz): Delete _check_progress_bar in v1.7
@@ -51,7 +51,7 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule
5151 _check_dl_idx_in_on_train_batch_hooks (trainer , model )
5252
5353
54- def __verify_train_loop_configuration (trainer : "pl.Trainer" , model : "pl.LightningModule" ) -> None :
54+ def __verify_train_val_loop_configuration (trainer : "pl.Trainer" , model : "pl.LightningModule" ) -> None :
5555 # -----------------------------------
5656 # verify model has a training step
5757 # -----------------------------------
@@ -83,24 +83,15 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin
8383 )
8484
8585 # ----------------------------------------------
86- # verify model does not have
87- # - on_train_dataloader
88- # - on_val_dataloader
86+ # verify model does not have on_train_dataloader
8987 # ----------------------------------------------
9088 has_on_train_dataloader = is_overridden ("on_train_dataloader" , model )
9189 if has_on_train_dataloader :
9290 rank_zero_deprecation (
93- "Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
91+ "Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
9492 " Please use `train_dataloader()` directly."
9593 )
9694
97- has_on_val_dataloader = is_overridden ("on_val_dataloader" , model )
98- if has_on_val_dataloader :
99- rank_zero_deprecation (
100- "Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
101- " Please use `val_dataloader()` directly."
102- )
103-
10495 trainer .overriden_optimizer_step = is_overridden ("optimizer_step" , model )
10596 trainer .overriden_optimizer_zero_grad = is_overridden ("optimizer_zero_grad" , model )
10697 automatic_optimization = model .automatic_optimization
@@ -110,8 +101,30 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin
110101 if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization :
111102 rank_zero_warn (
112103 "When using `Trainer(accumulate_grad_batches != 1)` and overriding"
113- "`LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch"
114- "(rather, they are called on every optimization step)."
104+ " `LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch"
105+ " (rather, they are called on every optimization step)."
106+ )
107+
108+ # -----------------------------------
109+ # verify model for val loop
110+ # -----------------------------------
111+
112+ has_val_loader = trainer ._data_connector ._val_dataloader_source .is_defined ()
113+ has_val_step = is_overridden ("validation_step" , model )
114+
115+ if has_val_loader and not has_val_step :
116+ rank_zero_warn ("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop." )
117+ if has_val_step and not has_val_loader :
118+ rank_zero_warn ("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop." )
119+
120+ # ----------------------------------------------
121+ # verify model does not have on_val_dataloader
122+ # ----------------------------------------------
123+ has_on_val_dataloader = is_overridden ("on_val_dataloader" , model )
124+ if has_on_val_dataloader :
125+ rank_zero_deprecation (
126+ "Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
127+ " Please use `val_dataloader()` directly."
115128 )
116129
117130
@@ -143,52 +156,43 @@ def _check_on_post_move_to_device(model: "pl.LightningModule") -> None:
143156 )
144157
145158
146- def __verify_eval_loop_configuration (model : "pl.LightningModule" , stage : str ) -> None :
159+ def __verify_eval_loop_configuration (trainer : "pl.Trainer" , model : "pl.LightningModule" , stage : str ) -> None :
147160 loader_name = f"{ stage } _dataloader"
148- step_name = "validation_step" if stage == "val" else "test_step"
161+ step_name = "validation_step" if stage == "val" else f"{ stage } _step"
162+ trainer_method = "validate" if stage == "val" else stage
163+ on_eval_hook = f"on_{ loader_name } "
149164
150- has_loader = is_overridden ( loader_name , model )
165+ has_loader = getattr ( trainer . _data_connector , f"_ { stage } _dataloader_source" ). is_defined ( )
151166 has_step = is_overridden (step_name , model )
152-
153- if has_loader and not has_step :
154- rank_zero_warn (f"you passed in a { loader_name } but have no { step_name } . Skipping { stage } loop" )
155- if has_step and not has_loader :
156- rank_zero_warn (f"you defined a { step_name } but have no { loader_name } . Skipping { stage } loop" )
167+ has_on_eval_dataloader = is_overridden (on_eval_hook , model )
157168
158169 # ----------------------------------------------
159- # verify model does not have
160- # - on_val_dataloader
161- # - on_test_dataloader
170+ # verify model does not have on_eval_dataloader
162171 # ----------------------------------------------
163- has_on_val_dataloader = is_overridden ("on_val_dataloader" , model )
164- if has_on_val_dataloader :
172+ if has_on_eval_dataloader :
165173 rank_zero_deprecation (
166- "Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0. "
167- " Please use `val_dataloader ()` directly."
174+ f "Method `{ on_eval_hook } ` is deprecated in v1.5.0 and will "
175+ f" be removed in v1.7.0. Please use `{ loader_name } ()` directly."
168176 )
169177
170- has_on_test_dataloader = is_overridden ("on_test_dataloader" , model )
171- if has_on_test_dataloader :
172- rank_zero_deprecation (
173- "Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
174- " Please use `test_dataloader()` directly."
175- )
176-
177-
178- def __verify_predict_loop_configuration (trainer : "pl.Trainer" , model : "pl.LightningModule" ) -> None :
179- has_predict_dataloader = trainer ._data_connector ._predict_dataloader_source .is_defined ()
180- if not has_predict_dataloader :
181- raise MisconfigurationException ("Dataloader not found for `Trainer.predict`" )
182- # ----------------------------------------------
183- # verify model does not have
184- # - on_predict_dataloader
185- # ----------------------------------------------
186- has_on_predict_dataloader = is_overridden ("on_predict_dataloader" , model )
187- if has_on_predict_dataloader :
188- rank_zero_deprecation (
189- "Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
190- " Please use `predict_dataloader()` directly."
191- )
178+ # -----------------------------------
179+ # verify model has an eval_dataloader
180+ # -----------------------------------
181+ if not has_loader :
182+ raise MisconfigurationException (f"No `{ loader_name } ()` method defined to run `Trainer.{ trainer_method } `." )
183+
184+ # predict_step is not required to be overridden
185+ if stage == "predict" :
186+ if model .predict_step is None :
187+ raise MisconfigurationException ("`predict_step` cannot be None to run `Trainer.predict`" )
188+ elif not has_step and not is_overridden ("forward" , model ):
189+ raise MisconfigurationException ("`Trainer.predict` requires `forward` method to run." )
190+ else :
191+ # -----------------------------------
192+ # verify model has an eval_step
193+ # -----------------------------------
194+ if not has_step :
195+ raise MisconfigurationException (f"No `{ step_name } ()` method defined to run `Trainer.{ trainer_method } `." )
192196
193197
194198def __verify_dp_batch_transfer_support (trainer : "pl.Trainer" , model : "pl.LightningModule" ) -> None :
0 commit comments