Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Commit a9fd52c

Browse files
authored
BUG: Fix missing channels dimension in normalization (#701)
* Fix missing channels dimension in normalization * Update CHANGELOG * Add test for 3D and 4D input images * Move conversion to NumPy array
1 parent edc72ed commit a9fd52c

File tree

4 files changed

+33
-17
lines changed

4 files changed

+33
-17
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ gets uploaded to AzureML, by skipping all test folders.
9898

9999
#### Fixed
100100

101+
- ([#701](https://github.com/microsoft/InnerEye-DeepLearning/pull/701)) Fix 3D images expected to be 4D for intensity normalization.
101102
- ([#704](https://github.com/microsoft/InnerEye-DeepLearning/pull/704)) Add submodules to sys.path to fix autodoc's warning.
102103
- ([#699](https://github.com/microsoft/InnerEye-DeepLearning/pull/699)) Fix Sphinx warnings.
103104
- ([#682](https://github.com/microsoft/InnerEye-DeepLearning/pull/682)) Ensure the shape of input patches is compatible with model constraints.

InnerEye/ML/photometric_normalization.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def transform(self, image: Union[np.ndarray, torch.Tensor],
8484
else:
8585
mask = np.ones_like(image)
8686

87+
is3d = image.ndim == 3
88+
if is3d:
89+
image = image[np.newaxis]
90+
8791
self.status_of_most_recent_call = None
8892
if self.norm_method == PhotometricNormalizationMethod.Unchanged:
8993
image_out = image
@@ -116,7 +120,10 @@ def transform(self, image: Union[np.ndarray, torch.Tensor],
116120
raise ValueError("Unknown normalization method {}".format(self.norm_method))
117121
if patient_id is not None and self.status_of_most_recent_call is not None:
118122
logging.debug(f"Photonorm patient {patient_id}: {self.status_of_most_recent_call}")
119-
check_array_range(image_out, error_prefix="Normalized image")
123+
check_array_range(np.asarray(image_out), error_prefix="Normalized image")
124+
125+
if is3d:
126+
image_out = image_out[0]
120127

121128
return image_out
122129

InnerEye/ML/utils/image_util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,16 +389,16 @@ def get_center_crop(image: NumpyOrTorch, crop_shape: TupleInt3) -> NumpyOrTorch:
389389

390390

391391
def check_array_range(data: np.ndarray, expected_range: Optional[Range] = None,
392-
error_prefix: str = None) -> None:
392+
error_prefix: Optional[str] = None) -> None:
393393
"""
394394
Checks if all values in the given array fall into the expected range. If not, raises a
395-
ValueError, and prints out statistics about the values that fell outside the expected range.
395+
``ValueError``, and prints out statistics about the values that fell outside the expected range.
396396
If no range is provided, it checks that all values in the array are finite (that is, they are not
397-
infinity and not np.nan
397+
infinity and not ``np.nan``).
398398
399399
:param data: The array to check. It can have any size.
400400
:param expected_range: The interval that all array elements must fall into. The first entry is the lower
401-
bound, the second entry is the upper bound.
401+
bound, the second entry is the upper bound.
402402
:param error_prefix: A string to use as the prefix for the error message.
403403
"""
404404
if expected_range is None:

Tests/ML/test_normalize.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@
3636

3737

3838
@pytest.fixture
39-
def image_rand_pos() -> Union[torch.Tensor, np.ndarray]:
39+
def image_rand_pos() -> np.ndarray:
4040
torch.random.manual_seed(1)
4141
np.random.seed(0)
4242
return (np.random.rand(3, 4, 4, 4) * 1000.0).astype(ImageDataType.IMAGE.value)
4343

4444

4545
@pytest.fixture
46-
def image_rand_pos_gpu(image_rand_pos: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]:
46+
def image_rand_pos_gpu(image_rand_pos: np.ndarray) -> Union[torch.Tensor, np.ndarray]:
4747
return torch.tensor(image_rand_pos) if use_gpu else image_rand_pos
4848

4949

@@ -56,42 +56,50 @@ def assert_image_out_datatype(image_out: np.ndarray) -> None:
5656
"datatype that we force images to have."
5757

5858

59-
def test_simplenorm_half(image_rand_pos: Union[torch.Tensor, np.ndarray]) -> None:
59+
def test_simplenorm_half(image_rand_pos: np.ndarray) -> None:
6060
image_out = photometric_normalization.simple_norm(image_rand_pos, mask_half, debug_mode=True)
6161
assert np.mean(image_out, dtype=np.float) == approx(-0.05052318)
6262
for c in range(image_out.shape[0]):
6363
assert np.mean(image_out[c, mask_half > 0.5], dtype=np.float) == approx(0, abs=1e-7)
6464
assert_image_out_datatype(image_out)
6565

6666

67-
def test_simplenorm_ones(image_rand_pos: Union[torch.Tensor, np.ndarray]) -> None:
67+
def test_simplenorm_ones(image_rand_pos: np.ndarray) -> None:
6868
image_out = photometric_normalization.simple_norm(image_rand_pos, mask_ones, debug_mode=True)
6969
assert np.mean(image_out) == approx(0, abs=1e-7)
7070
assert_image_out_datatype(image_out)
7171

7272

73-
def test_mriwindowhalf(image_rand_pos: Union[torch.Tensor, np.ndarray]) -> None:
74-
image_out, status = photometric_normalization.mri_window(image_rand_pos, mask_half, (0, 1), sharpen, tail)
73+
def test_3d_4d(image_rand_pos: np.ndarray) -> None:
74+
normalization = photometric_normalization.PhotometricNormalization()
75+
shape = image_rand_pos.shape
76+
spatial_shape = shape[1:]
77+
assert normalization.transform(image_rand_pos).shape == shape
78+
assert normalization.transform(image_rand_pos[0]).shape == spatial_shape
79+
80+
81+
def test_mriwindowhalf(image_rand_pos: np.ndarray) -> None:
82+
image_out, _ = photometric_normalization.mri_window(image_rand_pos, mask_half, (0, 1), sharpen, tail)
7583
assert np.mean(image_out) == approx(0.2748852)
7684
assert_image_out_datatype(image_out)
7785

7886

79-
def test_mriwindowones(image_rand_pos: Union[torch.Tensor, np.ndarray]) -> None:
80-
image_out, status = photometric_normalization.mri_window(image_rand_pos, mask_ones, (0.0, 1.0), sharpen, tail3)
87+
def test_mriwindowones(image_rand_pos: np.ndarray) -> None:
88+
image_out, _ = photometric_normalization.mri_window(image_rand_pos, mask_ones, (0.0, 1.0), sharpen, tail3)
8189
assert np.mean(image_out) == approx(0.2748852)
8290
assert_image_out_datatype(image_out)
8391

8492

85-
def test_trimmed_norm_full(image_rand_pos: Union[torch.Tensor, np.ndarray]) -> None:
86-
image_out, status = photometric_normalization.normalize_trim(image_rand_pos, mask_ones,
93+
def test_trimmed_norm_full(image_rand_pos: np.ndarray) -> None:
94+
image_out, _ = photometric_normalization.normalize_trim(image_rand_pos, mask_ones,
8795
output_range=(-1, 1), sharpen=1,
8896
trim_percentiles=(1, 99))
8997
assert np.mean(image_out, dtype=np.float) == approx(-0.08756259549409151)
9098
assert_image_out_datatype(image_out)
9199

92100

93-
def test_trimmed_norm_half(image_rand_pos: Union[torch.Tensor, np.ndarray]) -> None:
94-
image_out, status = photometric_normalization.normalize_trim(image_rand_pos, mask_half,
101+
def test_trimmed_norm_half(image_rand_pos: np.ndarray) -> None:
102+
image_out, _ = photometric_normalization.normalize_trim(image_rand_pos, mask_half,
95103
output_range=(-1, 1), sharpen=1,
96104
trim_percentiles=(1, 99))
97105
assert np.mean(image_out, dtype=np.float) == approx(-0.4862089517215888)

0 commit comments

Comments
 (0)