Skip to content
64 changes: 53 additions & 11 deletions src/aspire/basis/basis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import warnings

import numpy as np
from scipy.sparse.linalg import LinearOperator
Expand All @@ -19,7 +20,7 @@ class Coef:

_allowed_dtypes = (np.float32, np.float64)

def __init__(self, basis, data, dtype=None):
def __init__(self, basis, data, pixel_size=None, dtype=None):
"""
A stack of one or more coefficient arrays.

Expand All @@ -32,6 +33,8 @@ def __init__(self, basis, data, dtype=None):
:param basis: `Basis` associated with `data` coefficients.
:param data: Numpy array containing image data with shape
`(..., count)`.
:param pixel_size: Pixel size of underlying image data in
angstroms, default `None`.
:param dtype: Optionally cast `data` to this dtype.
Defaults to `data.dtype`.

Expand Down Expand Up @@ -62,6 +65,7 @@ def __init__(self, basis, data, dtype=None):
)
self.basis = basis

self.pixel_size = pixel_size
self._data = data.astype(self.dtype, copy=False)
self.ndim = self._data.ndim
self.shape = self._data.shape
Expand Down Expand Up @@ -127,9 +131,31 @@ def _check_key_dims(self, key):
f"Coef stack_dim is {self.stack_ndim}, slice length must be =< {self.ndim}"
)

def _check_pixel_size(self, other):
"""
Check pixel size. In the case of only one of self or other having a pixel
size, use that pixel size. If self and other do not have matching pixel size
emit a warning and use self.pixel_size.
"""
px_sz = self.pixel_size # default

if isinstance(other, Coef):
if self.pixel_size is None:
px_sz = other.pixel_size
elif other.pixel_size is not None and not np.isclose(
self.pixel_size, other.pixel_size
):
warnings.warn(
f"Pixel sizes do not match. Using pixel size {self.pixel_size}.",
UserWarning,
stacklevel=2,
)

return px_sz

def __getitem__(self, key):
self._check_key_dims(key)
return self.__class__(self.basis, self._data[key])
return self.__class__(self.basis, self._data[key], pixel_size=self.pixel_size)

def __setitem__(self, key, value):
self._check_key_dims(key)
Expand Down Expand Up @@ -158,14 +184,16 @@ def stack_reshape(self, *args):
)

return self.__class__(
self.basis, self._data.reshape(*shape, self._data.shape[-1])
self.basis,
self._data.reshape(*shape, self._data.shape[-1]),
pixel_size=self.pixel_size,
)

def copy(self):
"""
Return a new `Coef` instance with a deep copy of the data.
"""
return self.__class__(self.basis, self._data.copy())
return self.__class__(self.basis, self._data.copy(), pixel_size=self.pixel_size)

def evaluate(self):
"""
Expand Down Expand Up @@ -219,10 +247,12 @@ def __mul__(self, other):
:return: `Coef` instance.
"""

px_sz = self._check_pixel_size(other)

if isinstance(other, Coef):
other = other._data

return self.__class__(self.basis, self._data * other)
return self.__class__(self.basis, self._data * other, pixel_size=px_sz)

def __add__(self, other):
"""
Expand All @@ -233,10 +263,12 @@ def __add__(self, other):
:return: `Coef` instance.
"""

px_sz = self._check_pixel_size(other)

if isinstance(other, Coef):
other = other._data

return self.__class__(self.basis, self._data + other)
return self.__class__(self.basis, self._data + other, pixel_size=px_sz)

def __sub__(self, other):
"""
Expand All @@ -247,10 +279,12 @@ def __sub__(self, other):
:return: `Coef` instance.
"""

px_sz = self._check_pixel_size(other)

if isinstance(other, Coef):
other = other._data

return self.__class__(self.basis, self._data - other)
return self.__class__(self.basis, self._data - other, pixel_size=px_sz)

def __neg__(self):
"""
Expand All @@ -259,7 +293,7 @@ def __neg__(self):
:return: `Coef` instance.
"""

return self.__class__(self.basis, -self._data)
return self.__class__(self.basis, -self._data, pixel_size=self.pixel_size)

@property
def size(self):
Expand Down Expand Up @@ -444,6 +478,9 @@ def evaluate(self, v):
if not isinstance(v, Coef):
raise TypeError(f"`evaluate` should be passed a `Coef`, received {type(v)}")

# Store pixel_size for passthrough
px_sz = v.pixel_size

# Flatten stack
stack_shape = v.stack_shape
v = v.stack_reshape(-1).asnumpy()
Expand All @@ -454,7 +491,7 @@ def evaluate(self, v):
x = x.reshape(*stack_shape, *self.sz)

# Return the appropriate class
return self._cls(x)
return self._cls(x, pixel_size=px_sz)

def _evaluate(self, v):
raise NotImplementedError("subclasses must implement this")
Expand All @@ -480,7 +517,10 @@ def evaluate_t(self, v):
f"{self.__class__.__name__}::evaluate_t"
f" passed numpy array instead of {self._cls}."
)
px_sz = None
else:
# Store pixel_size for passthrough
px_sz = v.pixel_size
v = v.asnumpy()

# Flatten stack, ndim is wrt Basis (2 or 3)
Expand All @@ -491,7 +531,7 @@ def evaluate_t(self, v):
# Restore stack shape
x = x.reshape(*stack_shape, self.count)

return Coef(self, x)
return Coef(self, x, pixel_size=px_sz)

def _evaluate_t(self, v):
raise NotImplementedError("Subclasses should implement this")
Expand Down Expand Up @@ -553,7 +593,9 @@ def expand(self, x, tol=None, atol=0):

"""

px_sz = None
if isinstance(x, Image) or isinstance(x, Volume):
px_sz = x.pixel_size
x = x.asnumpy()

if x.dtype != self.dtype:
Expand Down Expand Up @@ -598,4 +640,4 @@ def expand(self, x, tol=None, atol=0):
# return v coefficients with the last dimension of self.count
v = v.reshape((*sz_roll, self.count))

return Coef(self, v)
return Coef(self, v, pixel_size=px_sz)
8 changes: 6 additions & 2 deletions src/aspire/basis/fle_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,8 @@ def lowpass(self, coefs, bandlimit):
f"`coefs` should be a `Coef` instance, received {type(coefs)}."
)

px_sz = coefs.pixel_size

# Copy to mutate the coefs.
coefs = coefs.asnumpy().copy()

Expand All @@ -699,7 +701,7 @@ def lowpass(self, coefs, bandlimit):
k = k - 1
coefs[:, k + 1 :] = 0

return Coef(self, coefs)
return Coef(self, coefs, pixel_size=px_sz)

def radial_convolve(self, coefs, radial_img):
"""
Expand All @@ -720,6 +722,8 @@ def radial_convolve(self, coefs, radial_img):
"`radial_convolve` currently only implemented for 1D stacks."
)

px_sz = coefs.pixel_size

# Potentially migrate to GPU
coefs = xp.asarray(coefs.asnumpy())
radial_img = xp.asarray(radial_img)
Expand All @@ -743,7 +747,7 @@ def radial_convolve(self, coefs, radial_img):
coefs_conv = coefs_conv[..., self._fle_to_fb_indices]

# Return as Coef on host
return Coef(self, xp.asnumpy(coefs_conv))
return Coef(self, xp.asnumpy(coefs_conv), pixel_size=px_sz)

def _radial_convolve_weights(self, b):
"""
Expand Down
9 changes: 5 additions & 4 deletions src/aspire/basis/fspca.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def expand(self, x):

assert c_fspca.shape == (x.shape[0], self.count)

return Coef(self, c_fspca)
return Coef(self, c_fspca, pixel_size=x.pixel_size)

def evaluate_to_image_basis(self, c):
"""
Expand Down Expand Up @@ -379,7 +379,7 @@ def evaluate(self, c):
# corrected_c[:, self.angular_indices!=0] *= 2
# return corrected_c @ eigvecs.T

return Coef(self.basis, c @ eigvecs.T)
return Coef(self.basis, c @ eigvecs.T, pixel_size=c.pixel_size)

# TODO: Python>=3.8 @cached_property
def _get_compressed_indices(self):
Expand Down Expand Up @@ -477,6 +477,7 @@ def to_complex(self, coef):
"""
if not isinstance(coef, Coef):
raise TypeError(f"'coef' should be `Coef` instance, received {type(coef)}.")
px_sz = coef.pixel_size
coef = coef.asnumpy()

if coef.dtype not in (np.float64, np.float32):
Expand Down Expand Up @@ -514,7 +515,7 @@ def to_complex(self, coef):
for i, k in enumerate(ccoef_d.keys()):
ccoef[:, i] = ccoef_d[k]

return ComplexCoef(self, ccoef)
return ComplexCoef(self, ccoef, pixel_size=px_sz)

def to_real(self, complex_coef):
"""
Expand Down Expand Up @@ -554,7 +555,7 @@ def to_real(self, complex_coef):
coef[:, pos_i] = 2.0 * complex_coef[:, i].real
coef[:, neg_i] = -2.0 * complex_coef[:, i].imag

return Coef(self, coef)
return Coef(self, coef, pixel_size=complex_coef.pixel_size)

def calculate_bispectrum(
self, coef, flatten=False, filter_nonzero_freqs=False, freq_cutoff=None
Expand Down
3 changes: 2 additions & 1 deletion src/aspire/basis/steerable.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def rotate(self, coef, radians, refl=None):
if not isinstance(coef, Coef):
raise TypeError(f"`coef` must be `Coef` instance, received {type(coef)}.")

px_sz = coef.pixel_size
coef = coef.asnumpy()

# Covert radians to a broadcastable shape
Expand Down Expand Up @@ -226,7 +227,7 @@ def rotate(self, coef, radians, refl=None):
ks_neg
) - coef_pos * np.sin(ks_pos)

return Coef(self, coef)
return Coef(self, coef, pixel_size=px_sz)

def complex_rotate(self, complex_coef, radians, refl=None):
"""
Expand Down
1 change: 1 addition & 0 deletions src/aspire/denoising/class_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
n=self.averager.src.n,
dtype=self.averager.src.dtype,
symmetry_group=self.src.symmetry_group,
pixel_size=self.src.pixel_size,
)

# Any further operations should not mutate this instance.
Expand Down
8 changes: 7 additions & 1 deletion src/aspire/denoising/denoised_src.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@ def __init__(self, src, denoiser):
:param denoiser: A `Denoiser` object for specifying a method for denoising
"""

super().__init__(src.L, src.n, dtype=src.dtype, metadata=src._metadata.copy())
super().__init__(
src.L,
src.n,
dtype=src.dtype,
pixel_size=src.pixel_size,
metadata=src._metadata.copy(),
)
# TODO, we can probably setup a reasonable default here.
self.denoiser = denoiser
if not isinstance(denoiser, Denoiser):
Expand Down
2 changes: 1 addition & 1 deletion src/aspire/reconstruction/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def estimate(self, b_coef=None, x0=None, tol=1e-5, regularizer=0):
if b_coef is None:
b_coef = self.src_backward()
est_coef = self.conj_grad(b_coef, x0=x0, tol=tol, regularizer=regularizer)
est = Coef(self.basis, est_coef).evaluate()
est = Coef(self.basis, est_coef, pixel_size=self.src.pixel_size).evaluate()

return est

Expand Down
9 changes: 8 additions & 1 deletion src/aspire/source/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,21 @@ def __init__(
)
symmetry_group = symmetry_group or self.vols.symmetry_group

if pixel_size and (pixel_size != self.vols.pixel_size):
logger.warning(
f"Overriding volume pixel size, {self.vols.pixel_size}, with "
f"user provided pixel size of {pixel_size} angstrom."
)
pixel_size = pixel_size or self.vols.pixel_size

# Infer the details from volume when possible.
super().__init__(
L=self.vols.resolution,
n=n,
dtype=self.vols.dtype,
memory=memory,
symmetry_group=symmetry_group,
pixel_size=self.vols.pixel_size,
pixel_size=pixel_size,
)

# If a user provides both `L` and `vols`, resolution should match.
Expand Down
7 changes: 6 additions & 1 deletion tests/_basis_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,16 @@ def testExpand(self, basis):
_class = self.getClass(basis)
# expand should take an Image/Volume and return a NumPy array of type
# basis.coefficient_dtype
px_sz = 1.234
result = basis.expand(
_class(np.zeros((basis.nres,) * basis.ndim, dtype=basis.dtype))
_class(
np.zeros((basis.nres,) * basis.ndim, dtype=basis.dtype),
pixel_size=px_sz,
)
)
assert isinstance(result, Coef)
assert result.dtype == basis.coefficient_dtype
np.testing.assert_array_equal(result.pixel_size, px_sz)

def testInitWithIntSize(self, basis):
# make sure we can instantiate with just an int as a shortcut
Expand Down
6 changes: 5 additions & 1 deletion tests/test_FFBbasis2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def testShift(self, basis):
v = Volume(
np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")).astype(
basis.dtype
)
),
pixel_size=1.234,
).downsample(basis.nres)

src = Simulation(L=basis.nres, n=n_img, vols=v, dtype=basis.dtype)
Expand Down Expand Up @@ -123,6 +124,9 @@ def testShift(self, basis):
logger.info(f"RMSE shifted image diffs {rmse}")
assert np.allclose(rmse, 0, atol=1e-5)

# Check pixel_size passthrough
np.testing.assert_array_equal(f_imgs.pixel_size, f_shifted_imgs.pixel_size)


params = [pytest.param(512, np.float32, marks=pytest.mark.expensive)]

Expand Down
4 changes: 4 additions & 0 deletions tests/test_class_src.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def class_sim_fixture(dtype, img_size):
C=1,
angles=true_rots.angles,
symmetry_group="C4", # For testing symmetry_group pass-through.
pixel_size=1.234, # For testing pixel_size pass-through
)
# Prefetch all the images
src = src.cache()
Expand Down Expand Up @@ -207,6 +208,9 @@ class averages.
# Check symmetry_group pass-through.
assert test_src.symmetry_group == class_sim_fixture.symmetry_group

# Check pixel_size pass-through.
np.testing.assert_array_equal(test_src.pixel_size, class_sim_fixture.pixel_size)


# Test the _HeapItem helper class
def test_heap_helper():
Expand Down
Loading
Loading