diff --git a/timm/layers/norm.py b/timm/layers/norm.py index 79a19bf323..ec082da2ef 100644 --- a/timm/layers/norm.py +++ b/timm/layers/norm.py @@ -104,7 +104,9 @@ def __init__( super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) + weight = self.weight.float() if self.weight is not None else None + bias = self.bias.float() if self.bias is not None else None + x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype) return x @@ -146,7 +148,9 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 3, 1) - x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) + weight = self.weight.float() if self.weight is not None else None + bias = self.bias.float() if self.bias is not None else None + x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype) x = x.permute(0, 3, 1, 2) return x @@ -282,7 +286,8 @@ def reset_parameters(self) -> None: nn.init.ones_(self.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = rms_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + weight = self.weight.float() if self.weight is not None else None + x = rms_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype) return x @@ -381,7 +386,8 @@ def reset_parameters(self) -> None: nn.init.ones_(self.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = rms_norm2d(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + weight = self.weight.float() if self.weight is not None else None + x = rms_norm2d(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype) return x @@ -470,7 +476,8 @@ def reset_parameters(self) -> None: nn.init.ones_(self.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = simple_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + weight = self.weight.float() if self.weight is not None else None + x = simple_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype) return x @@ -562,6 +569,7 @@ def reset_parameters(self) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 3, 1) - x = simple_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + weight = self.weight.float() if self.weight is not None else None + x = simple_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype) x = x.permute(0, 3, 1, 2) return x diff --git a/timm/layers/norm_act.py b/timm/layers/norm_act.py index 7dbb5e0f2c..d362a95079 100644 --- a/timm/layers/norm_act.py +++ b/timm/layers/norm_act.py @@ -482,7 +482,9 @@ def __init__( self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) def forward(self, x): - x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) + weight = self.weight.float() if self.weight is not None else None + bias = self.bias.float() if self.bias is not None else None + x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype) x = self.drop(x) x = self.act(x) return x @@ -540,7 +542,9 @@ def __init__( def forward(self, x): x = x.permute(0, 2, 3, 1) - x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype) + weight = self.weight.float() if self.weight is not None else None + bias = self.bias.float() if self.bias is not None else None + x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype) x = x.permute(0, 3, 1, 2) x = self.drop(x) x = self.act(x) @@ -605,7 +609,8 @@ def __init__( self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = rms_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + weight = self.weight.float() if self.weight is not None else None + x = rms_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype) x = self.drop(x) x = self.act(x) return x @@ -667,7 +672,8 @@ def __init__( self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = rms_norm2d(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype) + weight = self.weight.float() if self.weight is not None else None + x = rms_norm2d(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype) x = self.drop(x) x = self.act(x) return x diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 7f2f5aa341..d1976944da 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -2881,22 +2881,35 @@ def test_efficientnet(pretrained=False, **kwargs) -> EfficientNet: @register_model def test_efficientnet_gn(pretrained=False, **kwargs) -> EfficientNet: + model = _gen_test_efficientnet( - 'test_efficientnet_gn', pretrained=pretrained, norm_layer=partial(GroupNormAct, group_size=8), **kwargs) + 'test_efficientnet_gn', + pretrained=pretrained, + norm_layer=kwargs.pop('norm_layer', partial(GroupNormAct, group_size=8)), + **kwargs + ) return model @register_model def test_efficientnet_ln(pretrained=False, **kwargs) -> EfficientNet: model = _gen_test_efficientnet( - 'test_efficientnet_ln', pretrained=pretrained, norm_layer=LayerNormAct2d, **kwargs) + 'test_efficientnet_ln', + pretrained=pretrained, + norm_layer=kwargs.pop('norm_layer', LayerNormAct2d), + **kwargs + ) return model @register_model def test_efficientnet_evos(pretrained=False, **kwargs) -> EfficientNet: model = _gen_test_efficientnet( - 'test_efficientnet_evos', pretrained=pretrained, norm_layer=partial(EvoNorm2dS0, group_size=8), **kwargs) + 'test_efficientnet_evos', + pretrained=pretrained, + norm_layer=kwargs.pop('norm_layer', partial(EvoNorm2dS0, group_size=8)), + **kwargs + ) return model