Skip to content

Commit c30a890

Browse files
Add crop_pad_3d with tests (#679)
* add initial crop_pad_3d * crop 3d test * square crop 3d * fixed crop3d test * add padding * pad 3d tests * extra tests * fix comment * typo
1 parent c2831e2 commit c30a890

File tree

3 files changed

+104
-1
lines changed

3 files changed

+104
-1
lines changed

src/aspire/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .coor_trans import ( # isort:skip
22
common_line_from_rots,
33
crop_pad_2d,
4+
crop_pad_3d,
45
get_aligned_rotations,
56
get_rots_mse,
67
grid_1d,

src/aspire/utils/coor_trans.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,29 @@ def crop_pad_2d(im, size, fill_value=0):
333333
else:
334334
# target size is between mat_x and mat_y
335335
raise ValueError("Cannot crop and pad an image at the same time.")
336+
337+
338+
def crop_pad_3d(im, size, fill_value=0):
339+
im_y, im_x, im_z = im.shape
340+
# shift terms
341+
start_x = math.floor(im_x / 2) - math.floor(size / 2)
342+
start_y = math.floor(im_y / 2) - math.floor(size / 2)
343+
start_z = math.floor(im_z / 2) - math.floor(size / 2)
344+
345+
# cropping
346+
if size <= min(im_y, im_x, im_z):
347+
return im[
348+
start_y : start_y + size, start_x : start_x + size, start_z : start_z + size
349+
]
350+
# padding
351+
elif size >= max(im_y, im_x, im_z):
352+
to_return = fill_value * np.ones((size, size, size), dtype=im.dtype)
353+
to_return[
354+
-start_y : im_y - start_y,
355+
-start_x : im_x - start_x,
356+
-start_z : im_z - start_z,
357+
] = im
358+
return to_return
359+
else:
360+
# target size is between min and max of (im_y, im_x, im_z)
361+
raise ValueError("Cannot crop and pad a volume at the same time.")

tests/test_coor_trans.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from aspire.utils import (
77
Rotation,
88
crop_pad_2d,
9+
crop_pad_3d,
910
get_aligned_rotations,
1011
grid_2d,
1112
grid_3d,
@@ -115,6 +116,36 @@ def testSquareCrop2D(self):
115116
test_a = np.diag(np.arange(8))
116117
self.assertTrue(np.array_equal(test_a, crop_pad_2d(a, 8)))
117118

119+
def testSquareCrop3D(self):
120+
# even to even
121+
a = np.zeros((8, 8, 8))
122+
# pad it with the parts that will be cropped off from a 10x10x10
123+
a = np.pad(a, ((1, 1), (1, 1), (1, 1)), "constant", constant_values=1)
124+
# after cropping
125+
test_a = np.zeros((8, 8, 8))
126+
self.assertTrue(np.array_equal(crop_pad_3d(a, 8), test_a))
127+
128+
# even to odd
129+
a = np.zeros((7, 7, 7))
130+
# pad it with the parts that will be cropped off from a 10x10x10
131+
a = np.pad(a, ((2, 1), (2, 1), (2, 1)), "constant", constant_values=1)
132+
test_a = np.zeros((7, 7, 7))
133+
self.assertTrue(np.array_equal(crop_pad_3d(a, 7), test_a))
134+
135+
# odd to odd
136+
a = np.zeros((7, 7, 7))
137+
# pad it with the parts that will be cropped off from a 9x9x9
138+
a = np.pad(a, ((1, 1), (1, 1), (1, 1)), "constant", constant_values=1)
139+
test_a = np.zeros((7, 7, 7))
140+
self.assertTrue(np.array_equal(crop_pad_3d(a, 7), test_a))
141+
142+
# odd to even
143+
a = np.zeros((8, 8, 8))
144+
# pad it with the parts that will be cropped off from 11x11x11
145+
a = np.pad(a, ((1, 2), (1, 2), (1, 2)), "constant", constant_values=1)
146+
test_a = np.zeros((8, 8, 8))
147+
self.assertTrue(np.array_equal(crop_pad_3d(a, 8), test_a))
148+
118149
def testSquarePad2D(self):
119150
# Test even/odd cases based on the convention that the center of a sequence of length n
120151
# is (n+1)/2 if n is odd and n/2 + 1 if even.
@@ -151,6 +182,31 @@ def testSquarePad2D(self):
151182
test_a = np.diag([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
152183
self.assertTrue(np.array_equal(test_a, crop_pad_2d(a, 10)))
153184

185+
def testSquarePad3D(self):
186+
# even to even
187+
a = np.zeros((8, 8, 8))
188+
# after padding to 10x10x10
189+
test_a = np.pad(a, ((1, 1), (1, 1), (1, 1)), "constant", constant_values=1)
190+
self.assertTrue(np.array_equal(crop_pad_3d(a, 10, fill_value=1), test_a))
191+
192+
# even to odd
193+
a = np.zeros((8, 8, 8))
194+
# after padding to 11x11x11
195+
test_a = np.pad(a, ((1, 2), (1, 2), (1, 2)), "constant", constant_values=1)
196+
self.assertTrue(np.array_equal(crop_pad_3d(a, 11, fill_value=1), test_a))
197+
198+
# odd to odd
199+
a = np.zeros((7, 7, 7))
200+
# after padding to 9x9x9
201+
test_a = np.pad(a, ((1, 1), (1, 1), (1, 1)), "constant", constant_values=1)
202+
self.assertTrue(np.array_equal(crop_pad_3d(a, 9, fill_value=1), test_a))
203+
204+
# odd to even
205+
a = np.zeros((7, 7, 7))
206+
# after padding to 10x10x10
207+
test_a = np.pad(a, ((2, 1), (2, 1), (2, 1)), "constant", constant_values=1)
208+
self.assertTrue(np.array_equal(crop_pad_3d(a, 10, fill_value=1), test_a))
209+
154210
def testRectCrop2D(self):
155211
# Additional sanity checks for rectangular cropping case
156212

@@ -240,10 +296,17 @@ def testRectPad2D(self):
240296
def testCropPad2DError(self):
241297
with self.assertRaises(ValueError) as e:
242298
_ = crop_pad_2d(np.zeros((6, 10)), 8)
243-
self.assertTrue(
299+
self.assertEqual(
244300
"Cannot crop and pad an image at the same time.", str(e.exception)
245301
)
246302

303+
def testCropPad3DError(self):
304+
with self.assertRaises(ValueError) as e:
305+
_ = crop_pad_3d(np.zeros((6, 8, 10)), 8)
306+
self.assertEqual(
307+
"Cannot crop and pad a volume at the same time.", str(e.exception)
308+
)
309+
247310
def testCrop2DDtype(self):
248311
# crop_pad_2d must return an array of the same dtype it was given
249312
# in particular, because the method is used for Fourier downsampling
@@ -252,10 +315,23 @@ def testCrop2DDtype(self):
252315
crop_pad_2d(np.eye(10).astype("complex"), 5).dtype, np.dtype("complex128")
253316
)
254317

318+
def testCrop3DDtype(self):
319+
self.assertEqual(
320+
crop_pad_3d(np.ones((8, 8, 8)).astype("complex"), 5).dtype,
321+
np.dtype("complex128"),
322+
)
323+
255324
def testCrop2DFillValue(self):
256325
# make sure the fill value is as expected
257326
# we are padding from an odd to an even dimension
258327
# so the padded column is added to the left
259328
a = np.ones((4, 3))
260329
b = crop_pad_2d(a, 4, fill_value=-1)
261330
self.assertTrue(np.array_equal(b[:, 0], np.array([-1, -1, -1, -1])))
331+
332+
def testCrop3DFillValue(self):
333+
# make sure the fill value is expected. Since we are padding from odd to even
334+
# the padded side is added to the 0-end of dimension 3
335+
a = np.ones((4, 4, 3))
336+
b = crop_pad_3d(a, 4, fill_value=-1)
337+
self.assertTrue(np.array_equal(b[:, :, 0], -1 * np.ones((4, 4))))

0 commit comments

Comments
 (0)