@@ -89,7 +89,7 @@ def __call__(self, hidden_states):
8989class FlaxResnetBlock2D (nn .Module ):
9090 in_channels : int
9191 out_channels : int = None
92- dropout_prob : float = 0.0
92+ dropout : float = 0.0
9393 use_nin_shortcut : bool = None
9494 dtype : jnp .dtype = jnp .float32
9595
@@ -106,7 +106,7 @@ def setup(self):
106106 )
107107
108108 self .norm2 = nn .GroupNorm (num_groups = 32 , epsilon = 1e-6 )
109- self .dropout = nn .Dropout (self .dropout_prob )
109+ self .dropout_layer = nn .Dropout (self .dropout )
110110 self .conv2 = nn .Conv (
111111 out_channels ,
112112 kernel_size = (3 , 3 ),
@@ -135,7 +135,7 @@ def __call__(self, hidden_states, deterministic=True):
135135
136136 hidden_states = self .norm2 (hidden_states )
137137 hidden_states = nn .swish (hidden_states )
138- hidden_states = self .dropout (hidden_states , deterministic )
138+ hidden_states = self .dropout_layer (hidden_states , deterministic )
139139 hidden_states = self .conv2 (hidden_states )
140140
141141 if self .conv_shortcut is not None :
@@ -217,7 +217,7 @@ def setup(self):
217217 res_block = FlaxResnetBlock2D (
218218 in_channels = in_channels ,
219219 out_channels = self .out_channels ,
220- dropout_prob = self .dropout ,
220+ dropout = self .dropout ,
221221 dtype = self .dtype ,
222222 )
223223 resnets .append (res_block )
@@ -251,7 +251,7 @@ def setup(self):
251251 res_block = FlaxResnetBlock2D (
252252 in_channels = in_channels ,
253253 out_channels = self .out_channels ,
254- dropout_prob = self .dropout ,
254+ dropout = self .dropout ,
255255 dtype = self .dtype ,
256256 )
257257 resnets .append (res_block )
@@ -284,7 +284,7 @@ def setup(self):
284284 FlaxResnetBlock2D (
285285 in_channels = self .in_channels ,
286286 out_channels = self .in_channels ,
287- dropout_prob = self .dropout ,
287+ dropout = self .dropout ,
288288 dtype = self .dtype ,
289289 )
290290 ]
@@ -300,7 +300,7 @@ def setup(self):
300300 res_block = FlaxResnetBlock2D (
301301 in_channels = self .in_channels ,
302302 out_channels = self .in_channels ,
303- dropout_prob = self .dropout ,
303+ dropout = self .dropout ,
304304 dtype = self .dtype ,
305305 )
306306 resnets .append (res_block )
0 commit comments