|
5 | 5 | import logging |
6 | 6 | import time |
7 | 7 | from pathlib import Path |
8 | | -from typing import Tuple |
| 8 | +from typing import Optional, Tuple |
9 | 9 | from unittest import mock |
10 | 10 | from unittest.mock import Mock |
11 | 11 |
|
@@ -99,49 +99,102 @@ def create_train_and_test_data_small_dataset(image_size: TupleInt3, |
99 | 99 | return target_dir |
100 | 100 |
|
101 | 101 |
|
| 102 | +@pytest.mark.skipif(common_util.is_windows(), reason="Too slow on windows") |
| 103 | +@pytest.mark.parametrize("perform_cross_validation", [True, False]) |
| 104 | +def test_model_inference_train_and_test_default(test_output_dirs: OutputFolderForTests, |
| 105 | + perform_cross_validation: bool) -> None: |
| 106 | + """ |
| 107 | + Test inference defaults with ModelProcessing.DEFAULT. |
| 108 | +
|
| 109 | + :param test_output_dirs: Test output directories. |
| 110 | + :param perform_cross_validation: Whether to test with cross validation. |
| 111 | + :return: None. |
| 112 | + """ |
| 113 | + run_model_inference_train_and_test(test_output_dirs, |
| 114 | + perform_cross_validation, |
| 115 | + model_proc=ModelProcessing.DEFAULT) |
| 116 | + |
| 117 | + |
102 | 118 | @pytest.mark.skipif(common_util.is_windows(), reason="Too slow on windows") |
103 | 119 | @pytest.mark.parametrize("perform_cross_validation", [True, False]) |
104 | 120 | @pytest.mark.parametrize("inference_on_set", [(True, False, False), (False, True, False), (False, False, True)]) |
105 | 121 | def test_model_inference_train_and_test(test_output_dirs: OutputFolderForTests, |
106 | 122 | perform_cross_validation: bool, |
107 | 123 | inference_on_set: Tuple[bool, bool, bool]) -> None: |
| 124 | + """ |
| 125 | + Test inference overrides with ModelProcessing.DEFAULT. |
| 126 | +
|
| 127 | + :param test_output_dirs: Test output directories. |
| 128 | + :param perform_cross_validation: Whether to test with cross validation. |
| 129 | + :param inference_on_set: Overrides for inference on data sets. |
| 130 | + :return: None. |
| 131 | + """ |
108 | 132 | (inference_on_train_set, inference_on_val_set, inference_on_test_set) = inference_on_set |
109 | 133 | run_model_inference_train_and_test(test_output_dirs, |
110 | 134 | perform_cross_validation, |
111 | | - inference_on_train_set, |
112 | | - inference_on_val_set, |
113 | | - inference_on_test_set, |
114 | | - False, |
115 | | - False, |
116 | | - False, |
117 | | - ModelProcessing.DEFAULT) |
| 135 | + inference_on_train_set=inference_on_train_set, |
| 136 | + inference_on_val_set=inference_on_val_set, |
| 137 | + inference_on_test_set=inference_on_test_set, |
| 138 | + model_proc=ModelProcessing.DEFAULT) |
| 139 | + |
| 140 | + |
| 141 | +@pytest.mark.skipif(common_util.is_windows(), reason="Too slow on windows") |
| 142 | +def test_ensemble_model_inference_train_and_test_default(test_output_dirs: OutputFolderForTests) -> None: |
| 143 | + """ |
| 144 | + Test inference defaults with ModelProcessing.ENSEMBLE_CREATION. |
| 145 | +
|
| 146 | + :param test_output_dirs: Test output directories. |
| 147 | + :return: None. |
| 148 | + """ |
| 149 | + run_model_inference_train_and_test(test_output_dirs, |
| 150 | + True, |
| 151 | + model_proc=ModelProcessing.ENSEMBLE_CREATION) |
118 | 152 |
|
119 | 153 |
|
120 | 154 | @pytest.mark.skipif(common_util.is_windows(), reason="Too slow on windows") |
121 | 155 | @pytest.mark.parametrize("ensemble_inference_on_set", [(True, False, False), (False, True, False), (False, False, True)]) |
122 | 156 | def test_ensemble_model_inference_train_and_test(test_output_dirs: OutputFolderForTests, |
123 | 157 | ensemble_inference_on_set: Tuple[bool, bool, bool]) -> None: |
| 158 | + """ |
| 159 | + Test inference overrides with ModelProcessing.ENSEMBLE_CREATION. |
| 160 | +
|
| 161 | + :param test_output_dirs: Test output directories. |
| 162 | + :param perform_cross_validation: Whether to test with cross validation. |
| 163 | + :param ensemble_inference_on_set: Overrides for inference on data sets. |
| 164 | + :return: None. |
| 165 | + """ |
124 | 166 | (ensemble_inference_on_train_set, ensemble_inference_on_val_set, ensemble_inference_on_test_set) = ensemble_inference_on_set |
125 | 167 | run_model_inference_train_and_test(test_output_dirs, |
126 | 168 | True, |
127 | | - False, |
128 | | - False, |
129 | | - False, |
130 | | - ensemble_inference_on_train_set, |
131 | | - ensemble_inference_on_val_set, |
132 | | - ensemble_inference_on_test_set, |
133 | | - ModelProcessing.ENSEMBLE_CREATION) |
| 169 | + ensemble_inference_on_train_set=ensemble_inference_on_train_set, |
| 170 | + ensemble_inference_on_val_set=ensemble_inference_on_val_set, |
| 171 | + ensemble_inference_on_test_set=ensemble_inference_on_test_set, |
| 172 | + model_proc=ModelProcessing.ENSEMBLE_CREATION) |
134 | 173 |
|
135 | 174 |
|
136 | 175 | def run_model_inference_train_and_test(test_output_dirs: OutputFolderForTests, |
137 | 176 | perform_cross_validation: bool, |
138 | | - inference_on_train_set: bool, |
139 | | - inference_on_val_set: bool, |
140 | | - inference_on_test_set: bool, |
141 | | - ensemble_inference_on_train_set: bool, |
142 | | - ensemble_inference_on_val_set: bool, |
143 | | - ensemble_inference_on_test_set: bool, |
144 | | - model_proc: ModelProcessing) -> None: |
| 177 | + inference_on_train_set: Optional[bool] = None, |
| 178 | + inference_on_val_set: Optional[bool] = None, |
| 179 | + inference_on_test_set: Optional[bool] = None, |
| 180 | + ensemble_inference_on_train_set: Optional[bool] = None, |
| 181 | + ensemble_inference_on_val_set: Optional[bool] = None, |
| 182 | + ensemble_inference_on_test_set: Optional[bool] = None, |
| 183 | + model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> None: |
| 184 | + """ |
| 185 | + Test running inference produces expected output metrics, files, folders and calls to upload_folder. |
| 186 | +
|
| 187 | + :param test_output_dirs: Test output directories. |
| 188 | + :param perform_cross_validation: Whether to test with cross validation. |
| 189 | + :param inference_on_train_set: Override for inference on train data sets. |
| 190 | + :param inference_on_val_set: Override for inference on validation data sets. |
| 191 | + :param inference_on_test_set: Override for inference on test data sets. |
| 192 | + :param ensemble_inference_on_train_set: Override for ensemble inference on train data sets. |
| 193 | + :param ensemble_inference_on_val_set: Override for ensemble inference on validation data sets. |
| 194 | + :param ensemble_inference_on_test_set: Override for ensemble inference on test data sets. |
| 195 | + :param model_proc: Model processing to test. |
| 196 | + :return: None. |
| 197 | + """ |
145 | 198 | dummy_model = DummyModel() |
146 | 199 |
|
147 | 200 | config = PassThroughModel() |
@@ -202,6 +255,20 @@ def run_model_inference_train_and_test(test_output_dirs: OutputFolderForTests, |
202 | 255 | if mode in metrics: |
203 | 256 | metric = metrics[mode] |
204 | 257 | assert isinstance(metric, InferenceMetricsForSegmentation) |
| 258 | + |
| 259 | + if flag is None: |
| 260 | + # No override supplied, calculate the expected default: |
| 261 | + if model_proc == ModelProcessing.DEFAULT: |
| 262 | + if not perform_cross_validation: |
| 263 | + # If a "normal" run then default to val or test. |
| 264 | + flag = mode in (ModelExecutionMode.VAL, ModelExecutionMode.TEST) |
| 265 | + else: |
| 266 | + # If an ensemble child then default to never. |
| 267 | + flag = False |
| 268 | + else: |
| 269 | + # If an ensemble then default to test only. |
| 270 | + flag = mode == ModelExecutionMode.TEST |
| 271 | + |
205 | 272 | if mode in metrics and not flag: |
206 | 273 | error = error + f"Error: {mode.value} cannot be not None." |
207 | 274 | elif mode not in metrics and flag: |
|
0 commit comments