@@ -123,13 +123,13 @@ class Multiply(SymmetricXform):
123123 def __init__ (self , factor ):
124124 """
125125 Initialize a Multiply Xform using specified factors
126- :param factor: An ndarray of scalar factors to use for amplitude multiplication.
126+ :param factor: A float/int or an ndarray of scalar factors to use for amplitude multiplication.
127127 """
128128 super ().__init__ ()
129129 self .multipliers = np .array (factor )
130130
131131 def _forward (self , im , indices ):
132- if self .multipliers .size == 1 :
132+ if self .multipliers .size == 1 : # if we have a scalar multiplier
133133 im_new = im * self .multipliers
134134 else :
135135 im_new = im * self .multipliers [indices ]
@@ -138,12 +138,8 @@ def _forward(self, im, indices):
138138
139139 def __str__ (self ):
140140 if self .multipliers .size == 1 :
141- str_out = (self .__class__ .__name__ + ' by same number of '
142- + str (self .multipliers ))
143- else :
144- str_out = self .__class__ .__name__ + ' by difference numbers'
145-
146- return str_out
141+ return f'Multiply ({ self .multipliers } )'
142+ return 'Multiply (multiple)'
147143
148144
149145class Shift (LinearXform ):
@@ -154,36 +150,31 @@ class Shift(LinearXform):
154150 def __init__ (self , shifts ):
155151 """
156152 Initialize a Shift Xform using a Numpy array of shift values.
157- :param shifts: An ndarray of shape (n, 2)
153+ :param shifts: An ndarray of shape (2) or ( n, 2)
158154 """
159155 super ().__init__ ()
160156 self .shifts = np .array (shifts )
161- self .n = shifts .shape [0 ]
162157
163158 def _forward (self , im , indices ):
164- if self .n == 1 :
159+ if self .shifts . ndim == 1 :
165160 im_new = im .shift (self .shifts )
166161 else :
167162 im_new = im .shift (self .shifts [indices ])
168163
169164 return im_new
170165
171166 def _adjoint (self , im , indices ):
172- if self .n == 1 :
167+ if self .shifts . ndim == 1 :
173168 im_new = im .shift (- self .shifts )
174169 else :
175170 im_new = im .shift (- self .shifts [indices ])
176171
177172 return im_new
178173
179174 def __str__ (self ):
180- if self .n == 1 :
181- str_out = (self .__class__ .__name__ + ' by same number of '
182- + str (self .shifts ))
183- else :
184- str_out = self .__class__ .__name__ + ' by difference numbers'
185-
186- return str_out
175+ if self .shifts .ndim == 1 :
176+ return f'Shift ({ self .shifts } )'
177+ return 'Shift (multiple)'
187178
188179
189180class Downsample (LinearXform ):
@@ -198,12 +189,11 @@ def _forward(self, im, indices):
198189 return im .downsample (self .resolution )
199190
200191 def _adjoint (self , im , indices ):
201- # TODO: Implement upsampling with zero-padding
192+ # TODO: Implement up-sampling with zero-padding
202193 raise NotImplementedError ('Adjoint of downsampling not implemented yet.' )
203194
204195 def __str__ (self ):
205- return (self .__class__ .__name__ + ' at resolution of '
206- + str (self .resolution ))
196+ return f'Downsample (Resolution { self .resolution } )'
207197
208198
209199class FilterXform (SymmetricXform ):
@@ -222,8 +212,7 @@ def _forward(self, im, indices):
222212 return im .filter (self .filter )
223213
224214 def __str__ (self ):
225- return (self .__class__ .__name__ + ' with filter of '
226- + str (self .filter ))
215+ return f'FilterXform ({ self .filter } )'
227216
228217
229218class Add (Xform ):
@@ -234,7 +223,7 @@ class Add(Xform):
234223 def __init__ (self , addend ):
235224 """
236225 Initialize an Add Xform using a Numpy array of predefined values.
237- :param addend: An ndarray of shape (n,)
226+ :param addend: An float/int or an ndarray of shape (n,)
238227 """
239228 super ().__init__ ()
240229 self .addend = np .array (addend )
@@ -249,12 +238,8 @@ def _forward(self, im, indices):
249238
250239 def __str__ (self ):
251240 if self .addend .size == 1 :
252- str_out = (self .__class__ .__name__ + ' with same number of '
253- + str (self .addend ))
254- else :
255- str_out = self .__class__ .__name__ + ' with different numbers'
256-
257- return str_out
241+ return f'Add ({ self .addend } )'
242+ return 'Add (multiple)'
258243
259244
260245class FlipXform (Xform ):
@@ -282,8 +267,7 @@ def _forward(self, im, indices):
282267 return Image (im_out )
283268
284269 def __str__ (self ):
285- return (self .__class__ .__name__ + ' with filters of '
286- + str (self .filters [0 ]))
270+ return f'FlipXform ({ self .filters [0 ]} )'
287271
288272
289273class LambdaXform (Xform ):
@@ -310,8 +294,7 @@ def _forward(self, im, indices):
310294 return Image (im_out )
311295
312296 def __str__ (self ):
313- return (self .__class__ .__name__ + ' with function of '
314- + self .lambda_fun .__name__ )
297+ return f'LambdaXform ({ self .lambda_fun .__name__ } )'
315298
316299
317300class NoiseAdder (Xform ):
@@ -445,6 +428,9 @@ def __init__(self, xforms=None, memory=None):
445428 self .memory = memory
446429 self .active = True
447430
431+ def __str__ (self ):
432+ return '\n ' .join ([f'{ xform } ' for xform in self .xforms ])
433+
448434 def add_xform (self , xform ):
449435 """
450436 Add a single `Xform` object at the end of the pipeline.
0 commit comments