Skip to content

Commit b7a3398

Browse files
authored
Merge pull request #590 from ComputationalCryoEM/character_zero
BFSReddyChetterji k Bug
2 parents b19503c + a2fba0b commit b7a3398

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

src/aspire/classification/averager2d.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,8 @@ def __init__(
466466
if self.alignment_src.dtype != src.dtype:
467467
raise RuntimeError("Currently `alignment_src.dtype` must equal `src.dtype`")
468468

469+
self.mask = grid_2d(src.L, normalized=False)["r"] < src.L // 2
470+
469471
super().__init__(composite_basis, src, composite_basis, dtype=dtype)
470472

471473
def _phase_cross_correlation(self, img0, img1):
@@ -547,7 +549,7 @@ def _reddychatterji(self, images, class_k, reflection_k):
547549
# Result arrays
548550
M = len(images)
549551
rotations_k = np.zeros(M, dtype=self.dtype)
550-
correlations_k = np.zeros(M, dtype=self.dtype)
552+
correlations_k = np.full(M, -np.inf, dtype=self.dtype)
551553
shifts_k = np.zeros((M, 2), dtype=int)
552554

553555
# De-Mean, note images is mutated and should be a `copy`.
@@ -655,8 +657,8 @@ def _reddychatterji(self, images, class_k, reflection_k):
655657
# Hack
656658
regis_img_estimated = rotate(regis_img, r)
657659
regis_img_rotated_p180 = rotate(regis_img, r + 180)
658-
da = np.dot(fixed_img.flatten(), regis_img_estimated.flatten())
659-
db = np.dot(fixed_img.flatten(), regis_img_rotated_p180.flatten())
660+
da = np.dot(fixed_img[self.mask], regis_img_estimated[self.mask])
661+
db = np.dot(fixed_img[self.mask], regis_img_rotated_p180[self.mask])
660662
if db > da:
661663
regis_img_estimated = regis_img_rotated_p180
662664
r += 180
@@ -702,7 +704,7 @@ def _reddychatterji(self, images, class_k, reflection_k):
702704
shift = None # For logger line
703705

704706
# Estimated `corr` metric
705-
corr = np.dot(fixed_img.flatten(), regis_img_estimated.flatten())
707+
corr = np.dot(fixed_img[self.mask], regis_img_estimated[self.mask])
706708
correlations_k[m] = corr
707709

708710
logger.debug(
@@ -962,9 +964,7 @@ def align(self, classes, reflections, basis_coefficients):
962964
L = self.alignment_src.L
963965

964966
# Instantiate matrices for inner loop, and best results.
965-
_rotations = np.zeros(classes.shape, dtype=self.dtype)
966967
rotations = np.zeros(classes.shape, dtype=self.dtype)
967-
_correlations = np.zeros(classes.shape, dtype=self.dtype)
968968
correlations = np.ones(classes.shape, dtype=self.dtype) * -np.inf
969969
shifts = np.zeros((*classes.shape, 2), dtype=int)
970970

@@ -986,16 +986,16 @@ def align(self, classes, reflections, basis_coefficients):
986986
# Don't shift the base image
987987
images[1:] = Image(unshifted_images[1:]).shift(s).asnumpy()
988988

989-
rotations[k], _, correlations[k] = self._reddychatterji(
989+
_rotations, _, _correlations = self._reddychatterji(
990990
images, classes[k], reflections[k]
991991
)
992992

993993
# Where corr has improved
994994
# update our rolling best results with this loop.
995-
improved = _correlations > correlations
996-
correlations = np.where(improved, _correlations, correlations)
997-
rotations = np.where(improved, _rotations, rotations)
998-
shifts = np.where(improved[..., np.newaxis], s, shifts)
995+
improved = _correlations > correlations[k]
996+
correlations[k] = np.where(improved, _correlations, correlations[k])
997+
rotations[k] = np.where(improved, _rotations, rotations[k])
998+
shifts[k] = np.where(improved[..., np.newaxis], s, shifts[k])
999999
logger.debug(f"Shift {s} has improved {np.sum(improved)} results")
10001000

10011001
return rotations, shifts, correlations

0 commit comments

Comments
 (0)