1111from typing import Any , Callable , Dict , List , Optional , Tuple
1212
1313import pandas as pd
14+ from pytorch_lightning .core .datamodule import LightningDataModule
1415import stopit
1516import torch .multiprocessing
1617from azureml ._restclient .constants import RunStatus
@@ -120,19 +121,16 @@ def download_dataset(azure_dataset_id: str,
120121 return expected_dataset_path
121122
122123
123- def log_metrics (val_metrics : Optional [InferenceMetricsForSegmentation ],
124- test_metrics : Optional [InferenceMetricsForSegmentation ],
125- train_metrics : Optional [InferenceMetricsForSegmentation ],
124+ def log_metrics (metrics : Dict [ModelExecutionMode , InferenceMetrics ],
126125 run_context : Run ) -> None :
127126 """
128127 Log metrics for each split to the provided run, or the current run context if None provided
129- :param val_metrics: Inference results for the validation split
130- :param test_metrics: Inference results for the test split
131- :param train_metrics: Inference results for the train split
128+ :param metrics: Dictionary of inference results for each split.
132129 :param run_context: Run for which to log the metrics to, use the current run context if None provided
133130 """
134- for split in [x for x in [val_metrics , test_metrics , train_metrics ] if x ]:
135- split .log_metrics (run_context )
131+ for split in metrics .values ():
132+ if isinstance (split , InferenceMetricsForSegmentation ):
133+ split .log_metrics (run_context )
136134
137135
138136class MLRunner :
@@ -390,7 +388,7 @@ def run(self) -> None:
390388
391389 # If this is an cross validation run, and the present run is child run 0, then wait for the sibling
392390 # runs, build the ensemble model, and write a report for that.
393- if self .container .number_of_cross_validation_splits > 0 :
391+ if self .container .perform_cross_validation :
394392 should_wait_for_other_child_runs = (not self .is_offline_run ) and \
395393 self .container .cross_validation_split_index == 0
396394 if should_wait_for_other_child_runs :
@@ -420,10 +418,24 @@ def is_normal_run_or_crossval_child_0(self) -> bool:
420418 """
421419 Returns True if the present run is a non-crossvalidation run, or child run 0 of a crossvalidation run.
422420 """
423- if self .container .number_of_cross_validation_splits > 0 :
421+ if self .container .perform_cross_validation :
424422 return self .container .cross_validation_split_index == 0
425423 return True
426424
425+ @staticmethod
426+ def lightning_data_module_dataloaders (data : LightningDataModule ) -> Dict [ModelExecutionMode , Callable ]:
427+ """
428+ Given a lightning data module, return a dictionary of dataloader for each model execution mode.
429+
430+ :param data: Lightning data module.
431+ :return: Data loader for each model execution mode.
432+ """
433+ return {
434+ ModelExecutionMode .TEST : data .test_dataloader ,
435+ ModelExecutionMode .VAL : data .val_dataloader ,
436+ ModelExecutionMode .TRAIN : data .train_dataloader
437+ }
438+
427439 def run_inference_for_lightning_models (self , checkpoint_paths : List [Path ]) -> None :
428440 """
429441 Run inference on the test set for all models that are specified via a LightningContainer.
@@ -439,11 +451,10 @@ def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> No
439451 # Read the data modules before changing the working directory, in case the code relies on relative paths
440452 data = self .container .get_inference_data_module ()
441453 dataloaders : List [Tuple [DataLoader , ModelExecutionMode ]] = []
442- if self .container .perform_validation_and_test_set_inference :
443- dataloaders .append ((data .test_dataloader (), ModelExecutionMode .TEST )) # type: ignore
444- dataloaders .append ((data .val_dataloader (), ModelExecutionMode .VAL )) # type: ignore
445- if self .container .perform_training_set_inference :
446- dataloaders .append ((data .train_dataloader (), ModelExecutionMode .TRAIN )) # type: ignore
454+ data_dataloaders = MLRunner .lightning_data_module_dataloaders (data )
455+ for data_split , dataloader in data_dataloaders .items ():
456+ if self .container .inference_on_set (ModelProcessing .DEFAULT , data_split ):
457+ dataloaders .append ((dataloader (), data_split ))
447458 checkpoint = load_checkpoint (checkpoint_paths [0 ], use_gpu = self .container .use_gpu )
448459 lightning_model .load_state_dict (checkpoint ['state_dict' ])
449460 lightning_model .eval ()
@@ -491,8 +502,8 @@ def run_inference(self, checkpoint_handler: CheckpointHandler,
491502 """
492503
493504 # run full image inference on existing or newly trained model on the training, and testing set
494- test_metrics , val_metrics , _ = self .model_inference_train_and_test (checkpoint_handler = checkpoint_handler ,
495- model_proc = model_proc )
505+ self .model_inference_train_and_test (checkpoint_handler = checkpoint_handler ,
506+ model_proc = model_proc )
496507
497508 self .try_compare_scores_against_baselines (model_proc )
498509
@@ -752,37 +763,25 @@ def copy_file(source: Path, destination_file: str) -> None:
752763 def model_inference_train_and_test (self ,
753764 checkpoint_handler : CheckpointHandler ,
754765 model_proc : ModelProcessing = ModelProcessing .DEFAULT ) -> \
755- Tuple [Optional [InferenceMetrics ], Optional [InferenceMetrics ], Optional [InferenceMetrics ]]:
756- train_metrics = None
757- val_metrics = None
758- test_metrics = None
766+ Dict [ModelExecutionMode , InferenceMetrics ]:
767+ metrics : Dict [ModelExecutionMode , InferenceMetrics ] = {}
759768
760769 config = self .innereye_config
761770
762- def run_model_test (data_split : ModelExecutionMode ) -> Optional [InferenceMetrics ]:
763- return model_test (config , data_split = data_split , checkpoint_handler = checkpoint_handler , # type: ignore
764- model_proc = model_proc )
765-
766- if config .perform_validation_and_test_set_inference :
767- # perform inference on test set
768- test_metrics = run_model_test (ModelExecutionMode .TEST )
769- # perform inference on validation set (not for ensemble as current val is in the training fold
770- # for at least one of the models).
771- if model_proc != ModelProcessing .ENSEMBLE_CREATION :
772- val_metrics = run_model_test (ModelExecutionMode .VAL )
773-
774- if config .perform_training_set_inference :
775- # perform inference on training set if required
776- train_metrics = run_model_test (ModelExecutionMode .TRAIN )
771+ for data_split in ModelExecutionMode :
772+ if self .container .inference_on_set (model_proc , data_split ):
773+ opt_metrics = model_test (config , data_split = data_split , checkpoint_handler = checkpoint_handler ,
774+ model_proc = model_proc )
775+ if opt_metrics is not None :
776+ metrics [data_split ] = opt_metrics
777777
778778 # log the metrics to AzureML experiment if possible. When doing ensemble runs, log to the Hyperdrive parent run,
779779 # so that we get the metrics of child run 0 and the ensemble separated.
780780 if config .is_segmentation_model and not self .is_offline_run :
781781 run_for_logging = PARENT_RUN_CONTEXT if model_proc .ENSEMBLE_CREATION else RUN_CONTEXT
782- log_metrics (val_metrics = val_metrics , test_metrics = test_metrics , # type: ignore
783- train_metrics = train_metrics , run_context = run_for_logging ) # type: ignore
782+ log_metrics (metrics = metrics , run_context = run_for_logging ) # type: ignore
784783
785- return test_metrics , val_metrics , train_metrics
784+ return metrics
786785
787786 @stopit .threading_timeoutable ()
788787 def wait_for_runs_to_finish (self , delay : int = 60 ) -> None :
0 commit comments