Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,21 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann
else:
self.Conv2d_0 = conv

def forward(self, x):
assert x.shape[1] == self.channels
def forward(self, hidden_states):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
return self.conv(hidden_states)

x = F.interpolate(x, scale_factor=2.0, mode="nearest")
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")

# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
x = self.conv(x)
hidden_states = self.conv(hidden_states)
else:
x = self.Conv2d_0(x)
hidden_states = self.Conv2d_0(hidden_states)

return x
return hidden_states


class Downsample2D(nn.Module):
Expand Down Expand Up @@ -84,16 +84,16 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name=
else:
self.conv = conv

def forward(self, x):
assert x.shape[1] == self.channels
def forward(self, hidden_states):
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)

assert x.shape[1] == self.channels
x = self.conv(x)
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)

return x
return hidden_states


class FirUpsample2D(nn.Module):
Expand Down Expand Up @@ -174,12 +174,12 @@ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):

return x

def forward(self, x):
def forward(self, hidden_states):
if self.use_conv:
height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)

return height

Expand Down Expand Up @@ -236,14 +236,14 @@ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):

return x

def forward(self, x):
def forward(self, hidden_states):
if self.use_conv:
x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
hidden_states = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
hidden_states = hidden_states + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)

return x
return hidden_states


class ResnetBlock2D(nn.Module):
Expand Down