@@ -112,7 +112,7 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
112112 self .fir_kernel = fir_kernel
113113 self .out_channels = out_channels
114114
115- def _upsample_2d (self , x , weight = None , kernel = None , factor = 2 , gain = 1 ):
115+ def _upsample_2d (self , hidden_states , weight = None , kernel = None , factor = 2 , gain = 1 ):
116116 """Fused `upsample_2d()` followed by `Conv2d()`.
117117
118118 Args:
@@ -151,34 +151,46 @@ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
151151 convW = weight .shape [3 ]
152152 inC = weight .shape [1 ]
153153
154- p = (kernel .shape [0 ] - factor ) - (convW - 1 )
154+ pad_value = (kernel .shape [0 ] - factor ) - (convW - 1 )
155155
156156 stride = (factor , factor )
157157 # Determine data dimensions.
158- output_shape = ((x .shape [2 ] - 1 ) * factor + convH , (x .shape [3 ] - 1 ) * factor + convW )
158+ output_shape = (
159+ (hidden_states .shape [2 ] - 1 ) * factor + convH ,
160+ (hidden_states .shape [3 ] - 1 ) * factor + convW ,
161+ )
159162 output_padding = (
160- output_shape [0 ] - (x .shape [2 ] - 1 ) * stride [0 ] - convH ,
161- output_shape [1 ] - (x .shape [3 ] - 1 ) * stride [1 ] - convW ,
163+ output_shape [0 ] - (hidden_states .shape [2 ] - 1 ) * stride [0 ] - convH ,
164+ output_shape [1 ] - (hidden_states .shape [3 ] - 1 ) * stride [1 ] - convW ,
162165 )
163166 assert output_padding [0 ] >= 0 and output_padding [1 ] >= 0
164167 inC = weight .shape [1 ]
165- num_groups = x .shape [1 ] // inC
168+ num_groups = hidden_states .shape [1 ] // inC
166169
167170 # Transpose weights.
168171 weight = torch .reshape (weight , (num_groups , - 1 , inC , convH , convW ))
169172 weight = torch .flip (weight , dims = [3 , 4 ]).permute (0 , 2 , 1 , 3 , 4 )
170173 weight = torch .reshape (weight , (num_groups * inC , - 1 , convH , convW ))
171174
172- x = F .conv_transpose2d (x , weight , stride = stride , output_padding = output_padding , padding = 0 )
175+ inverse_conv = F .conv_transpose2d (
176+ hidden_states , weight , stride = stride , output_padding = output_padding , padding = 0
177+ )
173178
174- x = upfirdn2d_native (x , torch .tensor (kernel , device = x .device ), pad = ((p + 1 ) // 2 + factor - 1 , p // 2 + 1 ))
179+ output = upfirdn2d_native (
180+ inverse_conv ,
181+ torch .tensor (kernel , device = inverse_conv .device ),
182+ pad = ((pad_value + 1 ) // 2 + factor - 1 , pad_value // 2 + 1 ),
183+ )
175184 else :
176- p = kernel .shape [0 ] - factor
177- x = upfirdn2d_native (
178- x , torch .tensor (kernel , device = x .device ), up = factor , pad = ((p + 1 ) // 2 + factor - 1 , p // 2 )
185+ pad_value = kernel .shape [0 ] - factor
186+ output = upfirdn2d_native (
187+ hidden_states ,
188+ torch .tensor (kernel , device = hidden_states .device ),
189+ up = factor ,
190+ pad = ((pad_value + 1 ) // 2 + factor - 1 , pad_value // 2 ),
179191 )
180192
181- return x
193+ return output
182194
183195 def forward (self , hidden_states ):
184196 if self .use_conv :
@@ -200,7 +212,7 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
200212 self .use_conv = use_conv
201213 self .out_channels = out_channels
202214
203- def _downsample_2d (self , x , weight = None , kernel = None , factor = 2 , gain = 1 ):
215+ def _downsample_2d (self , hidden_states , weight = None , kernel = None , factor = 2 , gain = 1 ):
204216 """Fused `Conv2d()` followed by `downsample_2d()`.
205217
206218 Args:
@@ -232,20 +244,29 @@ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
232244
233245 if self .use_conv :
234246 _ , _ , convH , convW = weight .shape
235- p = (kernel .shape [0 ] - factor ) + (convW - 1 )
236- s = [factor , factor ]
237- x = upfirdn2d_native (x , torch .tensor (kernel , device = x .device ), pad = ((p + 1 ) // 2 , p // 2 ))
238- x = F .conv2d (x , weight , stride = s , padding = 0 )
247+ pad_value = (kernel .shape [0 ] - factor ) + (convW - 1 )
248+ stride_value = [factor , factor ]
249+ upfirdn_input = upfirdn2d_native (
250+ hidden_states ,
251+ torch .tensor (kernel , device = hidden_states .device ),
252+ pad = ((pad_value + 1 ) // 2 , pad_value // 2 ),
253+ )
254+ hidden_states = F .conv2d (upfirdn_input , weight , stride = stride_value , padding = 0 )
239255 else :
240- p = kernel .shape [0 ] - factor
241- x = upfirdn2d_native (x , torch .tensor (kernel , device = x .device ), down = factor , pad = ((p + 1 ) // 2 , p // 2 ))
256+ pad_value = kernel .shape [0 ] - factor
257+ hidden_states = upfirdn2d_native (
258+ hidden_states ,
259+ torch .tensor (kernel , device = hidden_states .device ),
260+ down = factor ,
261+ pad = ((pad_value + 1 ) // 2 , pad_value // 2 ),
262+ )
242263
243- return x
264+ return hidden_states
244265
245266 def forward (self , hidden_states ):
246267 if self .use_conv :
247- hidden_states = self ._downsample_2d (hidden_states , weight = self .Conv2d_0 .weight , kernel = self .fir_kernel )
248- hidden_states = hidden_states + self .Conv2d_0 .bias .reshape (1 , - 1 , 1 , 1 )
268+ downsample_input = self ._downsample_2d (hidden_states , weight = self .Conv2d_0 .weight , kernel = self .fir_kernel )
269+ hidden_states = downsample_input + self .Conv2d_0 .bias .reshape (1 , - 1 , 1 , 1 )
249270 else :
250271 hidden_states = self ._downsample_2d (hidden_states , kernel = self .fir_kernel , factor = 2 )
251272
@@ -332,17 +353,17 @@ def __init__(
332353 if self .use_in_shortcut :
333354 self .conv_shortcut = torch .nn .Conv2d (in_channels , out_channels , kernel_size = 1 , stride = 1 , padding = 0 )
334355
335- def forward (self , x , temb ):
336- hidden_states = x
356+ def forward (self , input_tensor , temb ):
357+ hidden_states = input_tensor
337358
338359 hidden_states = self .norm1 (hidden_states )
339360 hidden_states = self .nonlinearity (hidden_states )
340361
341362 if self .upsample is not None :
342- x = self .upsample (x )
363+ input_tensor = self .upsample (input_tensor )
343364 hidden_states = self .upsample (hidden_states )
344365 elif self .downsample is not None :
345- x = self .downsample (x )
366+ input_tensor = self .downsample (input_tensor )
346367 hidden_states = self .downsample (hidden_states )
347368
348369 hidden_states = self .conv1 (hidden_states )
@@ -358,19 +379,19 @@ def forward(self, x, temb):
358379 hidden_states = self .conv2 (hidden_states )
359380
360381 if self .conv_shortcut is not None :
361- x = self .conv_shortcut (x )
382+ input_tensor = self .conv_shortcut (input_tensor )
362383
363- out = (x + hidden_states ) / self .output_scale_factor
384+ output_tensor = (input_tensor + hidden_states ) / self .output_scale_factor
364385
365- return out
386+ return output_tensor
366387
367388
368389class Mish (torch .nn .Module ):
369- def forward (self , x ):
370- return x * torch .tanh (torch .nn .functional .softplus (x ))
390+ def forward (self , hidden_states ):
391+ return hidden_states * torch .tanh (torch .nn .functional .softplus (hidden_states ))
371392
372393
373- def upsample_2d (x , kernel = None , factor = 2 , gain = 1 ):
394+ def upsample_2d (hidden_states , kernel = None , factor = 2 , gain = 1 ):
374395 r"""Upsample2D a batch of 2D images with the given filter.
375396
376397 Args:
@@ -397,11 +418,16 @@ def upsample_2d(x, kernel=None, factor=2, gain=1):
397418 kernel /= torch .sum (kernel )
398419
399420 kernel = kernel * (gain * (factor ** 2 ))
400- p = kernel .shape [0 ] - factor
401- return upfirdn2d_native (x , kernel .to (device = x .device ), up = factor , pad = ((p + 1 ) // 2 + factor - 1 , p // 2 ))
421+ pad_value = kernel .shape [0 ] - factor
422+ return upfirdn2d_native (
423+ hidden_states ,
424+ kernel .to (device = hidden_states .device ),
425+ up = factor ,
426+ pad = ((pad_value + 1 ) // 2 + factor - 1 , pad_value // 2 ),
427+ )
402428
403429
404- def downsample_2d (x , kernel = None , factor = 2 , gain = 1 ):
430+ def downsample_2d (hidden_states , kernel = None , factor = 2 , gain = 1 ):
405431 r"""Downsample2D a batch of 2D images with the given filter.
406432
407433 Args:
@@ -429,8 +455,10 @@ def downsample_2d(x, kernel=None, factor=2, gain=1):
429455 kernel /= torch .sum (kernel )
430456
431457 kernel = kernel * gain
432- p = kernel .shape [0 ] - factor
433- return upfirdn2d_native (x , kernel .to (device = x .device ), down = factor , pad = ((p + 1 ) // 2 , p // 2 ))
458+ pad_value = kernel .shape [0 ] - factor
459+ return upfirdn2d_native (
460+ hidden_states , kernel .to (device = hidden_states .device ), down = factor , pad = ((pad_value + 1 ) // 2 , pad_value // 2 )
461+ )
434462
435463
436464def upfirdn2d_native (input , kernel , up = 1 , down = 1 , pad = (0 , 0 )):
@@ -441,6 +469,7 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
441469
442470 _ , channel , in_h , in_w = input .shape
443471 input = input .reshape (- 1 , in_h , in_w , 1 )
472+ # Rename this variable (input); it shadows a builtin.sonarlint(python:S5806)
444473
445474 _ , in_h , in_w , minor = input .shape
446475 kernel_h , kernel_w = kernel .shape
0 commit comments