Skip to content

Commit 05fce78

Browse files
authored
Merge pull request #449 from ComputationalCryoEM/aspire_image_source_403
ArrayImageSource admit numpy
2 parents b893bef + 6cf86ac commit 05fce78

File tree

3 files changed

+40
-4
lines changed

3 files changed

+40
-4
lines changed

src/aspire/source/image.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,11 +763,22 @@ class ArrayImageSource(ImageSource):
763763
def __init__(self, im, metadata=None, angles=None):
764764
"""
765765
Initialize from an `Image` object
766-
:param im: An `Image` object representing image data served up by this `ImageSource`
766+
:param im: An `Image` or Numpy array object representing image data served up by this `ImageSource`.
767+
In the case of a Numpy array, attempts to create an 'Image' object.
767768
:param metadata: A Dataframe of metadata information corresponding to this ImageSource's images
768769
:param angles: Optional n-by-3 array of rotation angles corresponding to `im`.
769770
"""
770771

772+
if not isinstance(im, Image):
773+
logger.info("Attempting to create an Image object from Numpy array.")
774+
try:
775+
im = Image(im)
776+
except Exception as e:
777+
raise RuntimeError(
778+
"Creating Image object from Numpy array failed."
779+
f" Original error: {str(e)}"
780+
)
781+
771782
super().__init__(
772783
L=im.res, n=im.n_images, dtype=im.dtype, metadata=metadata, memory=None
773784
)

tests/test_adaptive_support.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import pytest
77

88
from aspire.denoising import adaptive_support
9-
from aspire.image import Image
109
from aspire.source import ArrayImageSource
1110
from aspire.utils import gaussian_2d
1211

@@ -37,7 +36,7 @@ def testAdaptiveSupportBadThreshold(self):
3736
"""
3837

3938
discs = np.empty((self.size, self.size)) # Intentional Dummy Data
40-
img_src = ArrayImageSource(Image(discs))
39+
img_src = ArrayImageSource(discs)
4140

4241
with pytest.raises(ValueError, match=r"Given energy_threshold.*"):
4342
_ = adaptive_support(img_src, -0.5)
@@ -69,7 +68,7 @@ def test_adaptive_support_F(self):
6968
)
7069

7170
# Setup ImageSource like objects
72-
img_src = ArrayImageSource(Image(imgs))
71+
img_src = ArrayImageSource(imgs)
7372

7473
for ref, threshold in self.references.items():
7574
c, R = adaptive_support(img_src, threshold)

tests/test_array_image_source.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,32 @@ def testArrayImageSource(self):
5555
im = src.images(start=0, num=np.inf) # returns Image instance
5656
self.assertTrue(np.allclose(im.asnumpy(), self.ims_np))
5757

58+
def testArrayImageSourceFromNumpy(self):
59+
"""
60+
An Array can be wrapped in an ArrayImageSource when we need to deal with ImageSource objects.
61+
62+
This checks round trip conversion does not crash and returns identity.
63+
"""
64+
65+
# Create an ArrayImageSource directly from Numpy array
66+
src = ArrayImageSource(self.ims_np)
67+
68+
# Ask the Source for all images in the stack as a Numpy array
69+
ims_np = src.images(start=0, num=np.inf).asnumpy()
70+
71+
# Comparison should be yield identity
72+
self.assertTrue(np.allclose(ims_np, self.ims_np))
73+
74+
def testArrayImageSourceNumpyError(self):
75+
"""
76+
Test that ArrayImageSource when instantiated with incorrect input
77+
gives appropriate error.
78+
"""
79+
80+
# Test we raise with expected message from getter.
81+
with raises(RuntimeError, match=r"Creating Image object from Numpy.*"):
82+
_ = ArrayImageSource(np.empty((3, 2, 1)))
83+
5884
def testArrayImageSourceAngGetterError(self):
5985
"""
6086
Test that ArrayImageSource when instantiated without required

0 commit comments

Comments
 (0)