Skip to content

Commit efd9744

Browse files
Change self.L and self.n to self.src.L and self.src.n for estimators (#550)
* changed self.L and self.n to self.src.L and self.src.n for estimators * extraneous comma
1 parent b7a3398 commit efd9744

File tree

4 files changed

+26
-30
lines changed

4 files changed

+26
-30
lines changed

src/aspire/covariance/covar.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ def __getattr__(self, name):
4242

4343
def compute_kernel(self):
4444
# TODO: Most of this stuff is duplicated in MeanEstimator - move up the hierarchy?
45-
n = self.n
46-
L = self.L
47-
_2L = 2 * self.L
45+
n = self.src.n
46+
L = self.src.L
47+
_2L = 2 * self.src.L
4848

4949
kernel = np.zeros((_2L, _2L, _2L, _2L, _2L, _2L), dtype=self.dtype)
5050
sq_filters_f = np.square(evaluate_src_filters_on_grid(self.src))
@@ -168,22 +168,23 @@ def src_backward(self, mean_vol, noise_variance, shrink_method=None):
168168
contribution and expressed as coefficients of `basis`.
169169
"""
170170
covar_b = np.zeros(
171-
(self.L, self.L, self.L, self.L, self.L, self.L), dtype=self.dtype
171+
(self.src.L, self.src.L, self.src.L, self.src.L, self.src.L, self.src.L),
172+
dtype=self.dtype,
172173
)
173174

174-
for i in range(0, self.n, self.batch_size):
175+
for i in range(0, self.src.n, self.batch_size):
175176
im = self.src.images(i, self.batch_size)
176177
batch_n = im.n_images
177178
im_centered = im - self.src.vol_forward(mean_vol, i, self.batch_size)
178179

179180
im_centered_b = np.zeros(
180-
(batch_n, self.L, self.L, self.L), dtype=self.dtype
181+
(batch_n, self.src.L, self.src.L, self.src.L), dtype=self.dtype
181182
)
182183
for j in range(batch_n):
183184
im_centered_b[j] = self.src.im_backward(Image(im_centered[j]), i + j)
184185
im_centered_b = Volume(im_centered_b).to_vec()
185186

186-
covar_b += vecmat_to_volmat(im_centered_b.T @ im_centered_b) / self.n
187+
covar_b += vecmat_to_volmat(im_centered_b.T @ im_centered_b) / self.src.n
187188

188189
covar_b_coeff = self.basis.mat_evaluate_t(covar_b)
189190
return self._shrink(covar_b_coeff, noise_variance, shrink_method)

src/aspire/noise/noise.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ def __init__(self, src, bgRadius=1, batchSize=512):
2727

2828
self.src = src
2929
self.dtype = self.src.dtype
30-
self.L = src.L
31-
self.n = src.n
3230
self.bgRadius = bgRadius
3331
self.batchSize = batchSize
3432

@@ -72,16 +70,16 @@ def _estimate_noise_variance(self):
7270
TODO: How's this initial estimate of variance different from the 'estimate' method?
7371
"""
7472
# Run estimate using saved parameters
75-
g2d = grid_2d(self.L, indexing="yx", dtype=self.dtype)
73+
g2d = grid_2d(self.src.L, indexing="yx", dtype=self.dtype)
7674
mask = g2d["r"] >= self.bgRadius
7775

7876
first_moment = 0
7977
second_moment = 0
80-
for i in range(0, self.n, self.batchSize):
78+
for i in range(0, self.src.n, self.batchSize):
8179
images = self.src.images(start=i, num=self.batchSize).asnumpy()
8280
images_masked = images * mask
8381

84-
_denominator = self.n * np.sum(mask)
82+
_denominator = self.src.n * np.sum(mask)
8583
first_moment += np.sum(images_masked) / _denominator
8684
second_moment += np.sum(np.abs(images_masked**2)) / _denominator
8785
return second_moment - first_moment**2
@@ -100,7 +98,7 @@ def estimate(self):
10098
# AnisotropicNoiseEstimator.filter is an ArrayFilter.
10199
# We average the variance over all frequencies,
102100

103-
return np.mean(self.filter.evaluate_grid(self.L))
101+
return np.mean(self.filter.evaluate_grid(self.src.L))
104102

105103
def _create_filter(self, noise_psd=None):
106104
"""
@@ -117,21 +115,21 @@ def estimate_noise_psd(self):
117115
TODO: How's this initial estimate of variance different from the 'estimate' method?
118116
"""
119117
# Run estimate using saved parameters
120-
g2d = grid_2d(self.L, indexing="yx", dtype=self.dtype)
118+
g2d = grid_2d(self.src.L, indexing="yx", dtype=self.dtype)
121119
mask = g2d["r"] >= self.bgRadius
122120

123121
mean_est = 0
124-
noise_psd_est = np.zeros((self.L, self.L)).astype(self.src.dtype)
125-
for i in range(0, self.n, self.batchSize):
122+
noise_psd_est = np.zeros((self.src.L, self.src.L)).astype(self.src.dtype)
123+
for i in range(0, self.src.n, self.batchSize):
126124
images = self.src.images(i, self.batchSize).asnumpy()
127125
images_masked = images * mask
128126

129-
_denominator = self.n * np.sum(mask)
127+
_denominator = self.src.n * np.sum(mask)
130128
mean_est += np.sum(images_masked) / _denominator
131129
im_masked_f = xp.asnumpy(fft.centered_fft2(xp.asarray(images_masked)))
132130
noise_psd_est += np.sum(np.abs(im_masked_f**2), axis=0) / _denominator
133131

134-
mid = self.L // 2
132+
mid = self.src.L // 2
135133
noise_psd_est[mid, mid] -= mean_est**2
136134

137135
return noise_psd_est

src/aspire/reconstruction/estimator.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ def __init__(self, src, basis, batch_size=512, preconditioner="circulant"):
3030
self.batch_size = batch_size
3131
self.preconditioner = preconditioner
3232

33-
self.L = src.L
34-
self.n = src.n
35-
3633
if not self.dtype == self.basis.dtype:
3734
logger.warning(
3835
f"Inconsistent types in {self.dtype} Estimator."
@@ -84,11 +81,11 @@ def src_backward(self):
8481
:return: The adjoint mapping applied to the images, averaged over the whole dataset and expressed
8582
as coefficients of `basis`.
8683
"""
87-
mean_b = np.zeros((self.L, self.L, self.L), dtype=self.dtype)
84+
mean_b = np.zeros((self.src.L, self.src.L, self.src.L), dtype=self.dtype)
8885

89-
for i in range(0, self.n, self.batch_size):
86+
for i in range(0, self.src.n, self.batch_size):
9087
im = self.src.images(i, self.batch_size)
91-
batch_mean_b = self.src.im_backward(im, i) / self.n
88+
batch_mean_b = self.src.im_backward(im, i) / self.src.n
9289
mean_b += batch_mean_b.astype(self.dtype)
9390

9491
res = self.basis.evaluate_t(mean_b)

src/aspire/reconstruction/mean.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@
1414

1515
class MeanEstimator(Estimator):
1616
def compute_kernel(self):
17-
_2L = 2 * self.L
17+
_2L = 2 * self.src.L
1818
kernel = np.zeros((_2L, _2L, _2L), dtype=self.dtype)
1919
sq_filters_f = np.square(evaluate_src_filters_on_grid(self.src))
2020

21-
for i in range(0, self.n, self.batch_size):
22-
_range = np.arange(i, min(self.n, i + self.batch_size), dtype=int)
23-
pts_rot = rotated_grids(self.L, self.src.rots[_range, :, :])
21+
for i in range(0, self.src.n, self.batch_size):
22+
_range = np.arange(i, min(self.src.n, i + self.batch_size), dtype=int)
23+
pts_rot = rotated_grids(self.src.L, self.src.rots[_range, :, :])
2424
weights = sq_filters_f[:, :, _range]
2525
weights *= self.src.amplitudes[_range] ** 2
2626

27-
if self.L % 2 == 0:
27+
if self.src.L % 2 == 0:
2828
weights[0, :, :] = 0
2929
weights[:, 0, :] = 0
3030

@@ -33,7 +33,7 @@ def compute_kernel(self):
3333

3434
kernel += (
3535
1
36-
/ (self.n * self.L**4)
36+
/ (self.src.n * self.src.L**4)
3737
* anufft(weights, pts_rot[::-1], (_2L, _2L, _2L), real=True)
3838
)
3939

0 commit comments

Comments
 (0)