@@ -466,6 +466,8 @@ def __init__(
466466 if self .alignment_src .dtype != src .dtype :
467467 raise RuntimeError ("Currently `alignment_src.dtype` must equal `src.dtype`" )
468468
469+ self .mask = grid_2d (src .L , normalized = False )["r" ] < src .L // 2
470+
469471 super ().__init__ (composite_basis , src , composite_basis , dtype = dtype )
470472
471473 def _phase_cross_correlation (self , img0 , img1 ):
@@ -547,7 +549,7 @@ def _reddychatterji(self, images, class_k, reflection_k):
547549 # Result arrays
548550 M = len (images )
549551 rotations_k = np .zeros (M , dtype = self .dtype )
550- correlations_k = np .zeros ( M , dtype = self .dtype )
552+ correlations_k = np .full ( M , - np . inf , dtype = self .dtype )
551553 shifts_k = np .zeros ((M , 2 ), dtype = int )
552554
553555 # De-Mean, note images is mutated and should be a `copy`.
@@ -655,8 +657,8 @@ def _reddychatterji(self, images, class_k, reflection_k):
655657 # Hack
656658 regis_img_estimated = rotate (regis_img , r )
657659 regis_img_rotated_p180 = rotate (regis_img , r + 180 )
658- da = np .dot (fixed_img . flatten () , regis_img_estimated . flatten () )
659- db = np .dot (fixed_img . flatten () , regis_img_rotated_p180 . flatten () )
660+ da = np .dot (fixed_img [ self . mask ] , regis_img_estimated [ self . mask ] )
661+ db = np .dot (fixed_img [ self . mask ] , regis_img_rotated_p180 [ self . mask ] )
660662 if db > da :
661663 regis_img_estimated = regis_img_rotated_p180
662664 r += 180
@@ -702,7 +704,7 @@ def _reddychatterji(self, images, class_k, reflection_k):
702704 shift = None # For logger line
703705
704706 # Estimated `corr` metric
705- corr = np .dot (fixed_img . flatten () , regis_img_estimated . flatten () )
707+ corr = np .dot (fixed_img [ self . mask ] , regis_img_estimated [ self . mask ] )
706708 correlations_k [m ] = corr
707709
708710 logger .debug (
@@ -962,9 +964,7 @@ def align(self, classes, reflections, basis_coefficients):
962964 L = self .alignment_src .L
963965
964966 # Instantiate matrices for inner loop, and best results.
965- _rotations = np .zeros (classes .shape , dtype = self .dtype )
966967 rotations = np .zeros (classes .shape , dtype = self .dtype )
967- _correlations = np .zeros (classes .shape , dtype = self .dtype )
968968 correlations = np .ones (classes .shape , dtype = self .dtype ) * - np .inf
969969 shifts = np .zeros ((* classes .shape , 2 ), dtype = int )
970970
@@ -986,16 +986,16 @@ def align(self, classes, reflections, basis_coefficients):
986986 # Don't shift the base image
987987 images [1 :] = Image (unshifted_images [1 :]).shift (s ).asnumpy ()
988988
989- rotations [ k ] , _ , correlations [ k ] = self ._reddychatterji (
989+ _rotations , _ , _correlations = self ._reddychatterji (
990990 images , classes [k ], reflections [k ]
991991 )
992992
993993 # Where corr has improved
994994 # update our rolling best results with this loop.
995- improved = _correlations > correlations
996- correlations = np .where (improved , _correlations , correlations )
997- rotations = np .where (improved , _rotations , rotations )
998- shifts = np .where (improved [..., np .newaxis ], s , shifts )
995+ improved = _correlations > correlations [ k ]
996+ correlations [ k ] = np .where (improved , _correlations , correlations [ k ] )
997+ rotations [ k ] = np .where (improved , _rotations , rotations [ k ] )
998+ shifts [ k ] = np .where (improved [..., np .newaxis ], s , shifts [ k ] )
999999 logger .debug (f"Shift { s } has improved { np .sum (improved )} results" )
10001000
10011001 return rotations , shifts , correlations
0 commit comments