diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 3c7b9a2277..594415493b 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -39,11 +39,30 @@ import torch.nn.functional as F from torch.jit import Final -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ +from timm.data import ( + IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, + IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \ - trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ - get_act_layer, get_norm_layer, LayerType +) +from timm.layers import ( + PatchEmbed, + Mlp, + DropPath, + AttentionPoolLatent, + LayerNorm, + RmsNorm, + PatchDropout, + SwiGLUPacked, + SwiGLU, + trunc_normal_, + lecun_normal_, + resample_patch_embed, + resample_abs_pos_embed, + use_fused_attn, + get_act_layer, + get_norm_layer, + LayerType, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv @@ -64,10 +83,11 @@ def __init__( num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, + scale_norm: bool = False, proj_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0., - norm_layer: Type[nn.Module] = nn.LayerNorm, + norm_layer: Type[nn.Module] = LayerNorm, ) -> None: super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' @@ -79,6 +99,7 @@ def __init__( self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.norm = norm_layer(dim) if scale_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) @@ -102,6 +123,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) + x = self.norm(x) x = self.proj(x) x = self.proj_drop(x) return x @@ -130,13 +152,15 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., init_values: Optional[float] = None, drop_path: float = 0., act_layer: Type[nn.Module] = nn.GELU, - norm_layer: Type[nn.Module] = nn.LayerNorm, + norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Type[nn.Module] = Mlp, ) -> None: super().__init__() @@ -146,6 +170,7 @@ def __init__( num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + scale_norm=scale_attn_norm, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -159,6 +184,7 @@ def __init__( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + norm_layer=norm_layer if scale_mlp_norm else None, bias=proj_bias, drop=proj_drop, ) @@ -179,13 +205,15 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., init_values: Optional[float] = None, drop_path: float = 0., act_layer: Type[nn.Module] = nn.GELU, - norm_layer: Type[nn.Module] = nn.LayerNorm, + norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Type[nn.Module] = Mlp, ) -> None: super().__init__() @@ -196,6 +224,7 @@ def __init__( num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + scale_norm=scale_attn_norm, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -208,6 +237,7 @@ def __init__( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + norm_layer=norm_layer if scale_mlp_norm else None, bias=proj_bias, drop=proj_drop, ) @@ -242,17 +272,20 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., init_values: Optional[float] = None, drop_path: float = 0., act_layer: Type[nn.Module] = nn.GELU, - norm_layer: Type[nn.Module] = nn.LayerNorm, + norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Optional[Type[nn.Module]] = None, ) -> None: super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' + assert not scale_attn_norm and not scale_mlp_norm, 'Scale norms not supported' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 @@ -337,13 +370,15 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, init_values: Optional[float] = None, proj_drop: float = 0., attn_drop: float = 0., drop_path: float = 0., act_layer: Type[nn.Module] = nn.GELU, - norm_layer: Type[nn.Module] = nn.LayerNorm, + norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Type[nn.Module] = Mlp, ) -> None: super().__init__() @@ -358,6 +393,7 @@ def __init__( num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + scale_norm=scale_attn_norm, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -372,6 +408,7 @@ def __init__( dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + norm_layer=norm_layer if scale_mlp_norm else None, bias=proj_bias, drop=proj_drop, )), @@ -443,6 +480,8 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = True, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, init_values: Optional[float] = None, class_token: bool = True, @@ -505,7 +544,7 @@ def __init__( assert class_token or global_pool != 'token' assert pos_embed in ('', 'none', 'learn') use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm - norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + norm_layer = get_norm_layer(norm_layer) or LayerNorm embed_norm_layer = get_norm_layer(embed_norm_layer) act_layer = get_act_layer(act_layer) or nn.GELU @@ -563,6 +602,8 @@ def __init__( mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_norm=qk_norm, + scale_attn_norm=scale_attn_norm, + scale_mlp_norm=scale_mlp_norm, proj_bias=proj_bias, init_values=init_values, proj_drop=proj_drop_rate, @@ -1166,6 +1207,66 @@ def _convert_aimv2( return out_dict +def _convert_beit3(state_dict: dict, model): + """ + Turn a BEiT-3 checkpoint into a standard VisionTransformer state-dict. + """ + import re + state_dict = state_dict.get("model", state_dict) # unwrap if needed + + # Prune unused + for k in ("beit3.text_embed.weight", "beit3.vision_embed.mask_token"): + state_dict.pop(k, None) + + # Key renaming rules + rules = [ + (r"beit3\.", ""), + (r"vision_embed\.cls_token", "cls_token"), + (r"vision_embed\.", "patch_embed."), + (r"embed_positions\.", "pos_embed."), + (r"encoder\.", ""), + (r"layers\.", "blocks."), + (r"ffn_layernorm\.", "norm."), (r"ffn\.", "mlp."), + (r"self_attn_layer_norm\.", "norm1."), (r"self_attn\.", "attn."), + (r"final_layer_norm\.", "norm2."), + (r"inner_attn_ln", "norm"), + (r"out_proj", "proj"), + (r"\.A\.", "."), + ] + + # First pass, rename keys + tmp = {} + for k, v in state_dict.items(): + if ".B." in k: + continue # use branch-A only + for old, new in rules: + k = re.sub(old, new, k) + if k == "pos_embed.weight": + # strip first two positions, [1, N+1, D] + tmp["pos_embed"] = v[2:].unsqueeze(0) + else: + tmp[k] = v + + # Second pass, fuse q, k, v + out, buf = {}, {} + pat = re.compile(r"blocks\.(\d+)\.attn\.(q|k|v)_proj\.(weight|bias)$") + for k, v in tmp.items(): + m = pat.fullmatch(k) + if not m: # anything not q/k/v -> copy through + out[k] = v + continue + + blk, which, kind = m.groups() # block idx, 'q'/'k'/'v', 'weight'/'bias' + stash = buf.setdefault((blk, kind), {}) # Gather by block & param type + stash[which] = v + if len(stash) == 3: # Have q, k, v -> concatenate + out[f"blocks.{blk}.attn.qkv.{kind}"] = torch.cat( + [stash['q'], stash['k'], stash['v']], dim=0 + ) + + return out + + def checkpoint_filter_fn( state_dict: Dict[str, torch.Tensor], model: VisionTransformer, @@ -1186,6 +1287,9 @@ def checkpoint_filter_fn( state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.') elif "mask_token" in state_dict: state_dict = _convert_dinov2(state_dict, model) + elif any('beit3.' in k for k in state_dict.keys()): + # BEiT3 model - multimodal checkpoint with beit3.* prefix + state_dict = _convert_beit3(state_dict, model) elif "encoder" in state_dict: # IJEPA, vit in an 'encoder' submodule state_dict = state_dict['encoder'] @@ -2377,6 +2481,44 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: input_size=(3, 160, 160), crop_pct=0.95), 'test_vit4.r160_in1k': _cfg( input_size=(3, 160, 160), crop_pct=0.95), + + # BEiT3 models (remapped to VisionTransformer with scale_attn_norm=True, scale_mlp_norm=True) + 'beit3_base_patch16_224.in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_base_patch16_224.indomain_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_large_patch16_224.in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_large_patch16_224.indomain_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_giant_patch14_224.untrained': _cfg( + url='', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_giant_patch14_336.untrained': _cfg( + url='', input_size=(3, 336, 336), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_base_patch16_224.pt': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0, + num_classes=0, + ), + 'beit3_base_patch16_224.indomain_pt': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0, + num_classes=0, + ), + 'beit3_large_patch16_224.pt': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0, + num_classes=0, + ), + 'beit3_large_patch16_224.indomain_pt': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0, + num_classes=0, + ), } _quick_gelu_cfgs = [n for n, c in default_cfgs.items() if c.get('notes', ()) and 'quickgelu' in c['notes'][0]] @@ -2710,7 +2852,7 @@ def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTrans @register_model def vit_xsmall_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: # TinyCLIP 8M - model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_xsmall_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2720,7 +2862,7 @@ def vit_xsmall_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTra def vit_medium_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: # TinyCLIP 40M model_args = dict( - patch_size=32, embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=32, embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_medium_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2729,7 +2871,7 @@ def vit_medium_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTra @register_model def vit_medium_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: # TinyCLIP 39M - model_args = dict(embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict(embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_medium_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2739,7 +2881,7 @@ def vit_medium_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTra def vit_betwixt_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: # TinyCLIP 61M model_args = dict( - patch_size=32, embed_dim=640, depth=12, num_heads=10, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=32, embed_dim=640, depth=12, num_heads=10, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_betwixt_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2750,7 +2892,7 @@ def vit_base_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTrans """ ViT-B/32 CLIP image tower @ 224x224 """ model_args = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2761,7 +2903,7 @@ def vit_base_patch32_clip_256(pretrained: bool = False, **kwargs) -> VisionTrans """ ViT-B/32 CLIP image tower @ 256x256 """ model_args = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_base_patch32_clip_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2772,7 +2914,7 @@ def vit_base_patch32_clip_384(pretrained: bool = False, **kwargs) -> VisionTrans """ ViT-B/32 CLIP image tower @ 384x384 """ model_args = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_base_patch32_clip_384', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2783,7 +2925,7 @@ def vit_base_patch32_clip_448(pretrained: bool = False, **kwargs) -> VisionTrans """ ViT-B/32 CLIP image tower @ 448x448 """ model_args = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_base_patch32_clip_448', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2793,7 +2935,8 @@ def vit_base_patch32_clip_448(pretrained: bool = False, **kwargs) -> VisionTrans def vit_base_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/16 CLIP image tower """ - model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2803,7 +2946,8 @@ def vit_base_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTrans def vit_base_patch16_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/16 CLIP image tower @ 384x384 """ - model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_base_patch16_clip_384', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2813,7 +2957,8 @@ def vit_base_patch16_clip_384(pretrained: bool = False, **kwargs) -> VisionTrans def vit_base_patch16_plus_clip_240(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Base (ViT-B/16+) CLIP image tower @ 240x240 """ - model_args = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=16, embed_dim=896, depth=12, num_heads=14, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_base_patch16_plus_clip_240', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2823,7 +2968,8 @@ def vit_base_patch16_plus_clip_240(pretrained: bool = False, **kwargs) -> Vision def vit_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/14) CLIP image tower """ - model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2833,7 +2979,8 @@ def vit_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTran def vit_large_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 """ - model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2843,7 +2990,8 @@ def vit_large_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTran def vit_huge_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/14) CLIP image tower. """ - model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2853,7 +3001,8 @@ def vit_huge_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTrans def vit_huge_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336 """ - model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_huge_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2863,7 +3012,8 @@ def vit_huge_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTrans def vit_huge_patch14_clip_378(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 """ - model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_huge_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2875,7 +3025,9 @@ def vit_giant_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTran Pretrained weights from CLIP image tower. """ model_args = dict( - patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, + norm_layer=partial(LayerNorm, eps=1e-5), + ) model = _create_vision_transformer( 'vit_giant_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2887,7 +3039,9 @@ def vit_gigantic_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionT Pretrained weights from CLIP image tower. """ model_args = dict( - patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, + norm_layer=partial(LayerNorm, eps=1e-5), + ) model = _create_vision_transformer( 'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2899,7 +3053,8 @@ def vit_base_patch32_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> V """ model_args = dict( patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, - norm_layer=nn.LayerNorm, act_layer='quick_gelu') + norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu' + ) model = _create_vision_transformer( 'vit_base_patch32_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2911,7 +3066,8 @@ def vit_base_patch16_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> V """ model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, - norm_layer=nn.LayerNorm, act_layer='quick_gelu') + norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu' + ) model = _create_vision_transformer( 'vit_base_patch16_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2923,7 +3079,8 @@ def vit_large_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> """ model_args = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, - norm_layer=nn.LayerNorm, act_layer='quick_gelu') + norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu' + ) model = _create_vision_transformer( 'vit_large_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2935,7 +3092,8 @@ def vit_large_patch14_clip_quickgelu_336(pretrained: bool = False, **kwargs) -> """ model_args = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, - norm_layer=nn.LayerNorm, act_layer='quick_gelu') + norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu' + ) model = _create_vision_transformer( 'vit_large_patch14_clip_quickgelu_336', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2947,7 +3105,8 @@ def vit_huge_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> V """ model_args = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, - norm_layer=nn.LayerNorm, act_layer='quick_gelu') + norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu' + ) model = _create_vision_transformer( 'vit_huge_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2959,7 +3118,8 @@ def vit_huge_patch14_clip_quickgelu_378(pretrained: bool = False, **kwargs) -> V """ model_args = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, - norm_layer=nn.LayerNorm, act_layer='quick_gelu') + norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu' + ) model = _create_vision_transformer( 'vit_huge_patch14_clip_quickgelu_378', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2971,7 +3131,8 @@ def vit_gigantic_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) """ model_args = dict( patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, - norm_layer=nn.LayerNorm, act_layer='quick_gelu') + norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu' + ) model = _create_vision_transformer( 'vit_gigantic_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -3606,7 +3767,6 @@ def vit_giantopt_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> V return model - @register_model def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( @@ -4035,6 +4195,62 @@ def test_vit4(pretrained: bool = False, **kwargs) -> VisionTransformer: return model +@register_model +def beit3_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Base model (ViT-Base size) with patch size 16x16. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg', + norm_layer=partial(LayerNorm, eps=1e-5) + ) + model = _create_vision_transformer('beit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def beit3_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Large model (ViT-Large size) with patch size 16x16. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg', + norm_layer=partial(LayerNorm, eps=1e-5), + ) + model = _create_vision_transformer('beit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def beit3_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Giant model with patch size 14x14. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg', + norm_layer=partial(LayerNorm, eps=1e-5), + ) + model = _create_vision_transformer('beit3_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def beit3_giant_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Giant model with patch size 14x14 and image size 336x336. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + img_size=336, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg', + norm_layer=partial(LayerNorm, eps=1e-5), + ) + model = _create_vision_transformer('beit3_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + register_model_deprecations(__name__, { 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k', 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k', diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index f8e2a9a1e0..7d023e0d31 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -259,11 +259,12 @@ def __init__( in_features, hidden_features, act_layer = 'gelu', + norm_layer = None, bias = True, drop = 0.0, ): super().__init__() - norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6) + norm_layer = partial(get_norm_layer(norm_layer or 'layernorm'), eps=1e-6) self.norm = norm_layer(in_features) self.w0 = nn.Linear(in_features, hidden_features, bias=bias)