diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 52f01552c528..24c3b07e7cb6 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -95,9 +95,9 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= assert self.channels == self.out_channels self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) - def forward(self, x): - assert x.shape[1] == self.channels - return self.conv(x) + def forward(self, inputs): + assert inputs.shape[1] == self.channels + return self.conv(inputs) class Upsample2D(nn.Module): @@ -431,13 +431,13 @@ def __init__(self, pad_mode="reflect"): self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) - def forward(self, x): - x = F.pad(x, (self.pad,) * 4, self.pad_mode) - weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) - indices = torch.arange(x.shape[1], device=x.device) - kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1) + def forward(self, inputs): + inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) + weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) + indices = torch.arange(inputs.shape[1], device=inputs.device) + kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) weight[indices, indices] = kernel - return F.conv2d(x, weight, stride=2) + return F.conv2d(inputs, weight, stride=2) class KUpsample2D(nn.Module): @@ -448,13 +448,13 @@ def __init__(self, pad_mode="reflect"): self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) - def forward(self, x): - x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) - weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) - indices = torch.arange(x.shape[1], device=x.device) - kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1) + def forward(self, inputs): + inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) + weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) + indices = torch.arange(inputs.shape[1], device=inputs.device) + kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) weight[indices, indices] = kernel - return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) + return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) class ResnetBlock2D(nn.Module): @@ -664,13 +664,13 @@ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): self.group_norm = nn.GroupNorm(n_groups, out_channels) self.mish = nn.Mish() - def forward(self, x): - x = self.conv1d(x) - x = rearrange_dims(x) - x = self.group_norm(x) - x = rearrange_dims(x) - x = self.mish(x) - return x + def forward(self, inputs): + intermediate_repr = self.conv1d(inputs) + intermediate_repr = rearrange_dims(intermediate_repr) + intermediate_repr = self.group_norm(intermediate_repr) + intermediate_repr = rearrange_dims(intermediate_repr) + output = self.mish(intermediate_repr) + return output # unet_rl.py @@ -687,10 +687,10 @@ def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() ) - def forward(self, x, t): + def forward(self, inputs, t): """ Args: - x : [ batch_size x inp_channels x horizon ] + inputs : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: @@ -698,9 +698,9 @@ def forward(self, x, t): """ t = self.time_emb_act(t) t = self.time_emb(t) - out = self.conv_in(x) + rearrange_dims(t) + out = self.conv_in(inputs) + rearrange_dims(t) out = self.conv_out(out) - return out + self.residual_conv(x) + return out + self.residual_conv(inputs) def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):