From 54265efeba71ed969b44e26a9217f18349af59cf Mon Sep 17 00:00:00 2001 From: i-am-epic Date: Thu, 29 Sep 2022 22:17:19 +0530 Subject: [PATCH 1/8] renamed single letter variables --- src/diffusers/models/resnet.py | 86 +++++++++++++++++----------------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 97f3c02a8ccf..dcbb191e37e4 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -34,21 +34,21 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann else: self.Conv2d_0 = conv - def forward(self, x): - assert x.shape[1] == self.channels + def forward(self, input_tensor): + assert input_tensor.shape[1] == self.channels if self.use_conv_transpose: - return self.conv(x) + return self.conv(input_tensor) - x = F.interpolate(x, scale_factor=2.0, mode="nearest") + upsample_input = F.interpolate(input_tensor, scale_factor=2.0, mode="nearest") # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": - x = self.conv(x) + output_tensor = self.conv(upsample_input) else: - x = self.Conv2d_0(x) + output_tensor = self.Conv2d_0(upsample_input) - return x + return output_tensor class Downsample2D(nn.Module): @@ -84,16 +84,16 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= else: self.conv = conv - def forward(self, x): - assert x.shape[1] == self.channels + def forward(self, input_tensor): + assert input_tensor.shape[1] == self.channels if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) - x = F.pad(x, pad, mode="constant", value=0) + padded_input = F.pad(input_tensor, pad, mode="constant", value=0) - assert x.shape[1] == self.channels - x = self.conv(x) + assert padded_input.shape[1] == self.channels + output_tensor = self.conv(padded_input) - return x + return output_tensor class FirUpsample2D(nn.Module): @@ -106,7 +106,7 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel= self.fir_kernel = fir_kernel self.out_channels = out_channels - def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): + def _upsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain=1): """Fused `upsample_2d()` followed by `Conv2d()`. Args: @@ -149,30 +149,30 @@ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): stride = (factor, factor) # Determine data dimensions. - output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW) + output_shape = ((input_tensor.shape[2] - 1) * factor + convH, (input_tensor.shape[3] - 1) * factor + convW) output_padding = ( - output_shape[0] - (x.shape[2] - 1) * stride[0] - convH, - output_shape[1] - (x.shape[3] - 1) * stride[1] - convW, + output_shape[0] - (input_tensor.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (input_tensor.shape[3] - 1) * stride[1] - convW, ) assert output_padding[0] >= 0 and output_padding[1] >= 0 inC = weight.shape[1] - num_groups = x.shape[1] // inC + num_groups = input_tensor.shape[1] // inC # Transpose weights. weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) - x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0) + inverse_conv = F.conv_transpose2d(input_tensor, weight, stride=stride, output_padding=output_padding, padding=0) - x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) + output = upfirdn2d_native(inverse_conv, torch.tensor(kernel, device=inverse_conv.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) else: p = kernel.shape[0] - factor - x = upfirdn2d_native( - x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) + output = upfirdn2d_native( + input_tensor, torch.tensor(kernel, device=input_tensor.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) ) - return x + return output def forward(self, x): if self.use_conv: @@ -236,14 +236,14 @@ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): return x - def forward(self, x): + def forward(self, input_tensor): if self.use_conv: - x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) - x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + downsample_input = self._downsample_2d(input_tensor, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) + output_tensor = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: - x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2) + output_tensor = self._downsample_2d(input_tensor, kernel=self.fir_kernel, factor=2) - return x + return output_tensor class ResnetBlock2D(nn.Module): @@ -326,8 +326,8 @@ def __init__( if self.use_in_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - def forward(self, x, temb): - hidden_states = x + def forward(self, input_tensor, temb): + hidden_states = input_tensor # make sure hidden states is in float32 # when running in half-precision @@ -335,10 +335,10 @@ def forward(self, x, temb): hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: - x = self.upsample(x) + sample_input = self.upsample(input_tensor) hidden_states = self.upsample(hidden_states) elif self.downsample is not None: - x = self.downsample(x) + sample_input = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) hidden_states = self.conv1(hidden_states) @@ -356,19 +356,19 @@ def forward(self, x, temb): hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - x = self.conv_shortcut(x) + sample_conv_input = self.conv_shortcut(sample_input) - out = (x + hidden_states) / self.output_scale_factor + output_tensor = (sample_conv_input + hidden_states) / self.output_scale_factor - return out + return output_tensor class Mish(torch.nn.Module): - def forward(self, x): - return x * torch.tanh(torch.nn.functional.softplus(x)) + def forward(self, input_tensor): + return input_tensor * torch.tanh(torch.nn.functional.softplus(input_tensor)) -def upsample_2d(x, kernel=None, factor=2, gain=1): +def upsample_2d(input_tensor, kernel=None, factor=2, gain=1): r"""Upsample2D a batch of 2D images with the given filter. Args: @@ -395,11 +395,11 @@ def upsample_2d(x, kernel=None, factor=2, gain=1): kernel /= torch.sum(kernel) kernel = kernel * (gain * (factor**2)) - p = kernel.shape[0] - factor - return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) + pad_value = kernel.shape[0] - factor + return upfirdn2d_native(input_tensor, kernel.to(device=input_tensor.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2)) -def downsample_2d(x, kernel=None, factor=2, gain=1): +def downsample_2d(input_tensor, kernel=None, factor=2, gain=1): r"""Downsample2D a batch of 2D images with the given filter. Args: @@ -427,8 +427,8 @@ def downsample_2d(x, kernel=None, factor=2, gain=1): kernel /= torch.sum(kernel) kernel = kernel * gain - p = kernel.shape[0] - factor - return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + pad_value = kernel.shape[0] - factor + return upfirdn2d_native(input_tensor, kernel.to(device=input_tensor.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)) def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): From e25c501359a339e0de3051c8f4ab4f7e638a67df Mon Sep 17 00:00:00 2001 From: NIKHIL A V <58301643+i-am-epic@users.noreply.github.com> Date: Thu, 29 Sep 2022 22:36:05 +0530 Subject: [PATCH 2/8] renamed x to meaningful variable in resnet.py Hello @patil-suraj can you verify it Thanks --- src/diffusers/models/resnet.py | 104 ++++++++++++++++----------------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 97f3c02a8ccf..0eb29873bf2c 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -34,21 +34,21 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann else: self.Conv2d_0 = conv - def forward(self, x): - assert x.shape[1] == self.channels + def forward(self, input_tensor): + assert input_tensor.shape[1] == self.channels if self.use_conv_transpose: - return self.conv(x) + return self.conv(input_tensor) - x = F.interpolate(x, scale_factor=2.0, mode="nearest") + upsample_input = F.interpolate(input_tensor, scale_factor=2.0, mode="nearest") # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": - x = self.conv(x) + output_tensor = self.conv(upsample_input) else: - x = self.Conv2d_0(x) + output_tensor = self.Conv2d_0(upsample_input) - return x + return output_tensor class Downsample2D(nn.Module): @@ -84,16 +84,16 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= else: self.conv = conv - def forward(self, x): - assert x.shape[1] == self.channels + def forward(self, input_tensor): + assert input_tensor.shape[1] == self.channels if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) - x = F.pad(x, pad, mode="constant", value=0) + padded_input = F.pad(input_tensor, pad, mode="constant", value=0) - assert x.shape[1] == self.channels - x = self.conv(x) + assert padded_input.shape[1] == self.channels + output_tensor = self.conv(padded_input) - return x + return output_tensor class FirUpsample2D(nn.Module): @@ -106,7 +106,7 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel= self.fir_kernel = fir_kernel self.out_channels = out_channels - def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): + def _upsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain=1): """Fused `upsample_2d()` followed by `Conv2d()`. Args: @@ -149,37 +149,37 @@ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): stride = (factor, factor) # Determine data dimensions. - output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW) + output_shape = ((input_tensor.shape[2] - 1) * factor + convH, (input_tensor.shape[3] - 1) * factor + convW) output_padding = ( - output_shape[0] - (x.shape[2] - 1) * stride[0] - convH, - output_shape[1] - (x.shape[3] - 1) * stride[1] - convW, + output_shape[0] - (input_tensor.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (input_tensor.shape[3] - 1) * stride[1] - convW, ) assert output_padding[0] >= 0 and output_padding[1] >= 0 inC = weight.shape[1] - num_groups = x.shape[1] // inC + num_groups = input_tensor.shape[1] // inC # Transpose weights. weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) - x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0) + inverse_conv = F.conv_transpose2d(input_tensor, weight, stride=stride, output_padding=output_padding, padding=0) - x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) + output = upfirdn2d_native(inverse_conv, torch.tensor(kernel, device=inverse_conv.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) else: p = kernel.shape[0] - factor - x = upfirdn2d_native( - x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) + output = upfirdn2d_native( + input_tensor, torch.tensor(kernel, device=input_tensor.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) ) - return x + return output - def forward(self, x): + def forward(self, input_tensor): if self.use_conv: - height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel) + height = self._upsample_2d(input_tensor, self.Conv2d_0.weight, kernel=self.fir_kernel) height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: - height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2) + height = self._upsample_2d(input_tensor, kernel=self.fir_kernel, factor=2) return height @@ -194,7 +194,7 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel= self.use_conv = use_conv self.out_channels = out_channels - def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): + def _downsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain=1): """Fused `Conv2d()` followed by `downsample_2d()`. Args: @@ -226,24 +226,24 @@ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): if self.use_conv: _, _, convH, convW = weight.shape - p = (kernel.shape[0] - factor) + (convW - 1) + pad_value = (kernel.shape[0] - factor) + (convW - 1) s = [factor, factor] - x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2)) - x = F.conv2d(x, weight, stride=s, padding=0) + upfirdn_input = upfirdn2d_native(input_tensor, torch.tensor(kernel, device=input_tensor.device), pad=((pad_value + 1) // 2, pad_value // 2)) + output_tensor = F.conv2d(upfirdn_input, weight, stride=s, padding=0) else: p = kernel.shape[0] - factor - x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + output_tensor = upfirdn2d_native(input_tensor, torch.tensor(kernel, device=input_tensor.device), down=factor, pad=((p + 1) // 2, p // 2)) - return x + return output_tensor - def forward(self, x): + def forward(self, input_tensor): if self.use_conv: - x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) - x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + downsample_input = self._downsample_2d(input_tensor, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) + output_tensor = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: - x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2) + output_tensor = self._downsample_2d(input_tensor, kernel=self.fir_kernel, factor=2) - return x + return output_tensor class ResnetBlock2D(nn.Module): @@ -326,8 +326,8 @@ def __init__( if self.use_in_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - def forward(self, x, temb): - hidden_states = x + def forward(self, input_tensor, temb): + hidden_states = input_tensor # make sure hidden states is in float32 # when running in half-precision @@ -335,10 +335,10 @@ def forward(self, x, temb): hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: - x = self.upsample(x) + sample_input = self.upsample(input_tensor) hidden_states = self.upsample(hidden_states) elif self.downsample is not None: - x = self.downsample(x) + sample_input = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) hidden_states = self.conv1(hidden_states) @@ -356,19 +356,19 @@ def forward(self, x, temb): hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - x = self.conv_shortcut(x) + sample_conv_input = self.conv_shortcut(sample_input) - out = (x + hidden_states) / self.output_scale_factor + output_tensor = (sample_conv_input + hidden_states) / self.output_scale_factor - return out + return output_tensor class Mish(torch.nn.Module): - def forward(self, x): - return x * torch.tanh(torch.nn.functional.softplus(x)) + def forward(self, input_tensor): + return input_tensor * torch.tanh(torch.nn.functional.softplus(input_tensor)) -def upsample_2d(x, kernel=None, factor=2, gain=1): +def upsample_2d(input_tensor, kernel=None, factor=2, gain=1): r"""Upsample2D a batch of 2D images with the given filter. Args: @@ -395,11 +395,11 @@ def upsample_2d(x, kernel=None, factor=2, gain=1): kernel /= torch.sum(kernel) kernel = kernel * (gain * (factor**2)) - p = kernel.shape[0] - factor - return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) + pad_value = kernel.shape[0] - factor + return upfirdn2d_native(input_tensor, kernel.to(device=input_tensor.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2)) -def downsample_2d(x, kernel=None, factor=2, gain=1): +def downsample_2d(input_tensor, kernel=None, factor=2, gain=1): r"""Downsample2D a batch of 2D images with the given filter. Args: @@ -427,8 +427,8 @@ def downsample_2d(x, kernel=None, factor=2, gain=1): kernel /= torch.sum(kernel) kernel = kernel * gain - p = kernel.shape[0] - factor - return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + pad_value = kernel.shape[0] - factor + return upfirdn2d_native(input_tensor, kernel.to(device=input_tensor.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)) def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): From 19a13d0d200e551f116aba6f371bede7cd695b6a Mon Sep 17 00:00:00 2001 From: NIKHIL A V <58301643+i-am-epic@users.noreply.github.com> Date: Fri, 30 Sep 2022 01:58:22 +0530 Subject: [PATCH 3/8] Reformatted using black --- src/diffusers/models/resnet.py | 46 +++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 5690bf34d4a0..22783e42b542 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -34,7 +34,6 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann else: self.Conv2d_0 = conv - def forward(self, input_tensor): assert input_tensor.shape[1] == self.channels if self.use_conv_transpose: @@ -42,11 +41,9 @@ def forward(self, input_tensor): upsample_input = F.interpolate(input_tensor, scale_factor=2.0, mode="nearest") - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": - output_tensor = self.conv(upsample_input) else: output_tensor = self.Conv2d_0(upsample_input) @@ -54,7 +51,6 @@ def forward(self, input_tensor): return output_tensor - class Downsample2D(nn.Module): """ A downsampling layer with an optional convolution. @@ -88,7 +84,6 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= else: self.conv = conv - def forward(self, input_tensor): assert input_tensor.shape[1] == self.channels if self.use_conv and self.padding == 0: @@ -101,7 +96,6 @@ def forward(self, input_tensor): return output_tensor - class FirUpsample2D(nn.Module): def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() @@ -169,18 +163,26 @@ def _upsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain=1) weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) - inverse_conv = F.conv_transpose2d(input_tensor, weight, stride=stride, output_padding=output_padding, padding=0) + inverse_conv = F.conv_transpose2d( + input_tensor, weight, stride=stride, output_padding=output_padding, padding=0 + ) - output = upfirdn2d_native(inverse_conv, torch.tensor(kernel, device=inverse_conv.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) + output = upfirdn2d_native( + inverse_conv, + torch.tensor(kernel, device=inverse_conv.device), + pad=((p + 1) // 2 + factor - 1, p // 2 + 1), + ) else: p = kernel.shape[0] - factor output = upfirdn2d_native( - input_tensor, torch.tensor(kernel, device=input_tensor.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) + input_tensor, + torch.tensor(kernel, device=input_tensor.device), + up=factor, + pad=((p + 1) // 2 + factor - 1, p // 2), ) return output - def forward(self, input_tensor): if self.use_conv: height = self._upsample_2d(input_tensor, self.Conv2d_0.weight, kernel=self.fir_kernel) @@ -188,7 +190,6 @@ def forward(self, input_tensor): else: height = self._upsample_2d(input_tensor, kernel=self.fir_kernel, factor=2) - return height @@ -236,15 +237,20 @@ def _downsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain= _, _, convH, convW = weight.shape pad_value = (kernel.shape[0] - factor) + (convW - 1) s = [factor, factor] - upfirdn_input = upfirdn2d_native(input_tensor, torch.tensor(kernel, device=input_tensor.device), pad=((pad_value + 1) // 2, pad_value // 2)) + upfirdn_input = upfirdn2d_native( + input_tensor, + torch.tensor(kernel, device=input_tensor.device), + pad=((pad_value + 1) // 2, pad_value // 2), + ) output_tensor = F.conv2d(upfirdn_input, weight, stride=s, padding=0) else: p = kernel.shape[0] - factor - output_tensor = upfirdn2d_native(input_tensor, torch.tensor(kernel, device=input_tensor.device), down=factor, pad=((p + 1) // 2, p // 2)) + output_tensor = upfirdn2d_native( + input_tensor, torch.tensor(kernel, device=input_tensor.device), down=factor, pad=((p + 1) // 2, p // 2) + ) return output_tensor - def forward(self, input_tensor): if self.use_conv: downsample_input = self._downsample_2d(input_tensor, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) @@ -255,7 +261,6 @@ def forward(self, input_tensor): return output_tensor - class ResnetBlock2D(nn.Module): def __init__( self, @@ -406,7 +411,12 @@ def upsample_2d(input_tensor, kernel=None, factor=2, gain=1): kernel = kernel * (gain * (factor**2)) pad_value = kernel.shape[0] - factor - return upfirdn2d_native(input_tensor, kernel.to(device=input_tensor.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2)) + return upfirdn2d_native( + input_tensor, + kernel.to(device=input_tensor.device), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) def downsample_2d(input_tensor, kernel=None, factor=2, gain=1): @@ -438,7 +448,9 @@ def downsample_2d(input_tensor, kernel=None, factor=2, gain=1): kernel = kernel * gain pad_value = kernel.shape[0] - factor - return upfirdn2d_native(input_tensor, kernel.to(device=input_tensor.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)) + return upfirdn2d_native( + input_tensor, kernel.to(device=input_tensor.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) + ) def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): From 9d23c6f9d93f82921cc1db990dd1ba9a45b2d67b Mon Sep 17 00:00:00 2001 From: NIKHIL A V <58301643+i-am-epic@users.noreply.github.com> Date: Thu, 29 Sep 2022 22:36:05 +0530 Subject: [PATCH 4/8] renamed x to meaningful variable in resnet.py Hello @patil-suraj can you verify it Thanks --- setup.py | 13 ++++--------- src/diffusers/models/resnet.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 20c9ea61f5f2..a965e7bfa12f 100644 --- a/setup.py +++ b/setup.py @@ -177,14 +177,7 @@ def run(self): extras["docs"] = deps_list("hf-doc-builder") extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards") extras["test"] = deps_list( - "datasets", - "onnxruntime-gpu", - "pytest", - "pytest-timeout", - "pytest-xdist", - "scipy", - "torchvision", - "transformers" + "datasets", "onnxruntime-gpu", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers" ) extras["torch"] = deps_list("torch") @@ -193,7 +186,9 @@ def run(self): else: extras["flax"] = deps_list("jax", "jaxlib", "flax") -extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"] +extras["dev"] = ( + extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"] +) install_requires = [ deps["importlib_metadata"], diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 22783e42b542..e05e5047cb42 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -163,6 +163,7 @@ def _upsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain=1) weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) +<<<<<<< HEAD inverse_conv = F.conv_transpose2d( input_tensor, weight, stride=stride, output_padding=output_padding, padding=0 ) @@ -183,6 +184,19 @@ def _upsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain=1) return output +======= + inverse_conv = F.conv_transpose2d(input_tensor, weight, stride=stride, output_padding=output_padding, padding=0) + + output = upfirdn2d_native(inverse_conv, torch.tensor(kernel, device=inverse_conv.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) + else: + p = kernel.shape[0] - factor + output = upfirdn2d_native( + input_tensor, torch.tensor(kernel, device=input_tensor.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) + ) + + return output + +>>>>>>> e25c501 (renamed x to meaningful variable in resnet.py) def forward(self, input_tensor): if self.use_conv: height = self._upsample_2d(input_tensor, self.Conv2d_0.weight, kernel=self.fir_kernel) @@ -237,6 +251,7 @@ def _downsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain= _, _, convH, convW = weight.shape pad_value = (kernel.shape[0] - factor) + (convW - 1) s = [factor, factor] +<<<<<<< HEAD upfirdn_input = upfirdn2d_native( input_tensor, torch.tensor(kernel, device=input_tensor.device), @@ -248,6 +263,13 @@ def _downsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain= output_tensor = upfirdn2d_native( input_tensor, torch.tensor(kernel, device=input_tensor.device), down=factor, pad=((p + 1) // 2, p // 2) ) +======= + upfirdn_input = upfirdn2d_native(input_tensor, torch.tensor(kernel, device=input_tensor.device), pad=((pad_value + 1) // 2, pad_value // 2)) + output_tensor = F.conv2d(upfirdn_input, weight, stride=s, padding=0) + else: + p = kernel.shape[0] - factor + output_tensor = upfirdn2d_native(input_tensor, torch.tensor(kernel, device=input_tensor.device), down=factor, pad=((p + 1) // 2, p // 2)) +>>>>>>> e25c501 (renamed x to meaningful variable in resnet.py) return output_tensor @@ -411,12 +433,16 @@ def upsample_2d(input_tensor, kernel=None, factor=2, gain=1): kernel = kernel * (gain * (factor**2)) pad_value = kernel.shape[0] - factor +<<<<<<< HEAD return upfirdn2d_native( input_tensor, kernel.to(device=input_tensor.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) +======= + return upfirdn2d_native(input_tensor, kernel.to(device=input_tensor.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2)) +>>>>>>> e25c501 (renamed x to meaningful variable in resnet.py) def downsample_2d(input_tensor, kernel=None, factor=2, gain=1): @@ -448,9 +474,13 @@ def downsample_2d(input_tensor, kernel=None, factor=2, gain=1): kernel = kernel * gain pad_value = kernel.shape[0] - factor +<<<<<<< HEAD return upfirdn2d_native( input_tensor, kernel.to(device=input_tensor.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) ) +======= + return upfirdn2d_native(input_tensor, kernel.to(device=input_tensor.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)) +>>>>>>> e25c501 (renamed x to meaningful variable in resnet.py) def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): From f9b060bb514346f562fc57d72d1e5a1651800057 Mon Sep 17 00:00:00 2001 From: i-am-epic Date: Fri, 30 Sep 2022 02:16:37 +0530 Subject: [PATCH 5/8] reformatted the files --- src/diffusers/models/resnet.py | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index e05e5047cb42..22783e42b542 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -163,7 +163,6 @@ def _upsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain=1) weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) -<<<<<<< HEAD inverse_conv = F.conv_transpose2d( input_tensor, weight, stride=stride, output_padding=output_padding, padding=0 ) @@ -184,19 +183,6 @@ def _upsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain=1) return output -======= - inverse_conv = F.conv_transpose2d(input_tensor, weight, stride=stride, output_padding=output_padding, padding=0) - - output = upfirdn2d_native(inverse_conv, torch.tensor(kernel, device=inverse_conv.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) - else: - p = kernel.shape[0] - factor - output = upfirdn2d_native( - input_tensor, torch.tensor(kernel, device=input_tensor.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) - ) - - return output - ->>>>>>> e25c501 (renamed x to meaningful variable in resnet.py) def forward(self, input_tensor): if self.use_conv: height = self._upsample_2d(input_tensor, self.Conv2d_0.weight, kernel=self.fir_kernel) @@ -251,7 +237,6 @@ def _downsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain= _, _, convH, convW = weight.shape pad_value = (kernel.shape[0] - factor) + (convW - 1) s = [factor, factor] -<<<<<<< HEAD upfirdn_input = upfirdn2d_native( input_tensor, torch.tensor(kernel, device=input_tensor.device), @@ -263,13 +248,6 @@ def _downsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain= output_tensor = upfirdn2d_native( input_tensor, torch.tensor(kernel, device=input_tensor.device), down=factor, pad=((p + 1) // 2, p // 2) ) -======= - upfirdn_input = upfirdn2d_native(input_tensor, torch.tensor(kernel, device=input_tensor.device), pad=((pad_value + 1) // 2, pad_value // 2)) - output_tensor = F.conv2d(upfirdn_input, weight, stride=s, padding=0) - else: - p = kernel.shape[0] - factor - output_tensor = upfirdn2d_native(input_tensor, torch.tensor(kernel, device=input_tensor.device), down=factor, pad=((p + 1) // 2, p // 2)) ->>>>>>> e25c501 (renamed x to meaningful variable in resnet.py) return output_tensor @@ -433,16 +411,12 @@ def upsample_2d(input_tensor, kernel=None, factor=2, gain=1): kernel = kernel * (gain * (factor**2)) pad_value = kernel.shape[0] - factor -<<<<<<< HEAD return upfirdn2d_native( input_tensor, kernel.to(device=input_tensor.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) -======= - return upfirdn2d_native(input_tensor, kernel.to(device=input_tensor.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2)) ->>>>>>> e25c501 (renamed x to meaningful variable in resnet.py) def downsample_2d(input_tensor, kernel=None, factor=2, gain=1): @@ -474,13 +448,9 @@ def downsample_2d(input_tensor, kernel=None, factor=2, gain=1): kernel = kernel * gain pad_value = kernel.shape[0] - factor -<<<<<<< HEAD return upfirdn2d_native( input_tensor, kernel.to(device=input_tensor.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) ) -======= - return upfirdn2d_native(input_tensor, kernel.to(device=input_tensor.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)) ->>>>>>> e25c501 (renamed x to meaningful variable in resnet.py) def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): From 29eebc742909002921930aba699166d5d4da6820 Mon Sep 17 00:00:00 2001 From: i-am-epic Date: Fri, 30 Sep 2022 02:36:11 +0530 Subject: [PATCH 6/8] modified unboundlocalerror in line 374 --- src/diffusers/models/resnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 22783e42b542..e119c580a064 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -371,9 +371,9 @@ def forward(self, input_tensor, temb): hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - sample_conv_input = self.conv_shortcut(sample_input) + sample_input = self.conv_shortcut(sample_input) - output_tensor = (sample_conv_input + hidden_states) / self.output_scale_factor + output_tensor = (sample_input + hidden_states) / self.output_scale_factor return output_tensor From deeb9af7a0f668b76a54c9b2553629799d2c1038 Mon Sep 17 00:00:00 2001 From: i-am-epic Date: Fri, 30 Sep 2022 02:59:20 +0530 Subject: [PATCH 7/8] removed referenced before error --- src/diffusers/models/resnet.py | 39 ++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index e119c580a064..0aab8deae5b7 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -39,14 +39,14 @@ def forward(self, input_tensor): if self.use_conv_transpose: return self.conv(input_tensor) - upsample_input = F.interpolate(input_tensor, scale_factor=2.0, mode="nearest") + output_tensor = F.interpolate(input_tensor, scale_factor=2.0, mode="nearest") # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": - output_tensor = self.conv(upsample_input) + output_tensor = self.conv(output_tensor) else: - output_tensor = self.Conv2d_0(upsample_input) + output_tensor = self.Conv2d_0(output_tensor) return output_tensor @@ -88,10 +88,10 @@ def forward(self, input_tensor): assert input_tensor.shape[1] == self.channels if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) - padded_input = F.pad(input_tensor, pad, mode="constant", value=0) + input_tensor = F.pad(input_tensor, pad, mode="constant", value=0) - assert padded_input.shape[1] == self.channels - output_tensor = self.conv(padded_input) + assert input_tensor.shape[1] == self.channels + output_tensor = self.conv(input_tensor) return output_tensor @@ -145,7 +145,7 @@ def _upsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain=1) convW = weight.shape[3] inC = weight.shape[1] - p = (kernel.shape[0] - factor) - (convW - 1) + pad_value = (kernel.shape[0] - factor) - (convW - 1) stride = (factor, factor) # Determine data dimensions. @@ -170,15 +170,15 @@ def _upsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain=1) output = upfirdn2d_native( inverse_conv, torch.tensor(kernel, device=inverse_conv.device), - pad=((p + 1) // 2 + factor - 1, p // 2 + 1), + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), ) else: - p = kernel.shape[0] - factor + pad_value = kernel.shape[0] - factor output = upfirdn2d_native( input_tensor, torch.tensor(kernel, device=input_tensor.device), up=factor, - pad=((p + 1) // 2 + factor - 1, p // 2), + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) return output @@ -236,17 +236,20 @@ def _downsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain= if self.use_conv: _, _, convH, convW = weight.shape pad_value = (kernel.shape[0] - factor) + (convW - 1) - s = [factor, factor] + stride_value = [factor, factor] upfirdn_input = upfirdn2d_native( input_tensor, torch.tensor(kernel, device=input_tensor.device), pad=((pad_value + 1) // 2, pad_value // 2), ) - output_tensor = F.conv2d(upfirdn_input, weight, stride=s, padding=0) + output_tensor = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) else: - p = kernel.shape[0] - factor + pad_value = kernel.shape[0] - factor output_tensor = upfirdn2d_native( - input_tensor, torch.tensor(kernel, device=input_tensor.device), down=factor, pad=((p + 1) // 2, p // 2) + input_tensor, + torch.tensor(kernel, device=input_tensor.device), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), ) return output_tensor @@ -350,10 +353,10 @@ def forward(self, input_tensor, temb): hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: - sample_input = self.upsample(input_tensor) + input_tensor = self.upsample(input_tensor) hidden_states = self.upsample(hidden_states) elif self.downsample is not None: - sample_input = self.downsample(input_tensor) + input_tensor = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) hidden_states = self.conv1(hidden_states) @@ -371,9 +374,9 @@ def forward(self, input_tensor, temb): hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - sample_input = self.conv_shortcut(sample_input) + input_tensor = self.conv_shortcut(input_tensor) - output_tensor = (sample_input + hidden_states) / self.output_scale_factor + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor From 9191c2f8b0d1ca01690562d07b6aeb6516a87b05 Mon Sep 17 00:00:00 2001 From: Nikhil A V Date: Sat, 1 Oct 2022 13:02:49 +0530 Subject: [PATCH 8/8] renamed single variable x -> hidden_state, p-> pad_value --- src/diffusers/models/resnet.py | 92 ++++++++++++++++++---------------- 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 0aab8deae5b7..e06d28cca016 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -34,21 +34,21 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann else: self.Conv2d_0 = conv - def forward(self, input_tensor): - assert input_tensor.shape[1] == self.channels + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: - return self.conv(input_tensor) + return self.conv(hidden_states) - output_tensor = F.interpolate(input_tensor, scale_factor=2.0, mode="nearest") + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": - output_tensor = self.conv(output_tensor) + hidden_states = self.conv(hidden_states) else: - output_tensor = self.Conv2d_0(output_tensor) + hidden_states = self.Conv2d_0(hidden_states) - return output_tensor + return hidden_states class Downsample2D(nn.Module): @@ -84,16 +84,16 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= else: self.conv = conv - def forward(self, input_tensor): - assert input_tensor.shape[1] == self.channels + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) - input_tensor = F.pad(input_tensor, pad, mode="constant", value=0) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) - assert input_tensor.shape[1] == self.channels - output_tensor = self.conv(input_tensor) + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) - return output_tensor + return hidden_states class FirUpsample2D(nn.Module): @@ -106,7 +106,7 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel= self.fir_kernel = fir_kernel self.out_channels = out_channels - def _upsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain=1): + def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): """Fused `upsample_2d()` followed by `Conv2d()`. Args: @@ -149,14 +149,17 @@ def _upsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain=1) stride = (factor, factor) # Determine data dimensions. - output_shape = ((input_tensor.shape[2] - 1) * factor + convH, (input_tensor.shape[3] - 1) * factor + convW) + output_shape = ( + (hidden_states.shape[2] - 1) * factor + convH, + (hidden_states.shape[3] - 1) * factor + convW, + ) output_padding = ( - output_shape[0] - (input_tensor.shape[2] - 1) * stride[0] - convH, - output_shape[1] - (input_tensor.shape[3] - 1) * stride[1] - convW, + output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, ) assert output_padding[0] >= 0 and output_padding[1] >= 0 inC = weight.shape[1] - num_groups = input_tensor.shape[1] // inC + num_groups = hidden_states.shape[1] // inC # Transpose weights. weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) @@ -164,7 +167,7 @@ def _upsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain=1) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) inverse_conv = F.conv_transpose2d( - input_tensor, weight, stride=stride, output_padding=output_padding, padding=0 + hidden_states, weight, stride=stride, output_padding=output_padding, padding=0 ) output = upfirdn2d_native( @@ -175,20 +178,20 @@ def _upsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain=1) else: pad_value = kernel.shape[0] - factor output = upfirdn2d_native( - input_tensor, - torch.tensor(kernel, device=input_tensor.device), + hidden_states, + torch.tensor(kernel, device=hidden_states.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) return output - def forward(self, input_tensor): + def forward(self, hidden_states): if self.use_conv: - height = self._upsample_2d(input_tensor, self.Conv2d_0.weight, kernel=self.fir_kernel) + height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: - height = self._upsample_2d(input_tensor, kernel=self.fir_kernel, factor=2) + height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) return height @@ -203,7 +206,7 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel= self.use_conv = use_conv self.out_channels = out_channels - def _downsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain=1): + def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): """Fused `Conv2d()` followed by `downsample_2d()`. Args: @@ -238,30 +241,30 @@ def _downsample_2d(self, input_tensor, weight=None, kernel=None, factor=2, gain= pad_value = (kernel.shape[0] - factor) + (convW - 1) stride_value = [factor, factor] upfirdn_input = upfirdn2d_native( - input_tensor, - torch.tensor(kernel, device=input_tensor.device), + hidden_states, + torch.tensor(kernel, device=hidden_states.device), pad=((pad_value + 1) // 2, pad_value // 2), ) - output_tensor = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) + hidden_states = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) else: pad_value = kernel.shape[0] - factor - output_tensor = upfirdn2d_native( - input_tensor, - torch.tensor(kernel, device=input_tensor.device), + hidden_states = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2), ) - return output_tensor + return hidden_states - def forward(self, input_tensor): + def forward(self, hidden_states): if self.use_conv: - downsample_input = self._downsample_2d(input_tensor, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) - output_tensor = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) + hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: - output_tensor = self._downsample_2d(input_tensor, kernel=self.fir_kernel, factor=2) + hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) - return output_tensor + return hidden_states class ResnetBlock2D(nn.Module): @@ -382,11 +385,11 @@ def forward(self, input_tensor, temb): class Mish(torch.nn.Module): - def forward(self, input_tensor): - return input_tensor * torch.tanh(torch.nn.functional.softplus(input_tensor)) + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) -def upsample_2d(input_tensor, kernel=None, factor=2, gain=1): +def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Upsample2D a batch of 2D images with the given filter. Args: @@ -415,14 +418,14 @@ def upsample_2d(input_tensor, kernel=None, factor=2, gain=1): kernel = kernel * (gain * (factor**2)) pad_value = kernel.shape[0] - factor return upfirdn2d_native( - input_tensor, - kernel.to(device=input_tensor.device), + hidden_states, + kernel.to(device=hidden_states.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) -def downsample_2d(input_tensor, kernel=None, factor=2, gain=1): +def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Downsample2D a batch of 2D images with the given filter. Args: @@ -452,7 +455,7 @@ def downsample_2d(input_tensor, kernel=None, factor=2, gain=1): kernel = kernel * gain pad_value = kernel.shape[0] - factor return upfirdn2d_native( - input_tensor, kernel.to(device=input_tensor.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) + hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) ) @@ -464,6 +467,7 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): _, channel, in_h, in_w = input.shape input = input.reshape(-1, in_h, in_w, 1) + # Rename this variable (input); it shadows a builtin.sonarlint(python:S5806) _, in_h, in_w, minor = input.shape kernel_h, kernel_w = kernel.shape