Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/aspire/classification/averager2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,8 @@ def __init__(
if self.alignment_src.dtype != src.dtype:
raise RuntimeError("Currently `alignment_src.dtype` must equal `src.dtype`")

self.mask = grid_2d(src.L, normalized=False)["r"] < src.L // 2

super().__init__(composite_basis, src, composite_basis, dtype=dtype)

def _phase_cross_correlation(self, img0, img1):
Expand Down Expand Up @@ -547,7 +549,7 @@ def _reddychatterji(self, images, class_k, reflection_k):
# Result arrays
M = len(images)
rotations_k = np.zeros(M, dtype=self.dtype)
correlations_k = np.zeros(M, dtype=self.dtype)
correlations_k = np.full(M, -np.inf, dtype=self.dtype)
shifts_k = np.zeros((M, 2), dtype=int)

# De-Mean, note images is mutated and should be a `copy`.
Expand Down Expand Up @@ -655,8 +657,8 @@ def _reddychatterji(self, images, class_k, reflection_k):
# Hack
regis_img_estimated = rotate(regis_img, r)
regis_img_rotated_p180 = rotate(regis_img, r + 180)
da = np.dot(fixed_img.flatten(), regis_img_estimated.flatten())
db = np.dot(fixed_img.flatten(), regis_img_rotated_p180.flatten())
da = np.dot(fixed_img[self.mask], regis_img_estimated[self.mask])
db = np.dot(fixed_img[self.mask], regis_img_rotated_p180[self.mask])
if db > da:
regis_img_estimated = regis_img_rotated_p180
r += 180
Expand Down Expand Up @@ -702,7 +704,7 @@ def _reddychatterji(self, images, class_k, reflection_k):
shift = None # For logger line

# Estimated `corr` metric
corr = np.dot(fixed_img.flatten(), regis_img_estimated.flatten())
corr = np.dot(fixed_img[self.mask], regis_img_estimated[self.mask])
correlations_k[m] = corr

logger.debug(
Expand Down Expand Up @@ -962,9 +964,7 @@ def align(self, classes, reflections, basis_coefficients):
L = self.alignment_src.L

# Instantiate matrices for inner loop, and best results.
_rotations = np.zeros(classes.shape, dtype=self.dtype)
rotations = np.zeros(classes.shape, dtype=self.dtype)
_correlations = np.zeros(classes.shape, dtype=self.dtype)
correlations = np.ones(classes.shape, dtype=self.dtype) * -np.inf
shifts = np.zeros((*classes.shape, 2), dtype=int)

Expand All @@ -986,16 +986,16 @@ def align(self, classes, reflections, basis_coefficients):
# Don't shift the base image
images[1:] = Image(unshifted_images[1:]).shift(s).asnumpy()

rotations[k], _, correlations[k] = self._reddychatterji(
_rotations, _, _correlations = self._reddychatterji(
images, classes[k], reflections[k]
)

# Where corr has improved
# update our rolling best results with this loop.
improved = _correlations > correlations
correlations = np.where(improved, _correlations, correlations)
rotations = np.where(improved, _rotations, rotations)
shifts = np.where(improved[..., np.newaxis], s, shifts)
improved = _correlations > correlations[k]
correlations[k] = np.where(improved, _correlations, correlations[k])
rotations[k] = np.where(improved, _rotations, rotations[k])
shifts[k] = np.where(improved[..., np.newaxis], s, shifts[k])
logger.debug(f"Shift {s} has improved {np.sum(improved)} results")

return rotations, shifts, correlations
Expand Down