@@ -34,21 +34,21 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann
3434 else :
3535 self .Conv2d_0 = conv
3636
37- def forward (self , x ):
38- assert x .shape [1 ] == self .channels
37+ def forward (self , hidden_states ):
38+ assert hidden_states .shape [1 ] == self .channels
3939 if self .use_conv_transpose :
40- return self .conv (x )
40+ return self .conv (hidden_states )
4141
42- x = F .interpolate (x , scale_factor = 2.0 , mode = "nearest" )
42+ hidden_states = F .interpolate (hidden_states , scale_factor = 2.0 , mode = "nearest" )
4343
4444 # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
4545 if self .use_conv :
4646 if self .name == "conv" :
47- x = self .conv (x )
47+ hidden_states = self .conv (hidden_states )
4848 else :
49- x = self .Conv2d_0 (x )
49+ hidden_states = self .Conv2d_0 (hidden_states )
5050
51- return x
51+ return hidden_states
5252
5353
5454class Downsample2D (nn .Module ):
@@ -84,16 +84,16 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name=
8484 else :
8585 self .conv = conv
8686
87- def forward (self , x ):
88- assert x .shape [1 ] == self .channels
87+ def forward (self , hidden_states ):
88+ assert hidden_states .shape [1 ] == self .channels
8989 if self .use_conv and self .padding == 0 :
9090 pad = (0 , 1 , 0 , 1 )
91- x = F .pad (x , pad , mode = "constant" , value = 0 )
91+ hidden_states = F .pad (hidden_states , pad , mode = "constant" , value = 0 )
9292
93- assert x .shape [1 ] == self .channels
94- x = self .conv (x )
93+ assert hidden_states .shape [1 ] == self .channels
94+ hidden_states = self .conv (hidden_states )
9595
96- return x
96+ return hidden_states
9797
9898
9999class FirUpsample2D (nn .Module ):
@@ -174,12 +174,12 @@ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
174174
175175 return x
176176
177- def forward (self , x ):
177+ def forward (self , hidden_states ):
178178 if self .use_conv :
179- height = self ._upsample_2d (x , self .Conv2d_0 .weight , kernel = self .fir_kernel )
179+ height = self ._upsample_2d (hidden_states , self .Conv2d_0 .weight , kernel = self .fir_kernel )
180180 height = height + self .Conv2d_0 .bias .reshape (1 , - 1 , 1 , 1 )
181181 else :
182- height = self ._upsample_2d (x , kernel = self .fir_kernel , factor = 2 )
182+ height = self ._upsample_2d (hidden_states , kernel = self .fir_kernel , factor = 2 )
183183
184184 return height
185185
@@ -236,14 +236,14 @@ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
236236
237237 return x
238238
239- def forward (self , x ):
239+ def forward (self , hidden_states ):
240240 if self .use_conv :
241- x = self ._downsample_2d (x , weight = self .Conv2d_0 .weight , kernel = self .fir_kernel )
242- x = x + self .Conv2d_0 .bias .reshape (1 , - 1 , 1 , 1 )
241+ hidden_states = self ._downsample_2d (hidden_states , weight = self .Conv2d_0 .weight , kernel = self .fir_kernel )
242+ hidden_states = hidden_states + self .Conv2d_0 .bias .reshape (1 , - 1 , 1 , 1 )
243243 else :
244- x = self ._downsample_2d (x , kernel = self .fir_kernel , factor = 2 )
244+ hidden_states = self ._downsample_2d (hidden_states , kernel = self .fir_kernel , factor = 2 )
245245
246- return x
246+ return hidden_states
247247
248248
249249class ResnetBlock2D (nn .Module ):
0 commit comments