Skip to content
10 changes: 5 additions & 5 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2'
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*'
]

# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
]
Expand All @@ -72,11 +72,11 @@
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', '*huge*', '*giant*', '*gigantic*',
'*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560']
NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*']
'*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560', '*_1b_*', '*_3b_*']
NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*', '*_1b_*', '*_3b_*']
else:
EXCLUDE_FILTERS = ['*enormous*']
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*']
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*']

EXCLUDE_JIT_FILTERS = ['hiera_*']

Expand Down
2 changes: 1 addition & 1 deletion timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same
Expand Down
4 changes: 3 additions & 1 deletion timm/layers/create_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch.nn as nn

from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
from torchvision.ops.misc import FrozenBatchNorm2d

_NORM_MAP = dict(
Expand All @@ -23,6 +23,8 @@
layernorm2d=LayerNorm2d,
rmsnorm=RmsNorm,
rmsnorm2d=RmsNorm2d,
simplenorm=SimpleNorm,
simplenorm2d=SimpleNorm2d,
frozenbatchnorm2d=FrozenBatchNorm2d,
)
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
Expand Down
68 changes: 62 additions & 6 deletions timm/layers/fast_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
has_apex_rmsnorm = False


has_torch_rms_norm = hasattr(F, 'rms_norm')

# fast (ie lower precision LN) can be disabled with this flag if issues crop up
_USE_FAST_NORM = False # defaulting to False for now

Expand Down Expand Up @@ -75,7 +77,6 @@ def fast_group_norm(
if is_autocast_enabled(x.device.type):
# normally native AMP casts GN inputs to float32
# here we use the low precision autocast dtype
# FIXME what to do re CPU autocast?
dt = get_autocast_dtype(x.device.type)
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None

Expand All @@ -101,7 +102,6 @@ def fast_layer_norm(
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
# FIXME what to do re CPU autocast?
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None

with torch.amp.autocast(device_type=x.device.type, enabled=False):
Expand All @@ -115,15 +115,16 @@ def rms_norm(
eps: float = 1e-5,
):
norm_ndim = len(normalized_shape)
v = x.pow(2)
if torch.jit.is_scripting():
# ndim = len(x.shape)
# dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
# NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
assert norm_ndim == 1
v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
v = torch.mean(v, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
else:
dims = tuple(range(-1, -norm_ndim - 1, -1))
v = torch.var(x, dim=dims, keepdim=True)
v = torch.mean(v, dim=dims, keepdim=True)
x = x * torch.rsqrt(v + eps)
if weight is not None:
x = x * weight
Expand All @@ -146,5 +147,60 @@ def fast_rms_norm(
else:
return fused_rms_norm_affine(x, weight, normalized_shape, eps)

# fallback
return rms_norm(x, normalized_shape, weight, eps)
if is_autocast_enabled(x.device.type):
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
x, weight = x.to(dt), weight.to(dt)

with torch.amp.autocast(device_type=x.device.type, enabled=False):
if has_torch_rms_norm:
x = F.rms_norm(x, normalized_shape, weight, eps)
else:
x = rms_norm(x, normalized_shape, weight, eps)

return x


def simple_norm(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
eps: float = 1e-5,
):
norm_ndim = len(normalized_shape)
if torch.jit.is_scripting():
# ndim = len(x.shape)
# dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
# NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
assert norm_ndim == 1
v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
else:
dims = tuple(range(-1, -norm_ndim - 1, -1))
v = torch.var(x, dim=dims, keepdim=True)
x = x * torch.rsqrt(v + eps)
if weight is not None:
x = x * weight
return x


def fast_simple_norm(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
eps: float = 1e-5,
) -> torch.Tensor:
if torch.jit.is_scripting():
# this must be by itself, cannot merge with has_apex_rmsnorm
return simple_norm(x, normalized_shape, weight, eps)

if is_autocast_enabled(x.device.type):
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
x, weight = x.to(dt), weight.to(dt)

with torch.amp.autocast(device_type=x.device.type, enabled=False):
x = simple_norm(x, normalized_shape, weight, eps)
return x

9 changes: 5 additions & 4 deletions timm/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def __init__(

def init_weights(self):
# override init of fc1 w/ gate portion set to weight near zero, bias=1
fc1_mid = self.fc1.bias.shape[0] // 2
nn.init.ones_(self.fc1.bias[fc1_mid:])
nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
if self.fc1.bias is not None:
nn.init.ones_(self.fc1.bias[self.fc1.bias.shape[0] // 2:])
nn.init.normal_(self.fc1.weight[self.fc1.weight.shape[0] // 2:], std=1e-6)

def forward(self, x):
x = self.fc1(x)
Expand Down Expand Up @@ -132,7 +132,8 @@ def __init__(

def init_weights(self):
# override init of fc1 w/ gate portion set to weight near zero, bias=1
nn.init.ones_(self.fc1_g.bias)
if self.fc1_g.bias is not None:
nn.init.ones_(self.fc1_g.bias)
nn.init.normal_(self.fc1_g.weight, std=1e-6)

def forward(self, x):
Expand Down
124 changes: 115 additions & 9 deletions timm/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,24 @@
import torch.nn as nn
import torch.nn.functional as F

from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm, simple_norm

try:
from torch.nn.functional import rms_norm
except ImportError:
from .fast_norm import rms_norm


class GroupNorm(nn.GroupNorm):
_fast_norm: torch.jit.Final[bool]

def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
super().__init__(num_groups, num_channels, eps=eps, affine=affine)
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

def forward(self, x):
if self.fast_norm:
if self._fast_norm:
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
Expand All @@ -31,13 +38,14 @@ class GroupNorm1(nn.GroupNorm):
""" Group Normalization with 1 group.
Input: tensor in shape [B, C, *]
"""
_fast_norm: torch.jit.Final[bool]

def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.fast_norm:
if self._fast_norm:
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
Expand All @@ -46,6 +54,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class LayerNorm(nn.LayerNorm):
""" LayerNorm w/ fast norm option
"""
_fast_norm: torch.jit.Final[bool]

def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
Expand All @@ -60,6 +70,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

class LayerNorm2d(nn.LayerNorm):
""" LayerNorm for channels of '2D' spatial NCHW tensors """
_fast_norm: torch.jit.Final[bool]

def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
Expand Down Expand Up @@ -121,10 +133,11 @@ def forward(self, x) -> torch.Tensor:
class RmsNorm(nn.Module):
""" RmsNorm w/ fast (apex) norm if available
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool

def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
Expand All @@ -136,6 +149,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
Expand All @@ -150,17 +165,21 @@ def reset_parameters(self) -> None:
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
if self._fast_norm:
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
return x


class RmsNorm2d(nn.Module):
""" RmsNorm w/ fast (apex) norm if available
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool

def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
Expand All @@ -172,6 +191,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
Expand All @@ -187,6 +208,91 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
if self._fast_norm:
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
x = x.permute(0, 3, 1, 2)
return x


class SimpleNorm(nn.Module):
""" SimpleNorm (x / std(x))
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool

def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
normalized_shape = channels
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter('weight', None)

self.reset_parameters()

def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._fast_norm:
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
return x


class SimpleNorm2d(nn.Module):
""" SimpleNorm for NCHW tensors
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool

def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
normalized_shape = channels
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter('weight', None)

self.reset_parameters()

def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
if self._fast_norm:
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
x = x.permute(0, 3, 1, 2)
return x
Loading
Loading