@@ -34,15 +34,15 @@ class FlaxAutoencoderKLOutput(BaseOutput):
3434 Output of AutoencoderKL encoding method.
3535
3636 Args:
37- latent_dist (`DiagonalGaussianDistribution `):
38- Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution `.
39- `DiagonalGaussianDistribution ` allows for sampling latents from the distribution.
37+ latent_dist (`FlaxDiagonalGaussianDistribution `):
38+ Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution `.
39+ `FlaxDiagonalGaussianDistribution ` allows for sampling latents from the distribution.
4040 """
4141
42- latent_dist : "DiagonalGaussianDistribution "
42+ latent_dist : "FlaxDiagonalGaussianDistribution "
4343
4444
45- class Upsample2D (nn .Module ):
45+ class FlaxUpsample2D (nn .Module ):
4646 in_channels : int
4747 dtype : jnp .dtype = jnp .float32
4848
@@ -66,7 +66,7 @@ def __call__(self, hidden_states):
6666 return hidden_states
6767
6868
69- class Downsample2D (nn .Module ):
69+ class FlaxDownsample2D (nn .Module ):
7070 in_channels : int
7171 dtype : jnp .dtype = jnp .float32
7272
@@ -86,7 +86,7 @@ def __call__(self, hidden_states):
8686 return hidden_states
8787
8888
89- class ResnetBlock2D (nn .Module ):
89+ class FlaxResnetBlock2D (nn .Module ):
9090 in_channels : int
9191 out_channels : int = None
9292 dropout_prob : float = 0.0
@@ -144,7 +144,7 @@ def __call__(self, hidden_states, deterministic=True):
144144 return hidden_states + residual
145145
146146
147- class AttentionBlock (nn .Module ):
147+ class FlaxAttentionBlock (nn .Module ):
148148 channels : int
149149 num_head_channels : int = None
150150 dtype : jnp .dtype = jnp .float32
@@ -201,7 +201,7 @@ def __call__(self, hidden_states):
201201 return hidden_states
202202
203203
204- class DownEncoderBlock2D (nn .Module ):
204+ class FlaxDownEncoderBlock2D (nn .Module ):
205205 in_channels : int
206206 out_channels : int
207207 dropout : float = 0.0
@@ -214,7 +214,7 @@ def setup(self):
214214 for i in range (self .num_layers ):
215215 in_channels = self .in_channels if i == 0 else self .out_channels
216216
217- res_block = ResnetBlock2D (
217+ res_block = FlaxResnetBlock2D (
218218 in_channels = in_channels ,
219219 out_channels = self .out_channels ,
220220 dropout_prob = self .dropout ,
@@ -224,19 +224,19 @@ def setup(self):
224224 self .resnets = resnets
225225
226226 if self .add_downsample :
227- self .downsample = Downsample2D (self .out_channels , dtype = self .dtype )
227+ self .downsamplers_0 = FlaxDownsample2D (self .out_channels , dtype = self .dtype )
228228
229229 def __call__ (self , hidden_states , deterministic = True ):
230230 for resnet in self .resnets :
231231 hidden_states = resnet (hidden_states , deterministic = deterministic )
232232
233233 if self .add_downsample :
234- hidden_states = self .downsample (hidden_states )
234+ hidden_states = self .downsamplers_0 (hidden_states )
235235
236236 return hidden_states
237237
238238
239- class UpEncoderBlock2D (nn .Module ):
239+ class FlaxUpEncoderBlock2D (nn .Module ):
240240 in_channels : int
241241 out_channels : int
242242 dropout : float = 0.0
@@ -248,7 +248,7 @@ def setup(self):
248248 resnets = []
249249 for i in range (self .num_layers ):
250250 in_channels = self .in_channels if i == 0 else self .out_channels
251- res_block = ResnetBlock2D (
251+ res_block = FlaxResnetBlock2D (
252252 in_channels = in_channels ,
253253 out_channels = self .out_channels ,
254254 dropout_prob = self .dropout ,
@@ -259,19 +259,19 @@ def setup(self):
259259 self .resnets = resnets
260260
261261 if self .add_upsample :
262- self .upsample = Upsample2D (self .out_channels , dtype = self .dtype )
262+ self .upsamplers_0 = FlaxUpsample2D (self .out_channels , dtype = self .dtype )
263263
264264 def __call__ (self , hidden_states , deterministic = True ):
265265 for resnet in self .resnets :
266266 hidden_states = resnet (hidden_states , deterministic = deterministic )
267267
268268 if self .add_upsample :
269- hidden_states = self .upsample (hidden_states )
269+ hidden_states = self .upsamplers_0 (hidden_states )
270270
271271 return hidden_states
272272
273273
274- class UNetMidBlock2D (nn .Module ):
274+ class FlaxUNetMidBlock2D (nn .Module ):
275275 in_channels : int
276276 dropout : float = 0.0
277277 num_layers : int = 1
@@ -281,7 +281,7 @@ class UNetMidBlock2D(nn.Module):
281281 def setup (self ):
282282 # there is always at least one resnet
283283 resnets = [
284- ResnetBlock2D (
284+ FlaxResnetBlock2D (
285285 in_channels = self .in_channels ,
286286 out_channels = self .in_channels ,
287287 dropout_prob = self .dropout ,
@@ -292,12 +292,12 @@ def setup(self):
292292 attentions = []
293293
294294 for _ in range (self .num_layers ):
295- attn_block = AttentionBlock (
295+ attn_block = FlaxAttentionBlock (
296296 channels = self .in_channels , num_head_channels = self .attn_num_head_channels , dtype = self .dtype
297297 )
298298 attentions .append (attn_block )
299299
300- res_block = ResnetBlock2D (
300+ res_block = FlaxResnetBlock2D (
301301 in_channels = self .in_channels ,
302302 out_channels = self .in_channels ,
303303 dropout_prob = self .dropout ,
@@ -317,7 +317,7 @@ def __call__(self, hidden_states, deterministic=True):
317317 return hidden_states
318318
319319
320- class Encoder (nn .Module ):
320+ class FlaxEncoder (nn .Module ):
321321 in_channels : int = 3
322322 out_channels : int = 3
323323 down_block_types : Tuple [str ] = ("DownEncoderBlock2D" ,)
@@ -347,7 +347,7 @@ def setup(self):
347347 output_channel = block_out_channels [i ]
348348 is_final_block = i == len (block_out_channels ) - 1
349349
350- down_block = DownEncoderBlock2D (
350+ down_block = FlaxDownEncoderBlock2D (
351351 in_channels = input_channel ,
352352 out_channels = output_channel ,
353353 num_layers = self .layers_per_block ,
@@ -358,7 +358,7 @@ def setup(self):
358358 self .down_blocks = down_blocks
359359
360360 # middle
361- self .mid_block = UNetMidBlock2D (
361+ self .mid_block = FlaxUNetMidBlock2D (
362362 in_channels = block_out_channels [- 1 ], attn_num_head_channels = None , dtype = self .dtype
363363 )
364364
@@ -392,7 +392,7 @@ def __call__(self, sample, deterministic: bool = True):
392392 return sample
393393
394394
395- class Decoder (nn .Module ):
395+ class FlaxDecoder (nn .Module ):
396396 dtype : jnp .dtype = jnp .float32
397397 in_channels : int = 3
398398 out_channels : int = 3
@@ -415,7 +415,7 @@ def setup(self):
415415 )
416416
417417 # middle
418- self .mid_block = UNetMidBlock2D (
418+ self .mid_block = FlaxUNetMidBlock2D (
419419 in_channels = block_out_channels [- 1 ], attn_num_head_channels = None , dtype = self .dtype
420420 )
421421
@@ -429,7 +429,7 @@ def setup(self):
429429
430430 is_final_block = i == len (block_out_channels ) - 1
431431
432- up_block = UpEncoderBlock2D (
432+ up_block = FlaxUpEncoderBlock2D (
433433 in_channels = prev_output_channel ,
434434 out_channels = output_channel ,
435435 num_layers = self .layers_per_block + 1 ,
@@ -469,7 +469,7 @@ def __call__(self, sample, deterministic: bool = True):
469469 return sample
470470
471471
472- class DiagonalGaussianDistribution (object ):
472+ class FlaxDiagonalGaussianDistribution (object ):
473473 def __init__ (self , parameters , deterministic = False ):
474474 # Last axis to account for channels-last
475475 self .mean , self .logvar = jnp .split (parameters , 2 , axis = - 1 )
@@ -521,7 +521,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
521521 dtype : jnp .dtype = jnp .float32
522522
523523 def setup (self ):
524- self .encoder = Encoder (
524+ self .encoder = FlaxEncoder (
525525 in_channels = self .config .in_channels ,
526526 out_channels = self .config .latent_channels ,
527527 down_block_types = self .config .down_block_types ,
@@ -532,7 +532,7 @@ def setup(self):
532532 double_z = True ,
533533 dtype = self .dtype ,
534534 )
535- self .decoder = Decoder (
535+ self .decoder = FlaxDecoder (
536536 in_channels = self .config .latent_channels ,
537537 out_channels = self .config .out_channels ,
538538 up_block_types = self .config .up_block_types ,
@@ -572,7 +572,7 @@ def encode(self, sample, deterministic: bool = True, return_dict: bool = True):
572572
573573 hidden_states = self .encoder (sample , deterministic = deterministic )
574574 moments = self .quant_conv (hidden_states )
575- posterior = DiagonalGaussianDistribution (moments )
575+ posterior = FlaxDiagonalGaussianDistribution (moments )
576576
577577 if not return_dict :
578578 return (posterior ,)
0 commit comments