@@ -119,6 +119,8 @@ class FlaxResnetBlock2D(nn.Module):
119119 Output channels
120120 dropout (:obj:`float`, *optional*, defaults to 0.0):
121121 Dropout rate
122+ groups (:obj:`int`, *optional*, defaults to `32`):
123+ The number of groups to use for group norm.
122124 use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
123125 Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
124126 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
@@ -128,13 +130,14 @@ class FlaxResnetBlock2D(nn.Module):
128130 in_channels : int
129131 out_channels : int = None
130132 dropout : float = 0.0
133+ groups : int = 32
131134 use_nin_shortcut : bool = None
132135 dtype : jnp .dtype = jnp .float32
133136
134137 def setup (self ):
135138 out_channels = self .in_channels if self .out_channels is None else self .out_channels
136139
137- self .norm1 = nn .GroupNorm (num_groups = 32 , epsilon = 1e-6 )
140+ self .norm1 = nn .GroupNorm (num_groups = self . groups , epsilon = 1e-6 )
138141 self .conv1 = nn .Conv (
139142 out_channels ,
140143 kernel_size = (3 , 3 ),
@@ -143,7 +146,7 @@ def setup(self):
143146 dtype = self .dtype ,
144147 )
145148
146- self .norm2 = nn .GroupNorm (num_groups = 32 , epsilon = 1e-6 )
149+ self .norm2 = nn .GroupNorm (num_groups = self . groups , epsilon = 1e-6 )
147150 self .dropout_layer = nn .Dropout (self .dropout )
148151 self .conv2 = nn .Conv (
149152 out_channels ,
@@ -191,20 +194,23 @@ class FlaxAttentionBlock(nn.Module):
191194 Input channels
192195 num_head_channels (:obj:`int`, *optional*, defaults to `None`):
193196 Number of attention heads
197+ num_groups (:obj:`int`, *optional*, defaults to `32`):
198+ The number of groups to use for group norm
194199 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
195200 Parameters `dtype`
196201
197202 """
198203 channels : int
199204 num_head_channels : int = None
205+ num_groups : int = 32
200206 dtype : jnp .dtype = jnp .float32
201207
202208 def setup (self ):
203209 self .num_heads = self .channels // self .num_head_channels if self .num_head_channels is not None else 1
204210
205211 dense = partial (nn .Dense , self .channels , dtype = self .dtype )
206212
207- self .group_norm = nn .GroupNorm (num_groups = 32 , epsilon = 1e-6 )
213+ self .group_norm = nn .GroupNorm (num_groups = self . num_groups , epsilon = 1e-6 )
208214 self .query , self .key , self .value = dense (), dense (), dense ()
209215 self .proj_attn = dense ()
210216
@@ -264,6 +270,8 @@ class FlaxDownEncoderBlock2D(nn.Module):
264270 Dropout rate
265271 num_layers (:obj:`int`, *optional*, defaults to 1):
266272 Number of Resnet layer block
273+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
274+ The number of groups to use for the Resnet block group norm
267275 add_downsample (:obj:`bool`, *optional*, defaults to `True`):
268276 Whether to add downsample layer
269277 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
@@ -273,6 +281,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
273281 out_channels : int
274282 dropout : float = 0.0
275283 num_layers : int = 1
284+ resnet_groups : int = 32
276285 add_downsample : bool = True
277286 dtype : jnp .dtype = jnp .float32
278287
@@ -285,6 +294,7 @@ def setup(self):
285294 in_channels = in_channels ,
286295 out_channels = self .out_channels ,
287296 dropout = self .dropout ,
297+ groups = self .resnet_groups ,
288298 dtype = self .dtype ,
289299 )
290300 resnets .append (res_block )
@@ -303,9 +313,9 @@ def __call__(self, hidden_states, deterministic=True):
303313 return hidden_states
304314
305315
306- class FlaxUpEncoderBlock2D (nn .Module ):
316+ class FlaxUpDecoderBlock2D (nn .Module ):
307317 r"""
308- Flax Resnet blocks-based Encoder block for diffusion-based VAE.
318+ Flax Resnet blocks-based Decoder block for diffusion-based VAE.
309319
310320 Parameters:
311321 in_channels (:obj:`int`):
@@ -316,15 +326,18 @@ class FlaxUpEncoderBlock2D(nn.Module):
316326 Dropout rate
317327 num_layers (:obj:`int`, *optional*, defaults to 1):
318328 Number of Resnet layer block
319- add_downsample (:obj:`bool`, *optional*, defaults to `True`):
320- Whether to add downsample layer
329+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
330+ The number of groups to use for the Resnet block group norm
331+ add_upsample (:obj:`bool`, *optional*, defaults to `True`):
332+ Whether to add upsample layer
321333 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
322334 Parameters `dtype`
323335 """
324336 in_channels : int
325337 out_channels : int
326338 dropout : float = 0.0
327339 num_layers : int = 1
340+ resnet_groups : int = 32
328341 add_upsample : bool = True
329342 dtype : jnp .dtype = jnp .float32
330343
@@ -336,6 +349,7 @@ def setup(self):
336349 in_channels = in_channels ,
337350 out_channels = self .out_channels ,
338351 dropout = self .dropout ,
352+ groups = self .resnet_groups ,
339353 dtype = self .dtype ,
340354 )
341355 resnets .append (res_block )
@@ -366,6 +380,8 @@ class FlaxUNetMidBlock2D(nn.Module):
366380 Dropout rate
367381 num_layers (:obj:`int`, *optional*, defaults to 1):
368382 Number of Resnet layer block
383+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
384+ The number of groups to use for the Resnet and Attention block group norm
369385 attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
370386 Number of attention heads for each attention block
371387 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
@@ -374,16 +390,20 @@ class FlaxUNetMidBlock2D(nn.Module):
374390 in_channels : int
375391 dropout : float = 0.0
376392 num_layers : int = 1
393+ resnet_groups : int = 32
377394 attn_num_head_channels : int = 1
378395 dtype : jnp .dtype = jnp .float32
379396
380397 def setup (self ):
398+ resnet_groups = self .resnet_groups if self .resnet_groups is not None else min (self .in_channels // 4 , 32 )
399+
381400 # there is always at least one resnet
382401 resnets = [
383402 FlaxResnetBlock2D (
384403 in_channels = self .in_channels ,
385404 out_channels = self .in_channels ,
386405 dropout = self .dropout ,
406+ groups = resnet_groups ,
387407 dtype = self .dtype ,
388408 )
389409 ]
@@ -392,14 +412,18 @@ def setup(self):
392412
393413 for _ in range (self .num_layers ):
394414 attn_block = FlaxAttentionBlock (
395- channels = self .in_channels , num_head_channels = self .attn_num_head_channels , dtype = self .dtype
415+ channels = self .in_channels ,
416+ num_head_channels = self .attn_num_head_channels ,
417+ num_groups = resnet_groups ,
418+ dtype = self .dtype ,
396419 )
397420 attentions .append (attn_block )
398421
399422 res_block = FlaxResnetBlock2D (
400423 in_channels = self .in_channels ,
401424 out_channels = self .in_channels ,
402425 dropout = self .dropout ,
426+ groups = resnet_groups ,
403427 dtype = self .dtype ,
404428 )
405429 resnets .append (res_block )
@@ -441,7 +465,7 @@ class FlaxEncoder(nn.Module):
441465 Tuple containing the number of output channels for each block
442466 layers_per_block (:obj:`int`, *optional*, defaults to `2`):
443467 Number of Resnet layer for each block
444- norm_num_groups (:obj:`int`, *optional*, defaults to `2 `):
468+ norm_num_groups (:obj:`int`, *optional*, defaults to `32 `):
445469 norm num group
446470 act_fn (:obj:`str`, *optional*, defaults to `silu`):
447471 Activation function
@@ -483,6 +507,7 @@ def setup(self):
483507 in_channels = input_channel ,
484508 out_channels = output_channel ,
485509 num_layers = self .layers_per_block ,
510+ resnet_groups = self .norm_num_groups ,
486511 add_downsample = not is_final_block ,
487512 dtype = self .dtype ,
488513 )
@@ -491,12 +516,15 @@ def setup(self):
491516
492517 # middle
493518 self .mid_block = FlaxUNetMidBlock2D (
494- in_channels = block_out_channels [- 1 ], attn_num_head_channels = None , dtype = self .dtype
519+ in_channels = block_out_channels [- 1 ],
520+ resnet_groups = self .norm_num_groups ,
521+ attn_num_head_channels = None ,
522+ dtype = self .dtype ,
495523 )
496524
497525 # end
498526 conv_out_channels = 2 * self .out_channels if self .double_z else self .out_channels
499- self .conv_norm_out = nn .GroupNorm (num_groups = 32 , epsilon = 1e-6 )
527+ self .conv_norm_out = nn .GroupNorm (num_groups = self . norm_num_groups , epsilon = 1e-6 )
500528 self .conv_out = nn .Conv (
501529 conv_out_channels ,
502530 kernel_size = (3 , 3 ),
@@ -581,7 +609,10 @@ def setup(self):
581609
582610 # middle
583611 self .mid_block = FlaxUNetMidBlock2D (
584- in_channels = block_out_channels [- 1 ], attn_num_head_channels = None , dtype = self .dtype
612+ in_channels = block_out_channels [- 1 ],
613+ resnet_groups = self .norm_num_groups ,
614+ attn_num_head_channels = None ,
615+ dtype = self .dtype ,
585616 )
586617
587618 # upsampling
@@ -594,10 +625,11 @@ def setup(self):
594625
595626 is_final_block = i == len (block_out_channels ) - 1
596627
597- up_block = FlaxUpEncoderBlock2D (
628+ up_block = FlaxUpDecoderBlock2D (
598629 in_channels = prev_output_channel ,
599630 out_channels = output_channel ,
600631 num_layers = self .layers_per_block + 1 ,
632+ resnet_groups = self .norm_num_groups ,
601633 add_upsample = not is_final_block ,
602634 dtype = self .dtype ,
603635 )
@@ -607,7 +639,7 @@ def setup(self):
607639 self .up_blocks = up_blocks
608640
609641 # end
610- self .conv_norm_out = nn .GroupNorm (num_groups = 32 , epsilon = 1e-6 )
642+ self .conv_norm_out = nn .GroupNorm (num_groups = self . norm_num_groups , epsilon = 1e-6 )
611643 self .conv_out = nn .Conv (
612644 self .out_channels ,
613645 kernel_size = (3 , 3 ),
0 commit comments