Skip to content

Commit 6acc7ac

Browse files
vineetbansaljunchaoxia
authored andcommitted
some simplifications
1 parent f6bd250 commit 6acc7ac

File tree

1 file changed

+21
-35
lines changed

1 file changed

+21
-35
lines changed

src/aspire/source/xform.py

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,13 @@ class Multiply(SymmetricXform):
123123
def __init__(self, factor):
124124
"""
125125
Initialize a Multiply Xform using specified factors
126-
:param factor: An ndarray of scalar factors to use for amplitude multiplication.
126+
:param factor: A float/int or an ndarray of scalar factors to use for amplitude multiplication.
127127
"""
128128
super().__init__()
129129
self.multipliers = np.array(factor)
130130

131131
def _forward(self, im, indices):
132-
if self.multipliers.size == 1:
132+
if self.multipliers.size == 1: # if we have a scalar multiplier
133133
im_new = im * self.multipliers
134134
else:
135135
im_new = im * self.multipliers[indices]
@@ -138,12 +138,8 @@ def _forward(self, im, indices):
138138

139139
def __str__(self):
140140
if self.multipliers.size == 1:
141-
str_out = (self.__class__.__name__ + ' by same number of '
142-
+ str(self.multipliers))
143-
else:
144-
str_out = self.__class__.__name__ + ' by difference numbers'
145-
146-
return str_out
141+
return f'Multiply ({self.multipliers})'
142+
return 'Multiply (multiple)'
147143

148144

149145
class Shift(LinearXform):
@@ -154,36 +150,31 @@ class Shift(LinearXform):
154150
def __init__(self, shifts):
155151
"""
156152
Initialize a Shift Xform using a Numpy array of shift values.
157-
:param shifts: An ndarray of shape (n, 2)
153+
:param shifts: An ndarray of shape (2) or (n, 2)
158154
"""
159155
super().__init__()
160156
self.shifts = np.array(shifts)
161-
self.n = shifts.shape[0]
162157

163158
def _forward(self, im, indices):
164-
if self.n == 1:
159+
if self.shifts.ndim == 1:
165160
im_new = im.shift(self.shifts)
166161
else:
167162
im_new = im.shift(self.shifts[indices])
168163

169164
return im_new
170165

171166
def _adjoint(self, im, indices):
172-
if self.n == 1:
167+
if self.shifts.ndim == 1:
173168
im_new = im.shift(-self.shifts)
174169
else:
175170
im_new = im.shift(-self.shifts[indices])
176171

177172
return im_new
178173

179174
def __str__(self):
180-
if self.n == 1:
181-
str_out = (self.__class__.__name__ + ' by same number of '
182-
+ str(self.shifts))
183-
else:
184-
str_out = self.__class__.__name__ + ' by difference numbers'
185-
186-
return str_out
175+
if self.shifts.ndim == 1:
176+
return f'Shift ({self.shifts})'
177+
return 'Shift (multiple)'
187178

188179

189180
class Downsample(LinearXform):
@@ -198,12 +189,11 @@ def _forward(self, im, indices):
198189
return im.downsample(self.resolution)
199190

200191
def _adjoint(self, im, indices):
201-
# TODO: Implement upsampling with zero-padding
192+
# TODO: Implement up-sampling with zero-padding
202193
raise NotImplementedError('Adjoint of downsampling not implemented yet.')
203194

204195
def __str__(self):
205-
return (self.__class__.__name__ + ' at resolution of '
206-
+ str(self.resolution))
196+
return f'Downsample (Resolution {self.resolution})'
207197

208198

209199
class FilterXform(SymmetricXform):
@@ -222,8 +212,7 @@ def _forward(self, im, indices):
222212
return im.filter(self.filter)
223213

224214
def __str__(self):
225-
return (self.__class__.__name__ + ' with filter of '
226-
+ str(self.filter))
215+
return f'FilterXform ({self.filter})'
227216

228217

229218
class Add(Xform):
@@ -234,7 +223,7 @@ class Add(Xform):
234223
def __init__(self, addend):
235224
"""
236225
Initialize an Add Xform using a Numpy array of predefined values.
237-
:param addend: An ndarray of shape (n,)
226+
:param addend: An float/int or an ndarray of shape (n,)
238227
"""
239228
super().__init__()
240229
self.addend = np.array(addend)
@@ -249,12 +238,8 @@ def _forward(self, im, indices):
249238

250239
def __str__(self):
251240
if self.addend.size == 1:
252-
str_out = (self.__class__.__name__ + ' with same number of '
253-
+ str(self.addend))
254-
else:
255-
str_out = self.__class__.__name__ + ' with different numbers'
256-
257-
return str_out
241+
return f'Add ({self.addend})'
242+
return 'Add (multiple)'
258243

259244

260245
class FlipXform(Xform):
@@ -282,8 +267,7 @@ def _forward(self, im, indices):
282267
return Image(im_out)
283268

284269
def __str__(self):
285-
return (self.__class__.__name__ + ' with filters of '
286-
+ str(self.filters[0]))
270+
return f'FlipXform ({self.filters[0]})'
287271

288272

289273
class LambdaXform(Xform):
@@ -310,8 +294,7 @@ def _forward(self, im, indices):
310294
return Image(im_out)
311295

312296
def __str__(self):
313-
return (self.__class__.__name__ + ' with function of '
314-
+ self.lambda_fun.__name__)
297+
return f'LambdaXform ({self.lambda_fun.__name__})'
315298

316299

317300
class NoiseAdder(Xform):
@@ -445,6 +428,9 @@ def __init__(self, xforms=None, memory=None):
445428
self.memory = memory
446429
self.active = True
447430

431+
def __str__(self):
432+
return '\n'.join([f'{xform}' for xform in self.xforms])
433+
448434
def add_xform(self, xform):
449435
"""
450436
Add a single `Xform` object at the end of the pipeline.

0 commit comments

Comments
 (0)