diff --git a/tests/test_models.py b/tests/test_models.py index 0b7303c548..d4c39b391f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -53,13 +53,13 @@ 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos', 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet', - 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2' + 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*' ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', - 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', + 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*', 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*', ] @@ -72,11 +72,11 @@ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', '*huge*', '*giant*', '*gigantic*', - '*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560'] - NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*'] + '*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560', '*_1b_*', '*_3b_*'] + NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*', '*_1b_*', '*_3b_*'] else: EXCLUDE_FILTERS = ['*enormous*'] - NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*'] + NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*'] EXCLUDE_JIT_FILTERS = ['hiera_*'] diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index f631e86831..5ec03219e8 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -34,7 +34,7 @@ from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d from .padding import get_padding, get_same_padding, pad_same diff --git a/timm/layers/create_norm.py b/timm/layers/create_norm.py index 74e893d8fc..75262b5eca 100644 --- a/timm/layers/create_norm.py +++ b/timm/layers/create_norm.py @@ -10,7 +10,7 @@ import torch.nn as nn -from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d from torchvision.ops.misc import FrozenBatchNorm2d _NORM_MAP = dict( @@ -23,6 +23,8 @@ layernorm2d=LayerNorm2d, rmsnorm=RmsNorm, rmsnorm2d=RmsNorm2d, + simplenorm=SimpleNorm, + simplenorm2d=SimpleNorm2d, frozenbatchnorm2d=FrozenBatchNorm2d, ) _NORM_TYPES = {m for n, m in _NORM_MAP.items()} diff --git a/timm/layers/fast_norm.py b/timm/layers/fast_norm.py index 3bbb0b4f61..e7cbbb9495 100644 --- a/timm/layers/fast_norm.py +++ b/timm/layers/fast_norm.py @@ -24,6 +24,8 @@ has_apex_rmsnorm = False +has_torch_rms_norm = hasattr(F, 'rms_norm') + # fast (ie lower precision LN) can be disabled with this flag if issues crop up _USE_FAST_NORM = False # defaulting to False for now @@ -75,7 +77,6 @@ def fast_group_norm( if is_autocast_enabled(x.device.type): # normally native AMP casts GN inputs to float32 # here we use the low precision autocast dtype - # FIXME what to do re CPU autocast? dt = get_autocast_dtype(x.device.type) x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None @@ -101,7 +102,6 @@ def fast_layer_norm( # normally native AMP casts LN inputs to float32 # apex LN does not, this is behaving like Apex dt = get_autocast_dtype(x.device.type) - # FIXME what to do re CPU autocast? x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None with torch.amp.autocast(device_type=x.device.type, enabled=False): @@ -115,15 +115,16 @@ def rms_norm( eps: float = 1e-5, ): norm_ndim = len(normalized_shape) + v = x.pow(2) if torch.jit.is_scripting(): # ndim = len(x.shape) # dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x # NOTE -ve dims cause torchscript to crash in some cases, out of options to work around assert norm_ndim == 1 - v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True + v = torch.mean(v, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True else: dims = tuple(range(-1, -norm_ndim - 1, -1)) - v = torch.var(x, dim=dims, keepdim=True) + v = torch.mean(v, dim=dims, keepdim=True) x = x * torch.rsqrt(v + eps) if weight is not None: x = x * weight @@ -146,5 +147,60 @@ def fast_rms_norm( else: return fused_rms_norm_affine(x, weight, normalized_shape, eps) - # fallback - return rms_norm(x, normalized_shape, weight, eps) + if is_autocast_enabled(x.device.type): + # normally native AMP casts LN inputs to float32 + # apex LN does not, this is behaving like Apex + dt = get_autocast_dtype(x.device.type) + x, weight = x.to(dt), weight.to(dt) + + with torch.amp.autocast(device_type=x.device.type, enabled=False): + if has_torch_rms_norm: + x = F.rms_norm(x, normalized_shape, weight, eps) + else: + x = rms_norm(x, normalized_shape, weight, eps) + + return x + + +def simple_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor] = None, + eps: float = 1e-5, +): + norm_ndim = len(normalized_shape) + if torch.jit.is_scripting(): + # ndim = len(x.shape) + # dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x + # NOTE -ve dims cause torchscript to crash in some cases, out of options to work around + assert norm_ndim == 1 + v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True + else: + dims = tuple(range(-1, -norm_ndim - 1, -1)) + v = torch.var(x, dim=dims, keepdim=True) + x = x * torch.rsqrt(v + eps) + if weight is not None: + x = x * weight + return x + + +def fast_simple_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor] = None, + eps: float = 1e-5, +) -> torch.Tensor: + if torch.jit.is_scripting(): + # this must be by itself, cannot merge with has_apex_rmsnorm + return simple_norm(x, normalized_shape, weight, eps) + + if is_autocast_enabled(x.device.type): + # normally native AMP casts LN inputs to float32 + # apex LN does not, this is behaving like Apex + dt = get_autocast_dtype(x.device.type) + x, weight = x.to(dt), weight.to(dt) + + with torch.amp.autocast(device_type=x.device.type, enabled=False): + x = simple_norm(x, normalized_shape, weight, eps) + return x + diff --git a/timm/layers/mlp.py b/timm/layers/mlp.py index 11d9eeca32..d1e6774cc9 100644 --- a/timm/layers/mlp.py +++ b/timm/layers/mlp.py @@ -83,9 +83,9 @@ def __init__( def init_weights(self): # override init of fc1 w/ gate portion set to weight near zero, bias=1 - fc1_mid = self.fc1.bias.shape[0] // 2 - nn.init.ones_(self.fc1.bias[fc1_mid:]) - nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) + if self.fc1.bias is not None: + nn.init.ones_(self.fc1.bias[self.fc1.bias.shape[0] // 2:]) + nn.init.normal_(self.fc1.weight[self.fc1.weight.shape[0] // 2:], std=1e-6) def forward(self, x): x = self.fc1(x) @@ -132,7 +132,8 @@ def __init__( def init_weights(self): # override init of fc1 w/ gate portion set to weight near zero, bias=1 - nn.init.ones_(self.fc1_g.bias) + if self.fc1_g.bias is not None: + nn.init.ones_(self.fc1_g.bias) nn.init.normal_(self.fc1_g.weight, std=1e-6) def forward(self, x): diff --git a/timm/layers/norm.py b/timm/layers/norm.py index e9f9c27d8c..f718750dc5 100644 --- a/timm/layers/norm.py +++ b/timm/layers/norm.py @@ -11,17 +11,24 @@ import torch.nn as nn import torch.nn.functional as F -from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm +from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm, simple_norm + +try: + from torch.nn.functional import rms_norm +except ImportError: + from .fast_norm import rms_norm class GroupNorm(nn.GroupNorm): + _fast_norm: torch.jit.Final[bool] + def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN super().__init__(num_groups, num_channels, eps=eps, affine=affine) - self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) def forward(self, x): - if self.fast_norm: + if self._fast_norm: return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) else: return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) @@ -31,13 +38,14 @@ class GroupNorm1(nn.GroupNorm): """ Group Normalization with 1 group. Input: tensor in shape [B, C, *] """ + _fast_norm: torch.jit.Final[bool] def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs) - self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.fast_norm: + if self._fast_norm: return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) else: return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) @@ -46,6 +54,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LayerNorm(nn.LayerNorm): """ LayerNorm w/ fast norm option """ + _fast_norm: torch.jit.Final[bool] + def __init__(self, num_channels, eps=1e-6, affine=True): super().__init__(num_channels, eps=eps, elementwise_affine=affine) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) @@ -60,6 +70,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LayerNorm2d(nn.LayerNorm): """ LayerNorm for channels of '2D' spatial NCHW tensors """ + _fast_norm: torch.jit.Final[bool] + def __init__(self, num_channels, eps=1e-6, affine=True): super().__init__(num_channels, eps=eps, elementwise_affine=affine) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) @@ -121,10 +133,11 @@ def forward(self, x) -> torch.Tensor: class RmsNorm(nn.Module): """ RmsNorm w/ fast (apex) norm if available """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool + _fast_norm: bool def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} @@ -136,6 +149,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] self.eps = eps self.elementwise_affine = affine + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) else: @@ -150,17 +165,21 @@ def reset_parameters(self) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: # NOTE fast norm fallback needs our rms norm impl, so both paths through here. # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed. - x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) + if self._fast_norm: + x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) + else: + x = rms_norm(x, self.normalized_shape, self.weight, self.eps) return x class RmsNorm2d(nn.Module): """ RmsNorm w/ fast (apex) norm if available """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool + _fast_norm: bool def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} @@ -172,6 +191,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] self.eps = eps self.elementwise_affine = affine + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) else: @@ -187,6 +208,91 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 3, 1) # NOTE fast norm fallback needs our rms norm impl, so both paths through here. # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed. - x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) + if self._fast_norm: + x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) + else: + x = rms_norm(x, self.normalized_shape, self.weight, self.eps) + x = x.permute(0, 3, 1, 2) + return x + + +class SimpleNorm(nn.Module): + """ SimpleNorm (x / std(x)) + """ + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + _fast_norm: bool + + def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + normalized_shape = channels + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = affine + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('weight', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._fast_norm: + x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps) + else: + x = simple_norm(x, self.normalized_shape, self.weight, self.eps) + return x + + +class SimpleNorm2d(nn.Module): + """ SimpleNorm for NCHW tensors + """ + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + _fast_norm: bool + + def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + normalized_shape = channels + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = affine + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('weight', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 3, 1) + if self._fast_norm: + x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps) + else: + x = simple_norm(x, self.normalized_shape, self.weight, self.eps) x = x.permute(0, 3, 1, 2) return x diff --git a/timm/models/convnext.py b/timm/models/convnext.py index a6d1999bde..25a767ca6d 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -46,6 +46,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \ LayerNorm2d, LayerNorm, RmsNorm2d, RmsNorm, create_conv2d, get_act_layer, get_norm_layer, make_divisible, to_ntuple +from timm.layers import SimpleNorm2d, SimpleNorm from timm.layers import NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -233,6 +234,34 @@ def forward(self, x): x = self.blocks(x) return x +# map of norm layers with NCHW (2D) and channels last variants +_NORM_MAP = { + 'layernorm': (LayerNorm2d, LayerNorm), + 'layernorm2d': (LayerNorm2d, LayerNorm), + 'simplenorm': (SimpleNorm2d, SimpleNorm), + 'simplenorm2d': (SimpleNorm2d, SimpleNorm), + 'rmsnorm': (RmsNorm2d, RmsNorm), + 'rmsnorm2d': (RmsNorm2d, RmsNorm), +} + + +def _get_norm_layers(norm_layer: Union[Callable, str], conv_mlp: bool, norm_eps: float): + norm_layer = norm_layer or 'layernorm' + if norm_layer in _NORM_MAP: + norm_layer_cl = _NORM_MAP[norm_layer][0] if conv_mlp else _NORM_MAP[norm_layer][1] + norm_layer = _NORM_MAP[norm_layer][0] + if norm_eps is not None: + norm_layer = partial(norm_layer, eps=norm_eps) + norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) + else: + assert conv_mlp, \ + 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' + norm_layer = get_norm_layer(norm_layer) + norm_layer_cl = norm_layer + if norm_eps is not None: + norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) + return norm_layer, norm_layer_cl + class ConvNeXt(nn.Module): r""" ConvNeXt @@ -289,20 +318,7 @@ def __init__( super().__init__() assert output_stride in (8, 16, 32) kernel_sizes = to_ntuple(4)(kernel_sizes) - use_rms = isinstance(norm_layer, str) and norm_layer.startswith('rmsnorm') - if norm_layer is None or use_rms: - norm_layer = RmsNorm2d if use_rms else LayerNorm2d - norm_layer_cl = norm_layer if conv_mlp else (RmsNorm if use_rms else LayerNorm) - if norm_eps is not None: - norm_layer = partial(norm_layer, eps=norm_eps) - norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) - else: - assert conv_mlp,\ - 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' - norm_layer = get_norm_layer(norm_layer) - norm_layer_cl = norm_layer - if norm_eps is not None: - norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) + norm_layer, norm_layer_cl = _get_norm_layers(norm_layer, conv_mlp, norm_eps) act_layer = get_act_layer(act_layer) self.num_classes = num_classes @@ -975,7 +991,7 @@ def _cfgv2(url='', **kwargs): @register_model def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt: # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M - model_args = dict(depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d') + model_args = dict(depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='simplenorm') model = _create_convnext('convnext_zepto_rms', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -984,7 +1000,7 @@ def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt: def convnext_zepto_rms_ols(pretrained=False, **kwargs) -> ConvNeXt: # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M model_args = dict( - depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d', stem_type='overlap_act') + depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='simplenorm', stem_type='overlap_act') model = _create_convnext('convnext_zepto_rms_ols', pretrained=pretrained, **dict(model_args, **kwargs)) return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 6bc93dd1c6..2368d353b3 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -42,7 +42,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \ +from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \ trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ get_act_layer, get_norm_layer, LayerType from ._builder import build_model_with_cfg @@ -65,6 +65,7 @@ def __init__( num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, + proj_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0., norm_layer: nn.Module = nn.LayerNorm, @@ -80,7 +81,7 @@ def __init__( self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -130,6 +131,7 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., init_values: Optional[float] = None, @@ -145,6 +147,7 @@ def __init__( num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, @@ -157,6 +160,7 @@ def __init__( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + bias=proj_bias, drop=proj_drop, ) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() @@ -176,6 +180,7 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., init_values: Optional[float] = None, @@ -192,6 +197,7 @@ def __init__( num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, @@ -203,6 +209,7 @@ def __init__( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + bias=proj_bias, drop=proj_drop, ) self.norm2 = norm_layer(dim) @@ -236,6 +243,7 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., init_values: Optional[float] = None, @@ -266,11 +274,11 @@ def __init__( self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.attn_out_proj = nn.Linear(dim, dim) + self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias) self.mlp_drop = nn.Dropout(proj_drop) self.mlp_act = act_layer() - self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim) + self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias) self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -330,6 +338,7 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + proj_bias: bool = True, init_values: Optional[float] = None, proj_drop: float = 0., attn_drop: float = 0., @@ -350,6 +359,7 @@ def __init__( num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, @@ -363,6 +373,7 @@ def __init__( dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + bias=proj_bias, drop=proj_drop, )), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), @@ -433,6 +444,7 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = True, qk_norm: bool = False, + proj_bias: bool = True, init_values: Optional[float] = None, class_token: bool = True, pos_embed: str = 'learn', @@ -452,6 +464,7 @@ def __init__( weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '', fix_init: bool = False, embed_layer: Callable = PatchEmbed, + embed_norm_layer: Optional[LayerType] = None, norm_layer: Optional[LayerType] = None, act_layer: Optional[LayerType] = None, block_fn: Type[nn.Module] = Block, @@ -483,6 +496,7 @@ def __init__( weight_init: Weight initialization scheme. fix_init: Apply weight initialization fix (scaling w/ layer index). embed_layer: Patch embedding layer. + embed_norm_layer: Normalization layer to use / override in patch embed module. norm_layer: Normalization layer. act_layer: MLP activation layer. block_fn: Transformer block layer. @@ -493,6 +507,7 @@ def __init__( assert pos_embed in ('', 'none', 'learn') use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + embed_norm_layer = get_norm_layer(embed_norm_layer) act_layer = get_act_layer(act_layer) or nn.GELU self.num_classes = num_classes @@ -510,6 +525,8 @@ def __init__( if dynamic_img_size: # flatten deferred until after pos embed embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) + if embed_norm_layer is not None: + embed_args['norm_layer'] = embed_norm_layer self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, @@ -547,6 +564,7 @@ def __init__( mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_norm=qk_norm, + proj_bias=proj_bias, init_values=init_values, proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, @@ -1128,6 +1146,24 @@ def _convert_dinov2( return out_dict +def _convert_aimv2( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, +) -> Dict[str, torch.Tensor]: + out_dict = {} + for k, v in state_dict.items(): + k = k.replace('norm_1', 'norm1') + k = k.replace('norm_2', 'norm2') + k = k.replace('preprocessor.patchifier.', 'patch_embed.') + k = k.replace('preprocessor.pos_embed', 'pos_embed') + k = k.replace('trunk.', '') + k = k.replace('post_trunk_norm.', 'norm.') + k = k.replace('mlp.fc1', 'mlp.fc1_g') + k = k.replace('mlp.fc3', 'mlp.fc1_x') + out_dict[k] = v + return out_dict + + def checkpoint_filter_fn( state_dict: Dict[str, torch.Tensor], model: VisionTransformer, @@ -1159,6 +1195,8 @@ def checkpoint_filter_fn( # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) out_dict['head.weight'] = state_dict['visual.head.proj.weight'] out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0]) + elif 'preprocessor.patchifier.proj.weight' in state_dict: + state_dict = _convert_aimv2(state_dict, model) if prefix: # filter on & remove prefix string from keys @@ -1637,18 +1675,26 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: 'vit_base_patch16_clip_224.dfn2b': _cfg( hf_hub_id='timm/', + license='apple-ascl', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_large_patch14_clip_224.dfn2b_s39b': _cfg( + hf_hub_id='timm/', + license='apple-ascl', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), 'vit_large_patch14_clip_224.dfn2b': _cfg( hf_hub_id='timm/', + license='apple-ascl', notes=('natively QuickGELU, use quickgelu model variant for original results',), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), 'vit_huge_patch14_clip_224.dfn5b': _cfg( hf_hub_id='timm/', + license='apple-ascl', notes=('natively QuickGELU, use quickgelu model variant for original results',), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), 'vit_huge_patch14_clip_378.dfn5b': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + license='apple-ascl', notes=('natively QuickGELU, use quickgelu model variant for original results',), crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024), @@ -2119,6 +2165,63 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: input_size=(3, 448, 448), crop_pct=1.0, num_classes=0, ), + 'aimv2_large_patch14_224.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + crop_pct=1.0, num_classes=0), + 'aimv2_large_patch14_224.apple_pt_dist': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + crop_pct=1.0, num_classes=0), + 'aimv2_huge_patch14_224.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + crop_pct=1.0, num_classes=0), + 'aimv2_1b_patch14_224.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + crop_pct=1.0, num_classes=0), + 'aimv2_3b_patch14_224.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + crop_pct=1.0, num_classes=0), + 'aimv2_large_patch14_336.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 336, 336), crop_pct=1.0, num_classes=0), + 'aimv2_large_patch14_336.apple_pt_dist': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 336, 336), crop_pct=1.0, num_classes=0), + 'aimv2_huge_patch14_336.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 336, 336), crop_pct=1.0, num_classes=0), + 'aimv2_1b_patch14_336.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 336, 336), crop_pct=1.0, num_classes=0), + 'aimv2_3b_patch14_336.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 336, 336), crop_pct=1.0, num_classes=0), + 'aimv2_large_patch14_448.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 448, 448), crop_pct=1.0, num_classes=0), + 'aimv2_huge_patch14_448.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 448, 448), crop_pct=1.0, num_classes=0), + 'aimv2_1b_patch14_448.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 448, 448), crop_pct=1.0, num_classes=0), + 'aimv2_3b_patch14_448.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 448, 448), crop_pct=1.0, num_classes=0), + 'test_vit.r160_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 160, 160), crop_pct=0.95), @@ -2128,6 +2231,8 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: 'test_vit3.r160_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 160, 160), crop_pct=0.95), + 'test_vit4.r160_in1k': _cfg( + input_size=(3, 160, 160), crop_pct=0.95), } _quick_gelu_cfgs = [n for n, c in default_cfgs.items() if c.get('notes', ()) and 'quickgelu' in c['notes'][0]] @@ -3390,6 +3495,175 @@ def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTran return model +@register_model +def aimv2_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Large AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Huge AIM-v2 model + """ + + model_args = dict( + patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_1b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 1B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_1b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_3b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 3B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_3b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Large AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_huge_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Huge AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_huge_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_1b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 1B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_1b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_3b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 3B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_3b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_large_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Large AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_huge_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Huge AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_huge_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_1b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 1B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_1b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_3b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 3B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_3b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT Test @@ -3421,6 +3695,19 @@ def test_vit3(pretrained: bool = False, **kwargs) -> VisionTransformer: return model +@register_model +def test_vit4(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Test + """ + model_args = dict( + patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=3, + class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5, dynamic_img_size=True, + norm_layer='rmsnorm', + ) + model = _create_vision_transformer('test_vit4', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + register_model_deprecations(__name__, { 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k', 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k', diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index 18635f6038..f8e2a9a1e0 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -259,16 +259,17 @@ def __init__( in_features, hidden_features, act_layer = 'gelu', + bias = True, drop = 0.0, ): super().__init__() norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6) self.norm = norm_layer(in_features) - self.w0 = nn.Linear(in_features, hidden_features) + self.w0 = nn.Linear(in_features, hidden_features, bias=bias) self.act = create_act_layer(act_layer) - self.w1 = nn.Linear(in_features, hidden_features) - self.w2 = nn.Linear(hidden_features, in_features) + self.w1 = nn.Linear(in_features, hidden_features, bias=bias) + self.w2 = nn.Linear(hidden_features, in_features, bias=bias) def forward(self, x): x = self.norm(x)