Skip to content

Commit 35716dd

Browse files
committed
Revert "replace dropout_prob by dropout in vae"
This reverts commit cd7fb4f.
1 parent 825248b commit 35716dd

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/diffusers/models/vae_flax.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class FlaxResnetBlock2D(nn.Module):
127127

128128
in_channels: int
129129
out_channels: int = None
130-
dropout: float = 0.0
130+
dropout_prob: float = 0.0
131131
use_nin_shortcut: bool = None
132132
dtype: jnp.dtype = jnp.float32
133133

@@ -144,7 +144,7 @@ def setup(self):
144144
)
145145

146146
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
147-
self.dropout_layer = nn.Dropout(self.dropout)
147+
self.dropout = nn.Dropout(self.dropout_prob)
148148
self.conv2 = nn.Conv(
149149
out_channels,
150150
kernel_size=(3, 3),
@@ -173,7 +173,7 @@ def __call__(self, hidden_states, deterministic=True):
173173

174174
hidden_states = self.norm2(hidden_states)
175175
hidden_states = nn.swish(hidden_states)
176-
hidden_states = self.dropout_layer(hidden_states, deterministic)
176+
hidden_states = self.dropout(hidden_states, deterministic)
177177
hidden_states = self.conv2(hidden_states)
178178

179179
if self.conv_shortcut is not None:
@@ -284,7 +284,7 @@ def setup(self):
284284
res_block = FlaxResnetBlock2D(
285285
in_channels=in_channels,
286286
out_channels=self.out_channels,
287-
dropout=self.dropout,
287+
dropout_prob=self.dropout,
288288
dtype=self.dtype,
289289
)
290290
resnets.append(res_block)
@@ -335,7 +335,7 @@ def setup(self):
335335
res_block = FlaxResnetBlock2D(
336336
in_channels=in_channels,
337337
out_channels=self.out_channels,
338-
dropout=self.dropout,
338+
dropout_prob=self.dropout,
339339
dtype=self.dtype,
340340
)
341341
resnets.append(res_block)
@@ -383,7 +383,7 @@ def setup(self):
383383
FlaxResnetBlock2D(
384384
in_channels=self.in_channels,
385385
out_channels=self.in_channels,
386-
dropout=self.dropout,
386+
dropout_prob=self.dropout,
387387
dtype=self.dtype,
388388
)
389389
]
@@ -399,7 +399,7 @@ def setup(self):
399399
res_block = FlaxResnetBlock2D(
400400
in_channels=self.in_channels,
401401
out_channels=self.in_channels,
402-
dropout=self.dropout,
402+
dropout_prob=self.dropout,
403403
dtype=self.dtype,
404404
)
405405
resnets.append(res_block)

0 commit comments

Comments
 (0)