Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Commit 8eae655

Browse files
Better Ensemble Child Inference Defaults (#533)
1 parent f3446c8 commit 8eae655

File tree

5 files changed

+124
-31
lines changed

5 files changed

+124
-31
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ module on test data with partial ground truth files. (Also [522](https://github.
2020
jobs that run in AzureML.
2121

2222
### Changed
23+
- ([#533](https://github.com/microsoft/InnerEye-DeepLearning/pull/533)) Better defaults for inference on ensemble children.
2324
- ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) Renamed command line option 'perform_training_set_inference' to 'inference_on_train_set'. Replaced command line option 'perform_validation_and_test_set_inference' with the pair of options 'inference_on_val_set' and 'inference_on_test_set'.
2425
- ([#496](https://github.com/microsoft/InnerEye-DeepLearning/pull/496)) All plots are now saved as PNG, rather than JPG.
2526
- ([#497](https://github.com/microsoft/InnerEye-DeepLearning/pull/497)) Reducing the size of the code snapshot that

InnerEye/ML/configs/segmentation/BasicModel2Epochs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def __init__(self, **kwargs: Any) -> None:
4444
use_mixed_precision=True,
4545
azure_dataset_id=AZURE_DATASET_ID,
4646
comparison_blob_storage_paths=comparison_blob_storage_paths,
47+
inference_on_test_set=True,
48+
inference_on_val_set=True,
4749
dataset_mountpoint="/tmp/innereye",
4850
# Use an LR scheduler with a pronounced and clearly visible decay, to be able to easily see if that
4951
# is applied correctly in run recovery.

InnerEye/ML/deep_learning_config.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -263,20 +263,42 @@ def validate(self) -> None:
263263
f"found number_of_cross_validation_splits = {self.number_of_cross_validation_splits} "
264264
f"and cross_validation_split_index={self.cross_validation_split_index}")
265265

266-
""" Defaults for when to run inference in the absence of any command line switches. """
267-
INFERENCE_DEFAULTS: Dict[ModelProcessing, Dict[ModelExecutionMode, bool]] = {
266+
"""
267+
Defaults for when to run inference in the absence of any command line switches.
268+
This depends on ModelProcessing, perform_cross_validation, and ModelExecutionMode.
269+
If the current combination of these three parameters is not in this data structure,
270+
then default to False.
271+
"""
272+
INFERENCE_DEFAULTS: Dict[ModelProcessing, Dict[bool, Dict[ModelExecutionMode, bool]]] = {
268273
ModelProcessing.DEFAULT: {
269-
ModelExecutionMode.TRAIN: False,
270-
ModelExecutionMode.TEST: True,
271-
ModelExecutionMode.VAL: True,
274+
False: {
275+
ModelExecutionMode.TRAIN: False,
276+
ModelExecutionMode.TEST: True,
277+
ModelExecutionMode.VAL: True
278+
}
272279
},
273280
ModelProcessing.ENSEMBLE_CREATION: {
274-
ModelExecutionMode.TRAIN: False,
275-
ModelExecutionMode.TEST: True,
276-
ModelExecutionMode.VAL: False,
281+
True: {
282+
ModelExecutionMode.TRAIN: False,
283+
ModelExecutionMode.TEST: True,
284+
ModelExecutionMode.VAL: False
285+
}
277286
}
278287
}
279288

289+
def inference_defaults(self, model_proc: ModelProcessing, data_split: ModelExecutionMode) -> bool:
290+
"""
291+
Returns True if inference is required by default for this model_proc and data_split.
292+
293+
:param model_proc: Whether we are testing an ensemble or single model.
294+
:param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed.
295+
:return: True if inference required by default.
296+
"""
297+
try:
298+
return WorkflowParams.INFERENCE_DEFAULTS[model_proc][self.perform_cross_validation][data_split]
299+
except KeyError:
300+
return False
301+
280302
def inference_options(self) -> Dict[ModelProcessing, Dict[ModelExecutionMode, Optional[bool]]]:
281303
"""
282304
Return a mapping from ModelProcesing and ModelExecutionMode to command line switch.
@@ -308,7 +330,7 @@ def inference_on_set(self, model_proc: ModelProcessing, data_split: ModelExecuti
308330
if inference_option is not None:
309331
return inference_option
310332

311-
return WorkflowParams.INFERENCE_DEFAULTS[model_proc][data_split]
333+
return self.inference_defaults(model_proc, data_split)
312334

313335
@property
314336
def is_offline_run(self) -> bool:

Tests/ML/models/test_scalar_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def test_run_ml_with_classification_model(test_output_dirs: OutputFolderForTests
279279
azure_config = get_default_azure_config()
280280
azure_config.train = True
281281
config: ScalarModelBase = ModelConfigLoader().create_model_config_from_name(model_name)
282+
config.inference_on_test_set = True
282283
config.number_of_cross_validation_splits = number_of_offline_cross_validation_splits
283284
config.set_output_to(test_output_dirs.root_dir)
284285
# Trying to run DDP from the test suite hangs, hence restrict to single GPU.

Tests/ML/runners/test_runner.py

Lines changed: 89 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import time
77
from pathlib import Path
8-
from typing import Tuple
8+
from typing import Optional, Tuple
99
from unittest import mock
1010
from unittest.mock import Mock
1111

@@ -99,49 +99,102 @@ def create_train_and_test_data_small_dataset(image_size: TupleInt3,
9999
return target_dir
100100

101101

102+
@pytest.mark.skipif(common_util.is_windows(), reason="Too slow on windows")
103+
@pytest.mark.parametrize("perform_cross_validation", [True, False])
104+
def test_model_inference_train_and_test_default(test_output_dirs: OutputFolderForTests,
105+
perform_cross_validation: bool) -> None:
106+
"""
107+
Test inference defaults with ModelProcessing.DEFAULT.
108+
109+
:param test_output_dirs: Test output directories.
110+
:param perform_cross_validation: Whether to test with cross validation.
111+
:return: None.
112+
"""
113+
run_model_inference_train_and_test(test_output_dirs,
114+
perform_cross_validation,
115+
model_proc=ModelProcessing.DEFAULT)
116+
117+
102118
@pytest.mark.skipif(common_util.is_windows(), reason="Too slow on windows")
103119
@pytest.mark.parametrize("perform_cross_validation", [True, False])
104120
@pytest.mark.parametrize("inference_on_set", [(True, False, False), (False, True, False), (False, False, True)])
105121
def test_model_inference_train_and_test(test_output_dirs: OutputFolderForTests,
106122
perform_cross_validation: bool,
107123
inference_on_set: Tuple[bool, bool, bool]) -> None:
124+
"""
125+
Test inference overrides with ModelProcessing.DEFAULT.
126+
127+
:param test_output_dirs: Test output directories.
128+
:param perform_cross_validation: Whether to test with cross validation.
129+
:param inference_on_set: Overrides for inference on data sets.
130+
:return: None.
131+
"""
108132
(inference_on_train_set, inference_on_val_set, inference_on_test_set) = inference_on_set
109133
run_model_inference_train_and_test(test_output_dirs,
110134
perform_cross_validation,
111-
inference_on_train_set,
112-
inference_on_val_set,
113-
inference_on_test_set,
114-
False,
115-
False,
116-
False,
117-
ModelProcessing.DEFAULT)
135+
inference_on_train_set=inference_on_train_set,
136+
inference_on_val_set=inference_on_val_set,
137+
inference_on_test_set=inference_on_test_set,
138+
model_proc=ModelProcessing.DEFAULT)
139+
140+
141+
@pytest.mark.skipif(common_util.is_windows(), reason="Too slow on windows")
142+
def test_ensemble_model_inference_train_and_test_default(test_output_dirs: OutputFolderForTests) -> None:
143+
"""
144+
Test inference defaults with ModelProcessing.ENSEMBLE_CREATION.
145+
146+
:param test_output_dirs: Test output directories.
147+
:return: None.
148+
"""
149+
run_model_inference_train_and_test(test_output_dirs,
150+
True,
151+
model_proc=ModelProcessing.ENSEMBLE_CREATION)
118152

119153

120154
@pytest.mark.skipif(common_util.is_windows(), reason="Too slow on windows")
121155
@pytest.mark.parametrize("ensemble_inference_on_set", [(True, False, False), (False, True, False), (False, False, True)])
122156
def test_ensemble_model_inference_train_and_test(test_output_dirs: OutputFolderForTests,
123157
ensemble_inference_on_set: Tuple[bool, bool, bool]) -> None:
158+
"""
159+
Test inference overrides with ModelProcessing.ENSEMBLE_CREATION.
160+
161+
:param test_output_dirs: Test output directories.
162+
:param perform_cross_validation: Whether to test with cross validation.
163+
:param ensemble_inference_on_set: Overrides for inference on data sets.
164+
:return: None.
165+
"""
124166
(ensemble_inference_on_train_set, ensemble_inference_on_val_set, ensemble_inference_on_test_set) = ensemble_inference_on_set
125167
run_model_inference_train_and_test(test_output_dirs,
126168
True,
127-
False,
128-
False,
129-
False,
130-
ensemble_inference_on_train_set,
131-
ensemble_inference_on_val_set,
132-
ensemble_inference_on_test_set,
133-
ModelProcessing.ENSEMBLE_CREATION)
169+
ensemble_inference_on_train_set=ensemble_inference_on_train_set,
170+
ensemble_inference_on_val_set=ensemble_inference_on_val_set,
171+
ensemble_inference_on_test_set=ensemble_inference_on_test_set,
172+
model_proc=ModelProcessing.ENSEMBLE_CREATION)
134173

135174

136175
def run_model_inference_train_and_test(test_output_dirs: OutputFolderForTests,
137176
perform_cross_validation: bool,
138-
inference_on_train_set: bool,
139-
inference_on_val_set: bool,
140-
inference_on_test_set: bool,
141-
ensemble_inference_on_train_set: bool,
142-
ensemble_inference_on_val_set: bool,
143-
ensemble_inference_on_test_set: bool,
144-
model_proc: ModelProcessing) -> None:
177+
inference_on_train_set: Optional[bool] = None,
178+
inference_on_val_set: Optional[bool] = None,
179+
inference_on_test_set: Optional[bool] = None,
180+
ensemble_inference_on_train_set: Optional[bool] = None,
181+
ensemble_inference_on_val_set: Optional[bool] = None,
182+
ensemble_inference_on_test_set: Optional[bool] = None,
183+
model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> None:
184+
"""
185+
Test running inference produces expected output metrics, files, folders and calls to upload_folder.
186+
187+
:param test_output_dirs: Test output directories.
188+
:param perform_cross_validation: Whether to test with cross validation.
189+
:param inference_on_train_set: Override for inference on train data sets.
190+
:param inference_on_val_set: Override for inference on validation data sets.
191+
:param inference_on_test_set: Override for inference on test data sets.
192+
:param ensemble_inference_on_train_set: Override for ensemble inference on train data sets.
193+
:param ensemble_inference_on_val_set: Override for ensemble inference on validation data sets.
194+
:param ensemble_inference_on_test_set: Override for ensemble inference on test data sets.
195+
:param model_proc: Model processing to test.
196+
:return: None.
197+
"""
145198
dummy_model = DummyModel()
146199

147200
config = PassThroughModel()
@@ -202,6 +255,20 @@ def run_model_inference_train_and_test(test_output_dirs: OutputFolderForTests,
202255
if mode in metrics:
203256
metric = metrics[mode]
204257
assert isinstance(metric, InferenceMetricsForSegmentation)
258+
259+
if flag is None:
260+
# No override supplied, calculate the expected default:
261+
if model_proc == ModelProcessing.DEFAULT:
262+
if not perform_cross_validation:
263+
# If a "normal" run then default to val or test.
264+
flag = mode in (ModelExecutionMode.VAL, ModelExecutionMode.TEST)
265+
else:
266+
# If an ensemble child then default to never.
267+
flag = False
268+
else:
269+
# If an ensemble then default to test only.
270+
flag = mode == ModelExecutionMode.TEST
271+
205272
if mode in metrics and not flag:
206273
error = error + f"Error: {mode.value} cannot be not None."
207274
elif mode not in metrics and flag:

0 commit comments

Comments
 (0)