Skip to content

Commit cc93aab

Browse files
committed
Improve Multiply, Shift, and Add Xform for issue #233
1 parent 06f1719 commit cc93aab

File tree

2 files changed

+51
-6
lines changed

2 files changed

+51
-6
lines changed

src/aspire/source/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,10 @@ def invert_contrast(self, batch_size=512):
416416

417417
if signal_mean < noise_mean:
418418
logger.info('Need to invert contrast')
419-
scale_factor = -1.0 * np.ones(self.n)
419+
scale_factor = np.array(-1.0)
420420
else:
421421
logger.info('No need to invert contrast')
422-
scale_factor = 1.0 * np.ones(self.n)
422+
scale_factor = np.array(1.0)
423423

424424
logger.info('Adding Scaling Xform to end of generation pipeline')
425425
self.generation_pipeline.add_xform(Multiply(scale_factor))

src/aspire/source/xform.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,21 @@ def __init__(self, factor):
133133
self.multipliers = factor
134134

135135
def _forward(self, im, indices):
136-
return im * self.multipliers[indices]
136+
if self.multipliers.size == 1:
137+
return im * self.multipliers
138+
else:
139+
return im * self.multipliers[indices]
140+
141+
def __str__(self):
142+
"""
143+
Show class name and related scaling information
144+
:return: A string of class name and related information
145+
"""
146+
if self.multipliers.size == 1:
147+
return (self.__class__.__name__ + ' by same number of '
148+
+ str(self.multipliers))
149+
else:
150+
return self.__class__.__name__ + ' by difference numbers'
137151

138152

139153
class Shift(LinearXform):
@@ -151,10 +165,27 @@ def __init__(self, shifts):
151165
self.n = shifts.shape[0]
152166

153167
def _forward(self, im, indices):
154-
return im.shift(self.shifts[indices])
168+
if self.n == 1:
169+
return im.shift(self.shifts)
170+
else:
171+
return im.shift(self.shifts[indices])
155172

156173
def _adjoint(self, im, indices):
157-
return im.shift(-self.shifts[indices])
174+
if self.n == 1:
175+
return im.shift(-self.shifts)
176+
else:
177+
return im.shift(-self.shifts[indices])
178+
179+
def __str__(self):
180+
"""
181+
Show class name and related shift information
182+
:return: A string of class name and related information
183+
"""
184+
if self.n == 1:
185+
return (self.__class__.__name__ + ' by same number of '
186+
+ str(self.shifts))
187+
else:
188+
return self.__class__.__name__ + ' by difference numbers'
158189

159190

160191
class Downsample(LinearXform):
@@ -219,7 +250,21 @@ def __init__(self, addend):
219250
self.addend = addend
220251

221252
def _forward(self, im, indices):
222-
return im + self.addend[indices]
253+
if self.addend.size == 1:
254+
return im + self.addend
255+
else:
256+
return im + self.addend[indices]
257+
258+
def __str__(self):
259+
"""
260+
Show class name and related Add information
261+
:return: A string of class name and related information
262+
"""
263+
if self.addend.size == 1:
264+
return (self.__class__.__name__ + ' with same number of '
265+
+ str(self.addend))
266+
else:
267+
return self.__class__.__name__ + ' with different numbers'
223268

224269

225270
class FlipXform(Xform):

0 commit comments

Comments
 (0)