@@ -559,7 +559,7 @@ def setup(self):
559559
560560 def init_weights (self , rng : jax .random .PRNGKey ) -> FrozenDict :
561561 # init input tensors
562- sample_shape = (1 , self .in_channels , self .sample_size , self .sample_size )
562+ sample_shape = (1 , self .sample_size , self .sample_size , self .in_channels )
563563 sample = jnp .zeros (sample_shape , dtype = jnp .float32 )
564564
565565 params_rng , dropout_rng , gaussian_rng = jax .random .split (rng , 3 )
@@ -568,8 +568,6 @@ def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
568568 return self .init (rngs , sample )["params" ]
569569
570570 def encode (self , sample , deterministic : bool = True , return_dict : bool = True ):
571- sample = jnp .transpose (sample , (0 , 2 , 3 , 1 ))
572-
573571 hidden_states = self .encoder (sample , deterministic = deterministic )
574572 moments = self .quant_conv (hidden_states )
575573 posterior = DiagonalGaussianDistribution (moments )
@@ -586,8 +584,6 @@ def decode(self, latents, deterministic: bool = True, return_dict: bool = True):
586584 hidden_states = self .post_quant_conv (latents )
587585 hidden_states = self .decoder (hidden_states , deterministic = deterministic )
588586
589- hidden_states = jnp .transpose (hidden_states , (0 , 3 , 1 , 2 ))
590-
591587 if not return_dict :
592588 return (hidden_states ,)
593589
0 commit comments