Skip to content

Commit 25b5914

Browse files
Merge branch 'develop' into ds_logical_tests
2 parents faf608b + b000a87 commit 25b5914

File tree

23 files changed

+296
-72
lines changed

23 files changed

+296
-72
lines changed

gallery/tutorials/lecture_feature_demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@
4242
from aspire.noise import AnisotropicNoiseEstimator, WhiteNoiseEstimator
4343
from aspire.operators import FunctionFilter, RadialCTFFilter, ScalarFilter
4444
from aspire.source import RelionSource, Simulation
45-
from aspire.utils import Rotation
46-
from aspire.utils.coor_trans import (
45+
from aspire.utils import (
46+
Rotation,
4747
get_aligned_rotations,
4848
get_rots_mse,
4949
register_rotations,

gallery/tutorials/orient3d_simulation.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@
1515
from aspire.abinitio import CLSyncVoting
1616
from aspire.operators import RadialCTFFilter
1717
from aspire.source.simulation import Simulation
18-
from aspire.utils.coor_trans import (
19-
get_aligned_rotations,
20-
get_rots_mse,
21-
register_rotations,
22-
)
18+
from aspire.utils import get_aligned_rotations, get_rots_mse, register_rotations
2319
from aspire.volume import Volume
2420

2521
logger = logging.getLogger(__name__)

src/aspire/abinitio/commonline_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from aspire.abinitio.orientation_src import OrientEstSource
88
from aspire.basis import PolarBasis2D
9-
from aspire.utils.coor_trans import common_line_from_rots
9+
from aspire.utils import common_line_from_rots
1010
from aspire.utils.random import choice
1111

1212
logger = logging.getLogger(__name__)

src/aspire/basis/basis_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from numpy.polynomial.legendre import leggauss
1111
from scipy.special import jn, jv, sph_harm
1212

13-
from aspire.utils.coor_trans import grid_2d, grid_3d
13+
from aspire.utils import grid_2d, grid_3d
1414

1515
logger = logging.getLogger(__name__)
1616

src/aspire/classification/averager2d.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,8 @@ def __init__(
466466
if self.alignment_src.dtype != src.dtype:
467467
raise RuntimeError("Currently `alignment_src.dtype` must equal `src.dtype`")
468468

469+
self.mask = grid_2d(src.L, normalized=False)["r"] < src.L // 2
470+
469471
super().__init__(composite_basis, src, composite_basis, dtype=dtype)
470472

471473
def _phase_cross_correlation(self, img0, img1):
@@ -547,7 +549,7 @@ def _reddychatterji(self, images, class_k, reflection_k):
547549
# Result arrays
548550
M = len(images)
549551
rotations_k = np.zeros(M, dtype=self.dtype)
550-
correlations_k = np.zeros(M, dtype=self.dtype)
552+
correlations_k = np.full(M, -np.inf, dtype=self.dtype)
551553
shifts_k = np.zeros((M, 2), dtype=int)
552554

553555
# De-Mean, note images is mutated and should be a `copy`.
@@ -655,8 +657,8 @@ def _reddychatterji(self, images, class_k, reflection_k):
655657
# Hack
656658
regis_img_estimated = rotate(regis_img, r)
657659
regis_img_rotated_p180 = rotate(regis_img, r + 180)
658-
da = np.dot(fixed_img.flatten(), regis_img_estimated.flatten())
659-
db = np.dot(fixed_img.flatten(), regis_img_rotated_p180.flatten())
660+
da = np.dot(fixed_img[self.mask], regis_img_estimated[self.mask])
661+
db = np.dot(fixed_img[self.mask], regis_img_rotated_p180[self.mask])
660662
if db > da:
661663
regis_img_estimated = regis_img_rotated_p180
662664
r += 180
@@ -702,7 +704,7 @@ def _reddychatterji(self, images, class_k, reflection_k):
702704
shift = None # For logger line
703705

704706
# Estimated `corr` metric
705-
corr = np.dot(fixed_img.flatten(), regis_img_estimated.flatten())
707+
corr = np.dot(fixed_img[self.mask], regis_img_estimated[self.mask])
706708
correlations_k[m] = corr
707709

708710
logger.debug(
@@ -962,9 +964,7 @@ def align(self, classes, reflections, basis_coefficients):
962964
L = self.alignment_src.L
963965

964966
# Instantiate matrices for inner loop, and best results.
965-
_rotations = np.zeros(classes.shape, dtype=self.dtype)
966967
rotations = np.zeros(classes.shape, dtype=self.dtype)
967-
_correlations = np.zeros(classes.shape, dtype=self.dtype)
968968
correlations = np.ones(classes.shape, dtype=self.dtype) * -np.inf
969969
shifts = np.zeros((*classes.shape, 2), dtype=int)
970970

@@ -986,16 +986,16 @@ def align(self, classes, reflections, basis_coefficients):
986986
# Don't shift the base image
987987
images[1:] = Image(unshifted_images[1:]).shift(s).asnumpy()
988988

989-
rotations[k], _, correlations[k] = self._reddychatterji(
989+
_rotations, _, _correlations = self._reddychatterji(
990990
images, classes[k], reflections[k]
991991
)
992992

993993
# Where corr has improved
994994
# update our rolling best results with this loop.
995-
improved = _correlations > correlations
996-
correlations = np.where(improved, _correlations, correlations)
997-
rotations = np.where(improved, _rotations, rotations)
998-
shifts = np.where(improved[..., np.newaxis], s, shifts)
995+
improved = _correlations > correlations[k]
996+
correlations[k] = np.where(improved, _correlations, correlations[k])
997+
rotations[k] = np.where(improved, _rotations, rotations[k])
998+
shifts[k] = np.where(improved[..., np.newaxis], s, shifts[k])
999999
logger.debug(f"Shift {s} has improved {np.sum(improved)} results")
10001000

10011001
return rotations, shifts, correlations

src/aspire/covariance/covar.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ def __getattr__(self, name):
4242

4343
def compute_kernel(self):
4444
# TODO: Most of this stuff is duplicated in MeanEstimator - move up the hierarchy?
45-
n = self.n
46-
L = self.L
47-
_2L = 2 * self.L
45+
n = self.src.n
46+
L = self.src.L
47+
_2L = 2 * self.src.L
4848

4949
kernel = np.zeros((_2L, _2L, _2L, _2L, _2L, _2L), dtype=self.dtype)
5050
sq_filters_f = np.square(evaluate_src_filters_on_grid(self.src))
@@ -168,22 +168,23 @@ def src_backward(self, mean_vol, noise_variance, shrink_method=None):
168168
contribution and expressed as coefficients of `basis`.
169169
"""
170170
covar_b = np.zeros(
171-
(self.L, self.L, self.L, self.L, self.L, self.L), dtype=self.dtype
171+
(self.src.L, self.src.L, self.src.L, self.src.L, self.src.L, self.src.L),
172+
dtype=self.dtype,
172173
)
173174

174-
for i in range(0, self.n, self.batch_size):
175+
for i in range(0, self.src.n, self.batch_size):
175176
im = self.src.images(i, self.batch_size)
176177
batch_n = im.n_images
177178
im_centered = im - self.src.vol_forward(mean_vol, i, self.batch_size)
178179

179180
im_centered_b = np.zeros(
180-
(batch_n, self.L, self.L, self.L), dtype=self.dtype
181+
(batch_n, self.src.L, self.src.L, self.src.L), dtype=self.dtype
181182
)
182183
for j in range(batch_n):
183184
im_centered_b[j] = self.src.im_backward(Image(im_centered[j]), i + j)
184185
im_centered_b = Volume(im_centered_b).to_vec()
185186

186-
covar_b += vecmat_to_volmat(im_centered_b.T @ im_centered_b) / self.n
187+
covar_b += vecmat_to_volmat(im_centered_b.T @ im_centered_b) / self.src.n
187188

188189
covar_b_coeff = self.basis.mat_evaluate_t(covar_b)
189190
return self._shrink(covar_b_coeff, noise_variance, shrink_method)

src/aspire/ctf/ctf_estimator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
from aspire.numeric import fft
2222
from aspire.operators import voltage_to_wavelength
2323
from aspire.storage import StarFile
24-
from aspire.utils import abs2, complex_type
25-
from aspire.utils.coor_trans import grid_1d, grid_2d
24+
from aspire.utils import abs2, complex_type, grid_1d, grid_2d
2625

2726
logger = logging.getLogger(__name__)
2827

src/aspire/denoising/adaptive_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from aspire.noise import WhiteNoiseEstimator
66
from aspire.numeric import fft
77
from aspire.source import ImageSource
8-
from aspire.utils.coor_trans import grid_2d
8+
from aspire.utils import grid_2d
99

1010
logger = logging.getLogger(__name__)
1111

src/aspire/image/image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import aspire.volume
1010
from aspire.nufft import anufft
1111
from aspire.numeric import fft, xp
12-
from aspire.utils.coor_trans import grid_2d
12+
from aspire.utils import grid_2d
1313
from aspire.utils.matrix import anorm
1414

1515
logger = logging.getLogger(__name__)

src/aspire/noise/noise.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from aspire.numeric import fft, xp
66
from aspire.operators import ArrayFilter, ScalarFilter
7-
from aspire.utils.coor_trans import grid_2d
7+
from aspire.utils import grid_2d
88

99
logger = logging.getLogger(__name__)
1010

@@ -27,8 +27,6 @@ def __init__(self, src, bgRadius=1, batchSize=512):
2727

2828
self.src = src
2929
self.dtype = self.src.dtype
30-
self.L = src.L
31-
self.n = src.n
3230
self.bgRadius = bgRadius
3331
self.batchSize = batchSize
3432

@@ -72,16 +70,16 @@ def _estimate_noise_variance(self):
7270
TODO: How's this initial estimate of variance different from the 'estimate' method?
7371
"""
7472
# Run estimate using saved parameters
75-
g2d = grid_2d(self.L, indexing="yx", dtype=self.dtype)
73+
g2d = grid_2d(self.src.L, indexing="yx", dtype=self.dtype)
7674
mask = g2d["r"] >= self.bgRadius
7775

7876
first_moment = 0
7977
second_moment = 0
80-
for i in range(0, self.n, self.batchSize):
78+
for i in range(0, self.src.n, self.batchSize):
8179
images = self.src.images(start=i, num=self.batchSize).asnumpy()
8280
images_masked = images * mask
8381

84-
_denominator = self.n * np.sum(mask)
82+
_denominator = self.src.n * np.sum(mask)
8583
first_moment += np.sum(images_masked) / _denominator
8684
second_moment += np.sum(np.abs(images_masked**2)) / _denominator
8785
return second_moment - first_moment**2
@@ -100,7 +98,7 @@ def estimate(self):
10098
# AnisotropicNoiseEstimator.filter is an ArrayFilter.
10199
# We average the variance over all frequencies,
102100

103-
return np.mean(self.filter.evaluate_grid(self.L))
101+
return np.mean(self.filter.evaluate_grid(self.src.L))
104102

105103
def _create_filter(self, noise_psd=None):
106104
"""
@@ -117,21 +115,21 @@ def estimate_noise_psd(self):
117115
TODO: How's this initial estimate of variance different from the 'estimate' method?
118116
"""
119117
# Run estimate using saved parameters
120-
g2d = grid_2d(self.L, indexing="yx", dtype=self.dtype)
118+
g2d = grid_2d(self.src.L, indexing="yx", dtype=self.dtype)
121119
mask = g2d["r"] >= self.bgRadius
122120

123121
mean_est = 0
124-
noise_psd_est = np.zeros((self.L, self.L)).astype(self.src.dtype)
125-
for i in range(0, self.n, self.batchSize):
122+
noise_psd_est = np.zeros((self.src.L, self.src.L)).astype(self.src.dtype)
123+
for i in range(0, self.src.n, self.batchSize):
126124
images = self.src.images(i, self.batchSize).asnumpy()
127125
images_masked = images * mask
128126

129-
_denominator = self.n * np.sum(mask)
127+
_denominator = self.src.n * np.sum(mask)
130128
mean_est += np.sum(images_masked) / _denominator
131129
im_masked_f = xp.asnumpy(fft.centered_fft2(xp.asarray(images_masked)))
132130
noise_psd_est += np.sum(np.abs(im_masked_f**2), axis=0) / _denominator
133131

134-
mid = self.L // 2
132+
mid = self.src.L // 2
135133
noise_psd_est[mid, mid] -= mean_est**2
136134

137135
return noise_psd_est

0 commit comments

Comments
 (0)