@@ -171,6 +171,46 @@ def forward(
171171 return x , gate_msa , shift_mlp , scale_mlp , gate_mlp
172172
173173
174+ class AdaLayerNormZeroPruned (nn .Module ):
175+ r"""
176+ Norm layer adaptive layer norm zero (adaLN-Zero).
177+
178+ Parameters:
179+ embedding_dim (`int`): The size of each embedding vector.
180+ num_embeddings (`int`): The size of the embeddings dictionary.
181+ """
182+
183+ def __init__ (self , embedding_dim : int , num_embeddings : Optional [int ] = None , norm_type = "layer_norm" , bias = True ):
184+ super ().__init__ ()
185+ if num_embeddings is not None :
186+ self .emb = CombinedTimestepLabelEmbeddings (num_embeddings , embedding_dim )
187+ else :
188+ self .emb = None
189+
190+ if norm_type == "layer_norm" :
191+ self .norm = nn .LayerNorm (embedding_dim , elementwise_affine = False , eps = 1e-6 )
192+ elif norm_type == "fp32_layer_norm" :
193+ self .norm = FP32LayerNorm (embedding_dim , elementwise_affine = False , bias = False )
194+ else :
195+ raise ValueError (
196+ f"Unsupported `norm_type` ({ norm_type } ) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
197+ )
198+
199+ def forward (
200+ self ,
201+ x : torch .Tensor ,
202+ timestep : Optional [torch .Tensor ] = None ,
203+ class_labels : Optional [torch .LongTensor ] = None ,
204+ hidden_dtype : Optional [torch .dtype ] = None ,
205+ emb : Optional [torch .Tensor ] = None ,
206+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
207+ if self .emb is not None :
208+ emb = self .emb (timestep , class_labels , hidden_dtype = hidden_dtype )
209+ scale_msa , shift_msa , gate_msa , scale_mlp , shift_mlp , gate_mlp = emb .chunk (6 , dim = 1 )
210+ x = self .norm (x ) * (1 + scale_msa [:, None ]) + shift_msa [:, None ]
211+ return x , gate_msa , shift_mlp , scale_mlp , gate_mlp
212+
213+
174214class AdaLayerNormZeroSingle (nn .Module ):
175215 r"""
176216 Norm layer adaptive layer norm zero (adaLN-Zero).
@@ -203,6 +243,35 @@ def forward(
203243 return x , gate_msa
204244
205245
246+ class AdaLayerNormZeroSinglePruned (nn .Module ):
247+ r"""
248+ Norm layer adaptive layer norm zero (adaLN-Zero).
249+
250+ Parameters:
251+ embedding_dim (`int`): The size of each embedding vector.
252+ num_embeddings (`int`): The size of the embeddings dictionary.
253+ """
254+
255+ def __init__ (self , embedding_dim : int , norm_type = "layer_norm" , bias = True ):
256+ super ().__init__ ()
257+
258+ if norm_type == "layer_norm" :
259+ self .norm = nn .LayerNorm (embedding_dim , elementwise_affine = False , eps = 1e-6 )
260+ else :
261+ raise ValueError (
262+ f"Unsupported `norm_type` ({ norm_type } ) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
263+ )
264+
265+ def forward (
266+ self ,
267+ x : torch .Tensor ,
268+ emb : Optional [torch .Tensor ] = None ,
269+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
270+ scale_msa , shift_msa , gate_msa = emb .chunk (3 , dim = 1 )
271+ x = self .norm (x ) * (1 + scale_msa [:, None ]) + shift_msa [:, None ]
272+ return x , gate_msa
273+
274+
206275class LuminaRMSNormZero (nn .Module ):
207276 """
208277 Norm layer adaptive RMS normalization zero.
@@ -305,6 +374,50 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
305374 return x
306375
307376
377+ class AdaLayerNormContinuousPruned (nn .Module ):
378+ r"""
379+ Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
380+
381+ Args:
382+ embedding_dim (`int`): Embedding dimension to use during projection.
383+ conditioning_embedding_dim (`int`): Dimension of the input condition.
384+ elementwise_affine (`bool`, defaults to `True`):
385+ Boolean flag to denote if affine transformation should be applied.
386+ eps (`float`, defaults to 1e-5): Epsilon factor.
387+ bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
388+ norm_type (`str`, defaults to `"layer_norm"`):
389+ Normalization layer to use. Values supported: "layer_norm", "rms_norm".
390+ """
391+
392+ def __init__ (
393+ self ,
394+ embedding_dim : int ,
395+ conditioning_embedding_dim : int ,
396+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
397+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
398+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
399+ # However, this is how it was implemented in the original code, and it's rather likely you should
400+ # set `elementwise_affine` to False.
401+ elementwise_affine = True ,
402+ eps = 1e-5 ,
403+ bias = True ,
404+ norm_type = "layer_norm" ,
405+ ):
406+ super ().__init__ ()
407+ if norm_type == "layer_norm" :
408+ self .norm = LayerNorm (embedding_dim , eps , elementwise_affine , bias )
409+ elif norm_type == "rms_norm" :
410+ self .norm = RMSNorm (embedding_dim , eps , elementwise_affine )
411+ else :
412+ raise ValueError (f"unknown norm_type { norm_type } " )
413+
414+ def forward (self , x : torch .Tensor , emb : torch .Tensor ) -> torch .Tensor :
415+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
416+ shift , scale = torch .chunk (emb .to (x .dtype ), 2 , dim = 1 )
417+ x = self .norm (x ) * (1 + scale )[:, None , :] + shift [:, None , :]
418+ return x
419+
420+
308421class AdaLayerNormContinuous (nn .Module ):
309422 r"""
310423 Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
0 commit comments