diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index 5db97f1aa3..9a91b77949 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -466,6 +466,8 @@ def __init__( if self.alignment_src.dtype != src.dtype: raise RuntimeError("Currently `alignment_src.dtype` must equal `src.dtype`") + self.mask = grid_2d(src.L, normalized=False)["r"] < src.L // 2 + super().__init__(composite_basis, src, composite_basis, dtype=dtype) def _phase_cross_correlation(self, img0, img1): @@ -547,7 +549,7 @@ def _reddychatterji(self, images, class_k, reflection_k): # Result arrays M = len(images) rotations_k = np.zeros(M, dtype=self.dtype) - correlations_k = np.zeros(M, dtype=self.dtype) + correlations_k = np.full(M, -np.inf, dtype=self.dtype) shifts_k = np.zeros((M, 2), dtype=int) # De-Mean, note images is mutated and should be a `copy`. @@ -655,8 +657,8 @@ def _reddychatterji(self, images, class_k, reflection_k): # Hack regis_img_estimated = rotate(regis_img, r) regis_img_rotated_p180 = rotate(regis_img, r + 180) - da = np.dot(fixed_img.flatten(), regis_img_estimated.flatten()) - db = np.dot(fixed_img.flatten(), regis_img_rotated_p180.flatten()) + da = np.dot(fixed_img[self.mask], regis_img_estimated[self.mask]) + db = np.dot(fixed_img[self.mask], regis_img_rotated_p180[self.mask]) if db > da: regis_img_estimated = regis_img_rotated_p180 r += 180 @@ -702,7 +704,7 @@ def _reddychatterji(self, images, class_k, reflection_k): shift = None # For logger line # Estimated `corr` metric - corr = np.dot(fixed_img.flatten(), regis_img_estimated.flatten()) + corr = np.dot(fixed_img[self.mask], regis_img_estimated[self.mask]) correlations_k[m] = corr logger.debug( @@ -962,9 +964,7 @@ def align(self, classes, reflections, basis_coefficients): L = self.alignment_src.L # Instantiate matrices for inner loop, and best results. - _rotations = np.zeros(classes.shape, dtype=self.dtype) rotations = np.zeros(classes.shape, dtype=self.dtype) - _correlations = np.zeros(classes.shape, dtype=self.dtype) correlations = np.ones(classes.shape, dtype=self.dtype) * -np.inf shifts = np.zeros((*classes.shape, 2), dtype=int) @@ -986,16 +986,16 @@ def align(self, classes, reflections, basis_coefficients): # Don't shift the base image images[1:] = Image(unshifted_images[1:]).shift(s).asnumpy() - rotations[k], _, correlations[k] = self._reddychatterji( + _rotations, _, _correlations = self._reddychatterji( images, classes[k], reflections[k] ) # Where corr has improved # update our rolling best results with this loop. - improved = _correlations > correlations - correlations = np.where(improved, _correlations, correlations) - rotations = np.where(improved, _rotations, rotations) - shifts = np.where(improved[..., np.newaxis], s, shifts) + improved = _correlations > correlations[k] + correlations[k] = np.where(improved, _correlations, correlations[k]) + rotations[k] = np.where(improved, _rotations, rotations[k]) + shifts[k] = np.where(improved[..., np.newaxis], s, shifts[k]) logger.debug(f"Shift {s} has improved {np.sum(improved)} results") return rotations, shifts, correlations