@@ -23,7 +23,7 @@ class Averager2D(ABC):
2323
2424 def __init__ (self , composite_basis , source , batch_size = 512 , dtype = None ):
2525 """
26- :param composite_basis: Basis to be used during class average composition (eg hi res Cartesian/ FFB2D)
26+ :param composite_basis: Basis to be used during class average composition (eg FFB2D)
2727 :param source: Source of original images.
2828 :param dtype: Numpy dtype to be used during alignment.
2929 """
@@ -67,8 +67,8 @@ def average(
6767
6868 Should return an Image source of synthetic class averages.
6969
70- :param classes: class indices (refering to src). (n_img, n_nbor)
71- :param reflections: Bool representing whether to reflect image in `classes`
70+ :param classes: class indices (refering to src). (n_img, n_nbor).
71+ :param reflections: Bool representing whether to reflect image in `classes`.
7272 :coefs: Optional basis coefs (could avoid recomputing).
7373 :return: Stack of Synthetic Class Average images as Image instance.
7474 """
@@ -106,7 +106,7 @@ def __init__(
106106 """
107107 :param composite_basis: Basis to be used during class average composition (eg hi res Cartesian/FFB2D)
108108 :param source: Source of original images.
109- :param alignment_basis: Optional, basis to be used during alignment (eg FSPCA)
109+ :param alignment_basis: Optional, basis to be used only during alignment (eg FSPCA)
110110 :param dtype: Numpy dtype to be used during alignment.
111111 """
112112
@@ -541,13 +541,9 @@ def _reddychatterji(self, images, class_k, reflection_k):
541541
542542 # Result arrays
543543 M = len (images )
544- rotations_k = np .empty (M , dtype = self .dtype )
545- correlations_k = np .empty (M , dtype = self .dtype )
546- shifts_k = np .empty ((M , 2 ), dtype = self .dtype )
547- # Initialize for Image 0, others will populate in loop.
548- rotations_k [0 ] = 0
549- correlations_k [0 ] = 0
550- shifts_k [0 ] = 0
544+ rotations_k = np .zeros (M , dtype = self .dtype )
545+ correlations_k = np .zeros (M , dtype = self .dtype )
546+ shifts_k = np .zeros ((M , 2 ), dtype = int )
551547
552548 # De-Mean, note images is mutated and should be a `copy`.
553549 images -= images .mean (axis = (- 1 , - 2 ))[:, np .newaxis , np .newaxis ]
@@ -964,7 +960,7 @@ def align(self, classes, reflections, basis_coefficients):
964960
965961 # We'll brute force all shifts in a grid.
966962 g = grid_2d (L , normalized = False )
967- disc = g ["r" ] <= L // 8 # make param later
963+ disc = g ["r" ] <= L // self . radius
968964 X , Y = g ["x" ][disc ], g ["y" ][disc ]
969965
970966 for k in trange (n_classes ):
0 commit comments