diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 6b0089d5c2e5..b9718e67f279 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -112,7 +112,7 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel= self.fir_kernel = fir_kernel self.out_channels = out_channels - def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): + def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): """Fused `upsample_2d()` followed by `Conv2d()`. Args: @@ -151,34 +151,46 @@ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): convW = weight.shape[3] inC = weight.shape[1] - p = (kernel.shape[0] - factor) - (convW - 1) + pad_value = (kernel.shape[0] - factor) - (convW - 1) stride = (factor, factor) # Determine data dimensions. - output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW) + output_shape = ( + (hidden_states.shape[2] - 1) * factor + convH, + (hidden_states.shape[3] - 1) * factor + convW, + ) output_padding = ( - output_shape[0] - (x.shape[2] - 1) * stride[0] - convH, - output_shape[1] - (x.shape[3] - 1) * stride[1] - convW, + output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, + output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, ) assert output_padding[0] >= 0 and output_padding[1] >= 0 inC = weight.shape[1] - num_groups = x.shape[1] // inC + num_groups = hidden_states.shape[1] // inC # Transpose weights. weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) - x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0) + inverse_conv = F.conv_transpose2d( + hidden_states, weight, stride=stride, output_padding=output_padding, padding=0 + ) - x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) + output = upfirdn2d_native( + inverse_conv, + torch.tensor(kernel, device=inverse_conv.device), + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), + ) else: - p = kernel.shape[0] - factor - x = upfirdn2d_native( - x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) - return x + return output def forward(self, hidden_states): if self.use_conv: @@ -200,7 +212,7 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel= self.use_conv = use_conv self.out_channels = out_channels - def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): + def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): """Fused `Conv2d()` followed by `downsample_2d()`. Args: @@ -232,20 +244,29 @@ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): if self.use_conv: _, _, convH, convW = weight.shape - p = (kernel.shape[0] - factor) + (convW - 1) - s = [factor, factor] - x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2)) - x = F.conv2d(x, weight, stride=s, padding=0) + pad_value = (kernel.shape[0] - factor) + (convW - 1) + stride_value = [factor, factor] + upfirdn_input = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + pad=((pad_value + 1) // 2, pad_value // 2), + ) + hidden_states = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) else: - p = kernel.shape[0] - factor - x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + pad_value = kernel.shape[0] - factor + hidden_states = upfirdn2d_native( + hidden_states, + torch.tensor(kernel, device=hidden_states.device), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) - return x + return hidden_states def forward(self, hidden_states): if self.use_conv: - hidden_states = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) - hidden_states = hidden_states + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) + hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) @@ -332,17 +353,17 @@ def __init__( if self.use_in_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - def forward(self, x, temb): - hidden_states = x + def forward(self, input_tensor, temb): + hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: - x = self.upsample(x) + input_tensor = self.upsample(input_tensor) hidden_states = self.upsample(hidden_states) elif self.downsample is not None: - x = self.downsample(x) + input_tensor = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) hidden_states = self.conv1(hidden_states) @@ -358,19 +379,19 @@ def forward(self, x, temb): hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - x = self.conv_shortcut(x) + input_tensor = self.conv_shortcut(input_tensor) - out = (x + hidden_states) / self.output_scale_factor + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - return out + return output_tensor class Mish(torch.nn.Module): - def forward(self, x): - return x * torch.tanh(torch.nn.functional.softplus(x)) + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) -def upsample_2d(x, kernel=None, factor=2, gain=1): +def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Upsample2D a batch of 2D images with the given filter. Args: @@ -397,11 +418,16 @@ def upsample_2d(x, kernel=None, factor=2, gain=1): kernel /= torch.sum(kernel) kernel = kernel * (gain * (factor**2)) - p = kernel.shape[0] - factor - return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) + pad_value = kernel.shape[0] - factor + return upfirdn2d_native( + hidden_states, + kernel.to(device=hidden_states.device), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) -def downsample_2d(x, kernel=None, factor=2, gain=1): +def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Downsample2D a batch of 2D images with the given filter. Args: @@ -429,8 +455,10 @@ def downsample_2d(x, kernel=None, factor=2, gain=1): kernel /= torch.sum(kernel) kernel = kernel * gain - p = kernel.shape[0] - factor - return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + pad_value = kernel.shape[0] - factor + return upfirdn2d_native( + hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) + ) def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): @@ -441,6 +469,7 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): _, channel, in_h, in_w = input.shape input = input.reshape(-1, in_h, in_w, 1) + # Rename this variable (input); it shadows a builtin.sonarlint(python:S5806) _, in_h, in_w, minor = input.shape kernel_h, kernel_w = kernel.shape