@@ -127,7 +127,7 @@ class FlaxResnetBlock2D(nn.Module):
127127
128128 in_channels : int
129129 out_channels : int = None
130- dropout : float = 0.0
130+ dropout_prob : float = 0.0
131131 use_nin_shortcut : bool = None
132132 dtype : jnp .dtype = jnp .float32
133133
@@ -144,7 +144,7 @@ def setup(self):
144144 )
145145
146146 self .norm2 = nn .GroupNorm (num_groups = 32 , epsilon = 1e-6 )
147- self .dropout_layer = nn .Dropout (self .dropout )
147+ self .dropout = nn .Dropout (self .dropout_prob )
148148 self .conv2 = nn .Conv (
149149 out_channels ,
150150 kernel_size = (3 , 3 ),
@@ -173,7 +173,7 @@ def __call__(self, hidden_states, deterministic=True):
173173
174174 hidden_states = self .norm2 (hidden_states )
175175 hidden_states = nn .swish (hidden_states )
176- hidden_states = self .dropout_layer (hidden_states , deterministic )
176+ hidden_states = self .dropout (hidden_states , deterministic )
177177 hidden_states = self .conv2 (hidden_states )
178178
179179 if self .conv_shortcut is not None :
@@ -284,7 +284,7 @@ def setup(self):
284284 res_block = FlaxResnetBlock2D (
285285 in_channels = in_channels ,
286286 out_channels = self .out_channels ,
287- dropout = self .dropout ,
287+ dropout_prob = self .dropout ,
288288 dtype = self .dtype ,
289289 )
290290 resnets .append (res_block )
@@ -335,7 +335,7 @@ def setup(self):
335335 res_block = FlaxResnetBlock2D (
336336 in_channels = in_channels ,
337337 out_channels = self .out_channels ,
338- dropout = self .dropout ,
338+ dropout_prob = self .dropout ,
339339 dtype = self .dtype ,
340340 )
341341 resnets .append (res_block )
@@ -383,7 +383,7 @@ def setup(self):
383383 FlaxResnetBlock2D (
384384 in_channels = self .in_channels ,
385385 out_channels = self .in_channels ,
386- dropout = self .dropout ,
386+ dropout_prob = self .dropout ,
387387 dtype = self .dtype ,
388388 )
389389 ]
@@ -399,7 +399,7 @@ def setup(self):
399399 res_block = FlaxResnetBlock2D (
400400 in_channels = self .in_channels ,
401401 out_channels = self .in_channels ,
402- dropout = self .dropout ,
402+ dropout_prob = self .dropout ,
403403 dtype = self .dtype ,
404404 )
405405 resnets .append (res_block )
0 commit comments