Skip to content

Commit a7058f4

Browse files
authored
Renamed x -> hidden_states in resnet.py (#676)
renamed x to hidden_states
1 parent 3dacbb9 commit a7058f4

File tree

1 file changed

+21
-21
lines changed

1 file changed

+21
-21
lines changed

src/diffusers/models/resnet.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,21 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann
3434
else:
3535
self.Conv2d_0 = conv
3636

37-
def forward(self, x):
38-
assert x.shape[1] == self.channels
37+
def forward(self, hidden_states):
38+
assert hidden_states.shape[1] == self.channels
3939
if self.use_conv_transpose:
40-
return self.conv(x)
40+
return self.conv(hidden_states)
4141

42-
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
42+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
4343

4444
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
4545
if self.use_conv:
4646
if self.name == "conv":
47-
x = self.conv(x)
47+
hidden_states = self.conv(hidden_states)
4848
else:
49-
x = self.Conv2d_0(x)
49+
hidden_states = self.Conv2d_0(hidden_states)
5050

51-
return x
51+
return hidden_states
5252

5353

5454
class Downsample2D(nn.Module):
@@ -84,16 +84,16 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name=
8484
else:
8585
self.conv = conv
8686

87-
def forward(self, x):
88-
assert x.shape[1] == self.channels
87+
def forward(self, hidden_states):
88+
assert hidden_states.shape[1] == self.channels
8989
if self.use_conv and self.padding == 0:
9090
pad = (0, 1, 0, 1)
91-
x = F.pad(x, pad, mode="constant", value=0)
91+
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
9292

93-
assert x.shape[1] == self.channels
94-
x = self.conv(x)
93+
assert hidden_states.shape[1] == self.channels
94+
hidden_states = self.conv(hidden_states)
9595

96-
return x
96+
return hidden_states
9797

9898

9999
class FirUpsample2D(nn.Module):
@@ -174,12 +174,12 @@ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
174174

175175
return x
176176

177-
def forward(self, x):
177+
def forward(self, hidden_states):
178178
if self.use_conv:
179-
height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
179+
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
180180
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
181181
else:
182-
height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
182+
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
183183

184184
return height
185185

@@ -236,14 +236,14 @@ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
236236

237237
return x
238238

239-
def forward(self, x):
239+
def forward(self, hidden_states):
240240
if self.use_conv:
241-
x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
242-
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
241+
hidden_states = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
242+
hidden_states = hidden_states + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
243243
else:
244-
x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
244+
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
245245

246-
return x
246+
return hidden_states
247247

248248

249249
class ResnetBlock2D(nn.Module):

0 commit comments

Comments
 (0)