@@ -104,7 +104,9 @@ def __init__(
104104        super ().__init__ (num_channels , eps = eps , elementwise_affine = affine , ** kwargs )
105105
106106    def  forward (self , x : torch .Tensor ) ->  torch .Tensor :
107-         x  =  F .layer_norm (x .float (), self .normalized_shape , self .weight , self .bias , self .eps ).to (x .dtype )
107+         weight  =  self .weight .float () if  self .weight  is  not None  else  None 
108+         bias  =  self .bias .float () if  self .bias  is  not None  else  None 
109+         x  =  F .layer_norm (x .float (), self .normalized_shape , weight , bias , self .eps ).to (x .dtype )
108110        return  x 
109111
110112
@@ -146,7 +148,9 @@ def __init__(
146148
147149    def  forward (self , x : torch .Tensor ) ->  torch .Tensor :
148150        x  =  x .permute (0 , 2 , 3 , 1 )
149-         x  =  F .layer_norm (x .float (), self .normalized_shape , self .weight , self .bias , self .eps ).to (x .dtype )
151+         weight  =  self .weight .float () if  self .weight  is  not None  else  None 
152+         bias  =  self .bias .float () if  self .bias  is  not None  else  None 
153+         x  =  F .layer_norm (x .float (), self .normalized_shape , weight , bias , self .eps ).to (x .dtype )
150154        x  =  x .permute (0 , 3 , 1 , 2 )
151155        return  x 
152156
@@ -282,7 +286,8 @@ def reset_parameters(self) -> None:
282286            nn .init .ones_ (self .weight )
283287
284288    def  forward (self , x : torch .Tensor ) ->  torch .Tensor :
285-         x  =  rms_norm (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
289+         weight  =  self .weight .float () if  self .weight  is  not None  else  None 
290+         x  =  rms_norm (x .float (), self .normalized_shape , weight , self .eps ).to (x .dtype )
286291        return  x 
287292
288293
@@ -381,7 +386,8 @@ def reset_parameters(self) -> None:
381386            nn .init .ones_ (self .weight )
382387
383388    def  forward (self , x : torch .Tensor ) ->  torch .Tensor :
384-         x  =  rms_norm2d (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
389+         weight  =  self .weight .float () if  self .weight  is  not None  else  None 
390+         x  =  rms_norm2d (x .float (), self .normalized_shape , weight , self .eps ).to (x .dtype )
385391        return  x 
386392
387393
@@ -470,7 +476,8 @@ def reset_parameters(self) -> None:
470476            nn .init .ones_ (self .weight )
471477
472478    def  forward (self , x : torch .Tensor ) ->  torch .Tensor :
473-         x  =  simple_norm (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
479+         weight  =  self .weight .float () if  self .weight  is  not None  else  None 
480+         x  =  simple_norm (x .float (), self .normalized_shape , weight , self .eps ).to (x .dtype )
474481        return  x 
475482
476483
@@ -562,6 +569,7 @@ def reset_parameters(self) -> None:
562569
563570    def  forward (self , x : torch .Tensor ) ->  torch .Tensor :
564571        x  =  x .permute (0 , 2 , 3 , 1 )
565-         x  =  simple_norm (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
572+         weight  =  self .weight .float () if  self .weight  is  not None  else  None 
573+         x  =  simple_norm (x .float (), self .normalized_shape , weight , self .eps ).to (x .dtype )
566574        x  =  x .permute (0 , 3 , 1 , 2 )
567575        return  x 
0 commit comments