3636
3737
3838@pytest .fixture
39- def image_rand_pos () -> Union [ torch . Tensor , np .ndarray ] :
39+ def image_rand_pos () -> np .ndarray :
4040 torch .random .manual_seed (1 )
4141 np .random .seed (0 )
4242 return (np .random .rand (3 , 4 , 4 , 4 ) * 1000.0 ).astype (ImageDataType .IMAGE .value )
4343
4444
4545@pytest .fixture
46- def image_rand_pos_gpu (image_rand_pos : Union [ torch . Tensor , np .ndarray ] ) -> Union [torch .Tensor , np .ndarray ]:
46+ def image_rand_pos_gpu (image_rand_pos : np .ndarray ) -> Union [torch .Tensor , np .ndarray ]:
4747 return torch .tensor (image_rand_pos ) if use_gpu else image_rand_pos
4848
4949
@@ -56,42 +56,50 @@ def assert_image_out_datatype(image_out: np.ndarray) -> None:
5656 "datatype that we force images to have."
5757
5858
59- def test_simplenorm_half (image_rand_pos : Union [ torch . Tensor , np .ndarray ] ) -> None :
59+ def test_simplenorm_half (image_rand_pos : np .ndarray ) -> None :
6060 image_out = photometric_normalization .simple_norm (image_rand_pos , mask_half , debug_mode = True )
6161 assert np .mean (image_out , dtype = np .float ) == approx (- 0.05052318 )
6262 for c in range (image_out .shape [0 ]):
6363 assert np .mean (image_out [c , mask_half > 0.5 ], dtype = np .float ) == approx (0 , abs = 1e-7 )
6464 assert_image_out_datatype (image_out )
6565
6666
67- def test_simplenorm_ones (image_rand_pos : Union [ torch . Tensor , np .ndarray ] ) -> None :
67+ def test_simplenorm_ones (image_rand_pos : np .ndarray ) -> None :
6868 image_out = photometric_normalization .simple_norm (image_rand_pos , mask_ones , debug_mode = True )
6969 assert np .mean (image_out ) == approx (0 , abs = 1e-7 )
7070 assert_image_out_datatype (image_out )
7171
7272
73- def test_mriwindowhalf (image_rand_pos : Union [torch .Tensor , np .ndarray ]) -> None :
74- image_out , status = photometric_normalization .mri_window (image_rand_pos , mask_half , (0 , 1 ), sharpen , tail )
73+ def test_3d_4d (image_rand_pos : np .ndarray ) -> None :
74+ normalization = photometric_normalization .PhotometricNormalization ()
75+ shape = image_rand_pos .shape
76+ spatial_shape = shape [1 :]
77+ assert normalization .transform (image_rand_pos ).shape == shape
78+ assert normalization .transform (image_rand_pos [0 ]).shape == spatial_shape
79+
80+
81+ def test_mriwindowhalf (image_rand_pos : np .ndarray ) -> None :
82+ image_out , _ = photometric_normalization .mri_window (image_rand_pos , mask_half , (0 , 1 ), sharpen , tail )
7583 assert np .mean (image_out ) == approx (0.2748852 )
7684 assert_image_out_datatype (image_out )
7785
7886
79- def test_mriwindowones (image_rand_pos : Union [ torch . Tensor , np .ndarray ] ) -> None :
80- image_out , status = photometric_normalization .mri_window (image_rand_pos , mask_ones , (0.0 , 1.0 ), sharpen , tail3 )
87+ def test_mriwindowones (image_rand_pos : np .ndarray ) -> None :
88+ image_out , _ = photometric_normalization .mri_window (image_rand_pos , mask_ones , (0.0 , 1.0 ), sharpen , tail3 )
8189 assert np .mean (image_out ) == approx (0.2748852 )
8290 assert_image_out_datatype (image_out )
8391
8492
85- def test_trimmed_norm_full (image_rand_pos : Union [ torch . Tensor , np .ndarray ] ) -> None :
86- image_out , status = photometric_normalization .normalize_trim (image_rand_pos , mask_ones ,
93+ def test_trimmed_norm_full (image_rand_pos : np .ndarray ) -> None :
94+ image_out , _ = photometric_normalization .normalize_trim (image_rand_pos , mask_ones ,
8795 output_range = (- 1 , 1 ), sharpen = 1 ,
8896 trim_percentiles = (1 , 99 ))
8997 assert np .mean (image_out , dtype = np .float ) == approx (- 0.08756259549409151 )
9098 assert_image_out_datatype (image_out )
9199
92100
93- def test_trimmed_norm_half (image_rand_pos : Union [ torch . Tensor , np .ndarray ] ) -> None :
94- image_out , status = photometric_normalization .normalize_trim (image_rand_pos , mask_half ,
101+ def test_trimmed_norm_half (image_rand_pos : np .ndarray ) -> None :
102+ image_out , _ = photometric_normalization .normalize_trim (image_rand_pos , mask_half ,
95103 output_range = (- 1 , 1 ), sharpen = 1 ,
96104 trim_percentiles = (1 , 99 ))
97105 assert np .mean (image_out , dtype = np .float ) == approx (- 0.4862089517215888 )
0 commit comments