Skip to content

Commit ed09404

Browse files
committed
Fix k bug in BFSReddyChetterji
1 parent 6cb389e commit ed09404

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

src/aspire/classification/averager2d.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ def _reddychatterji(self, images, class_k, reflection_k):
547547
# Result arrays
548548
M = len(images)
549549
rotations_k = np.zeros(M, dtype=self.dtype)
550-
correlations_k = np.zeros(M, dtype=self.dtype)
550+
correlations_k = np.full(M, -np.inf, dtype=self.dtype)
551551
shifts_k = np.zeros((M, 2), dtype=int)
552552

553553
# De-Mean, note images is mutated and should be a `copy`.
@@ -962,9 +962,7 @@ def align(self, classes, reflections, basis_coefficients):
962962
L = self.alignment_src.L
963963

964964
# Instantiate matrices for inner loop, and best results.
965-
_rotations = np.zeros(classes.shape, dtype=self.dtype)
966965
rotations = np.zeros(classes.shape, dtype=self.dtype)
967-
_correlations = np.zeros(classes.shape, dtype=self.dtype)
968966
correlations = np.ones(classes.shape, dtype=self.dtype) * -np.inf
969967
shifts = np.zeros((*classes.shape, 2), dtype=int)
970968

@@ -986,16 +984,16 @@ def align(self, classes, reflections, basis_coefficients):
986984
# Don't shift the base image
987985
images[1:] = Image(unshifted_images[1:]).shift(s).asnumpy()
988986

989-
rotations[k], _, correlations[k] = self._reddychatterji(
987+
_rotations, _, _correlations = self._reddychatterji(
990988
images, classes[k], reflections[k]
991989
)
992990

993991
# Where corr has improved
994992
# update our rolling best results with this loop.
995-
improved = _correlations > correlations
996-
correlations = np.where(improved, _correlations, correlations)
997-
rotations = np.where(improved, _rotations, rotations)
998-
shifts = np.where(improved[..., np.newaxis], s, shifts)
993+
improved = _correlations > correlations[k]
994+
correlations[k] = np.where(improved, _correlations, correlations[k])
995+
rotations[k] = np.where(improved, _rotations, rotations[k])
996+
shifts[k] = np.where(improved[..., np.newaxis], s, shifts[k])
999997
logger.debug(f"Shift {s} has improved {np.sum(improved)} results")
1000998

1001999
return rotations, shifts, correlations

0 commit comments

Comments
 (0)