Skip to content

Commit a2fba0b

Browse files
committed
restrict corr to a disc
1 parent ed09404 commit a2fba0b

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/aspire/classification/averager2d.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,8 @@ def __init__(
466466
if self.alignment_src.dtype != src.dtype:
467467
raise RuntimeError("Currently `alignment_src.dtype` must equal `src.dtype`")
468468

469+
self.mask = grid_2d(src.L, normalized=False)["r"] < src.L // 2
470+
469471
super().__init__(composite_basis, src, composite_basis, dtype=dtype)
470472

471473
def _phase_cross_correlation(self, img0, img1):
@@ -655,8 +657,8 @@ def _reddychatterji(self, images, class_k, reflection_k):
655657
# Hack
656658
regis_img_estimated = rotate(regis_img, r)
657659
regis_img_rotated_p180 = rotate(regis_img, r + 180)
658-
da = np.dot(fixed_img.flatten(), regis_img_estimated.flatten())
659-
db = np.dot(fixed_img.flatten(), regis_img_rotated_p180.flatten())
660+
da = np.dot(fixed_img[self.mask], regis_img_estimated[self.mask])
661+
db = np.dot(fixed_img[self.mask], regis_img_rotated_p180[self.mask])
660662
if db > da:
661663
regis_img_estimated = regis_img_rotated_p180
662664
r += 180
@@ -702,7 +704,7 @@ def _reddychatterji(self, images, class_k, reflection_k):
702704
shift = None # For logger line
703705

704706
# Estimated `corr` metric
705-
corr = np.dot(fixed_img.flatten(), regis_img_estimated.flatten())
707+
corr = np.dot(fixed_img[self.mask], regis_img_estimated[self.mask])
706708
correlations_k[m] = corr
707709

708710
logger.debug(

0 commit comments

Comments
 (0)