Skip to content

Commit 8ada1e9

Browse files
authored
Merge pull request #583 from ComputationalCryoEM/deprecate_ensure
Deprecate ensure
2 parents b401824 + 0679e86 commit 8ada1e9

File tree

23 files changed

+145
-209
lines changed

23 files changed

+145
-209
lines changed

src/aspire/abinitio/commonline_sync.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44

55
from aspire.abinitio import CLOrient3D
6-
from aspire.utils import Rotation, ensure
6+
from aspire.utils import Rotation
77
from aspire.utils.matlab_compat import stable_eigsh
88

99
logger = logging.getLogger(__name__)
@@ -53,8 +53,8 @@ def estimate_rotations(self):
5353

5454
S = self.syncmatrix
5555
sz = S.shape
56-
ensure(sz[0] == sz[1], "syncmatrix must be a square matrix.")
57-
ensure(sz[0] % 2 == 0, "syncmatrix must be a square matrix of size 2Kx2K.")
56+
assert sz[0] == sz[1], "syncmatrix must be a square matrix."
57+
assert sz[0] % 2 == 0, "syncmatrix must be a square matrix of size 2Kx2K."
5858

5959
n_img = sz[0] // 2
6060

@@ -153,7 +153,7 @@ def syncmatrix_vote(self):
153153
sz = clmatrix.shape
154154
n_theta = self.n_theta
155155

156-
ensure(sz[0] == sz[1], "clmatrix must be a square matrix.")
156+
assert sz[0] == sz[1], "clmatrix must be a square matrix."
157157

158158
n_img = sz[0]
159159
S = np.eye(2 * n_img, dtype=self.dtype).reshape(n_img, 2, n_img, 2)

src/aspire/basis/basis.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from aspire.basis.basis_utils import num_besselj_zeros
77
from aspire.image import Image
8-
from aspire.utils import ensure, mdim_mat_fun_conj
8+
from aspire.utils import mdim_mat_fun_conj
99
from aspire.utils.matlab_compat import m_reshape
1010
from aspire.volume import Volume
1111

@@ -185,10 +185,9 @@ def expand(self, x):
185185

186186
x = x.reshape((-1, *self.sz))
187187

188-
ensure(
189-
x.shape[-self.ndim :] == self.sz,
190-
f"Last {self.ndim} dimensions of x must match {self.sz}.",
191-
)
188+
assert (
189+
x.shape[-self.ndim :] == self.sz
190+
), f"Last {self.ndim} dimensions of x must match {self.sz}."
192191

193192
operator = LinearOperator(
194193
shape=(self.count, self.count),

src/aspire/basis/basis_utils.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from numpy.polynomial.legendre import leggauss
1111
from scipy.special import jn, jv, sph_harm
1212

13-
from aspire.utils import ensure
1413
from aspire.utils.coor_trans import grid_2d, grid_3d
1514

1615
logger = logging.getLogger(__name__)
@@ -158,8 +157,8 @@ def real_sph_harmonic(j, m, theta, phi):
158157

159158

160159
def besselj_zeros(nu, k):
161-
ensure(k >= 3, "k must be >= 3")
162-
ensure(0 <= nu <= 1e7, "nu must be between 0 and 1e7")
160+
assert k >= 3, "k must be >= 3"
161+
assert 0 <= nu <= 1e7, "nu must be between 0 and 1e7"
163162

164163
z = np.zeros(k)
165164

@@ -198,10 +197,9 @@ def besselj_zeros(nu, k):
198197
z[n : n + j] = besselj_newton(nu, z0)
199198

200199
# Check to see that the sequence of zeros makes sense
201-
ensure(
202-
check_besselj_zeros(nu, z[n - 2 : n + j]),
203-
"Unable to properly estimate Bessel function zeros.",
204-
)
200+
assert check_besselj_zeros(
201+
nu, z[n - 2 : n + j]
202+
), "Unable to properly estimate Bessel function zeros."
205203

206204
# Check how far off we are
207205
err = (z[n : n + j] - z0) / np.diff(z[n - 1 : n + j])
@@ -236,10 +234,11 @@ def unique_coords_nd(N, ndim, shifted=False, normalized=True, dtype=np.float32):
236234
:param normalized: normalize the grid or not.
237235
:return: The unique polar coordinates in 2D or 3D
238236
"""
239-
ensure(
240-
ndim in (2, 3), "Only two- or three-dimensional basis functions are supported."
241-
)
242-
ensure(N > 0, "Number of grid points should be greater than 0.")
237+
assert ndim in (
238+
2,
239+
3,
240+
), "Only two- or three-dimensional basis functions are supported."
241+
assert N > 0, "Number of grid points should be greater than 0."
243242

244243
if ndim == 2:
245244
grid = grid_2d(

src/aspire/basis/fb_2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from aspire.basis import SteerableBasis2D
77
from aspire.basis.basis_utils import unique_coords_nd
88
from aspire.image import Image
9-
from aspire.utils import complex_type, ensure, real_type, roll_dim, unroll_dim
9+
from aspire.utils import complex_type, real_type, roll_dim, unroll_dim
1010
from aspire.utils.matlab_compat import m_flatten, m_reshape
1111

1212
logger = logging.getLogger(__name__)
@@ -37,8 +37,8 @@ def __init__(self, size, ell_max=None, dtype=np.float32):
3737
"""
3838

3939
ndim = len(size)
40-
ensure(ndim == 2, "Only two-dimensional basis functions are supported.")
41-
ensure(len(set(size)) == 1, "Only square domains are supported.")
40+
assert ndim == 2, "Only two-dimensional basis functions are supported."
41+
assert len(set(size)) == 1, "Only square domains are supported."
4242
super().__init__(size, ell_max, dtype=dtype)
4343

4444
def _build(self):

src/aspire/basis/fb_3d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from aspire.basis import Basis
66
from aspire.basis.basis_utils import real_sph_harmonic, sph_bessel, unique_coords_nd
7-
from aspire.utils import ensure, roll_dim, unroll_dim
7+
from aspire.utils import roll_dim, unroll_dim
88
from aspire.utils.matlab_compat import m_flatten, m_reshape
99

1010
logger = logging.getLogger(__name__)
@@ -30,8 +30,8 @@ def __init__(self, size, ell_max=None, dtype=np.float32):
3030
below the Nyquist frequency (default Inf).
3131
"""
3232
ndim = len(size)
33-
ensure(ndim == 3, "Only three-dimensional basis functions are supported.")
34-
ensure(len(set(size)) == 1, "Only cubic domains are supported.")
33+
assert ndim == 3, "Only three-dimensional basis functions are supported."
34+
assert len(set(size)) == 1, "Only cubic domains are supported."
3535

3636
super().__init__(size, ell_max, dtype=dtype)
3737

src/aspire/basis/polar_2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from aspire.basis import Basis
66
from aspire.image import Image
77
from aspire.nufft import anufft, nufft
8-
from aspire.utils import ensure, real_type
8+
from aspire.utils import real_type
99

1010
logger = logging.getLogger(__name__)
1111

@@ -26,8 +26,8 @@ def __init__(self, size, nrad=None, ntheta=None, dtype=np.float32):
2626
"""
2727

2828
ndim = len(size)
29-
ensure(ndim == 2, "Only two-dimensional grids are supported.")
30-
ensure(len(set(size)) == 1, "Only square domains are supported.")
29+
assert ndim == 2, "Only two-dimensional grids are supported."
30+
assert len(set(size)) == 1, "Only square domains are supported."
3131

3232
self.nrad = nrad
3333
self.ntheta = ntheta

src/aspire/covariance/covar.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from aspire.operators import evaluate_src_filters_on_grid
1515
from aspire.reconstruction import Estimator, FourierKernel, MeanEstimator
1616
from aspire.utils import (
17-
ensure,
1817
make_symmat,
1918
symmat_to_vec_iso,
2019
vec_to_symmat_iso,
@@ -197,10 +196,12 @@ def _shrink(self, covar_b_coeff, noise_variance, method=None):
197196
:param method: One of None/'frobenius_norm'/'operator_norm'/'soft_threshold'
198197
:return: Shrunk covariance matrix
199198
"""
200-
ensure(
201-
method in (None, "frobenius_norm", "operator_norm", "soft_threshold"),
202-
"Unsupported shrink method",
203-
)
199+
assert method in (
200+
None,
201+
"frobenius_norm",
202+
"operator_norm",
203+
"soft_threshold",
204+
), "Unsupported shrink method"
204205

205206
An = self.basis.mat_evaluate_t(self.mean_kernel.toeplitz())
206207
if method is None:

src/aspire/covariance/covar2d.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from aspire.operators import BlkDiagMatrix, RadialCTFFilter
88
from aspire.optimization import conj_grad, fill_struct
9-
from aspire.utils import ensure, make_symmat
9+
from aspire.utils import make_symmat
1010
from aspire.utils.matlab_compat import m_reshape
1111

1212
logger = logging.getLogger(__name__)
@@ -22,10 +22,11 @@ def shrink_covar(covar, noise_var, gamma, shrinker="frobenius_norm"):
2222
:return: The shrinked covariance matrix
2323
"""
2424

25-
ensure(
26-
shrinker in ("frobenius_norm", "operator_norm", "soft_threshold"),
27-
"Unsupported shrink method",
28-
)
25+
assert shrinker in (
26+
"frobenius_norm",
27+
"operator_norm",
28+
"soft_threshold",
29+
), "Unsupported shrink method"
2930

3031
lambs, eig_vec = eig(make_symmat(covar))
3132

@@ -96,7 +97,7 @@ def __init__(self, basis):
9697
"""
9798
self.basis = basis
9899
self.dtype = self.basis.dtype
99-
ensure(basis.ndim == 2, "Only two-dimensional basis functions are needed.")
100+
assert basis.ndim == 2, "Only two-dimensional basis functions are needed."
100101

101102
def _get_mean(self, coeffs):
102103
"""
@@ -327,7 +328,7 @@ def identity(x):
327328

328329
def precond_fun(S, x):
329330
p = np.size(S, 0)
330-
ensure(np.size(x) == p * p, "The sizes of S and x are not consistent.")
331+
assert np.size(x) == p * p, "The sizes of S and x are not consistent."
331332
x = m_reshape(x, (p, p))
332333
y = S @ x @ S
333334
y = m_reshape(y, (p**2,))
@@ -632,7 +633,7 @@ def _solve_covar(self, A_covar, b_covar, M, covar_est_opt):
632633

633634
def precond_fun(S, x):
634635
p = np.size(S, 0)
635-
ensure(np.size(x) == p * p, "The sizes of S and x are not consistent.")
636+
assert np.size(x) == p * p, "The sizes of S and x are not consistent."
636637
x = m_reshape(x, (p, p))
637638
y = S @ x @ S
638639
y = m_reshape(y, (p**2,))

src/aspire/image/image.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import aspire.volume
1010
from aspire.nufft import anufft
1111
from aspire.numeric import fft, xp
12-
from aspire.utils import ensure
1312
from aspire.utils.coor_trans import grid_2d
1413
from aspire.utils.matrix import anorm
1514

@@ -144,7 +143,7 @@ def __init__(self, data, dtype=None):
144143
self.n_images = self.shape[0]
145144
self.res = self.shape[1]
146145

147-
ensure(data.shape[1] == data.shape[2], "Only square ndarrays are supported.")
146+
assert data.shape[1] == data.shape[2], "Only square ndarrays are supported."
148147

149148
def __getitem__(self, item):
150149
return self.data[item]
@@ -279,12 +278,11 @@ def _im_translate(self, shifts):
279278
shifts = shifts[np.newaxis, :]
280279
n_shifts = shifts.shape[0]
281280

282-
ensure(shifts.shape[-1] == 2, "shifts must be nx2")
281+
assert shifts.shape[-1] == 2, "shifts must be nx2"
283282

284-
ensure(
285-
n_shifts == 1 or n_shifts == self.n_images,
286-
"number of shifts must be 1 or match the number of images",
287-
)
283+
assert (
284+
n_shifts == 1 or n_shifts == self.n_images
285+
), "number of shifts must be 1 or match the number of images"
288286
# Cast shifts to this instance's internal dtype
289287
shifts = shifts.astype(self.dtype)
290288

@@ -330,10 +328,9 @@ def backproject(self, rot_matrices):
330328

331329
L = self.res
332330

333-
ensure(
334-
self.n_images == rot_matrices.shape[0],
335-
"Number of rotation matrices must match the number of images",
336-
)
331+
assert (
332+
self.n_images == rot_matrices.shape[0]
333+
), "Number of rotation matrices must match the number of images"
337334

338335
# TODO: rotated_grids might as well give us correctly shaped array in the first place
339336
pts_rot = aspire.volume.rotated_grids(L, rot_matrices)

src/aspire/image/preprocess.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from scipy.special import erf
66

77
from aspire.numeric import fft, xp
8-
from aspire.utils import ensure
98

109
logger = logging.getLogger(__name__)
1110

@@ -129,10 +128,9 @@ def downsample(insamples, szout, mask=None):
129128
:return: An array consists of the blurred and downsampled objects.
130129
"""
131130

132-
ensure(
133-
insamples.ndim - 1 == np.size(szout),
134-
"The number of downsampling dimensions is not the same as that of objects.",
135-
)
131+
assert insamples.ndim - 1 == np.size(
132+
szout
133+
), "The number of downsampling dimensions is not the same as that of objects."
136134

137135
L_in = insamples.shape[1]
138136
L_out = szout[0]

0 commit comments

Comments
 (0)