From 38c5f3bc63d7032e71806f84ce0a51daa5559885 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 May 2025 10:33:09 -0700 Subject: [PATCH 1/5] A cleaned up beit3 remap onto vision_transformer.py vit --- timm/models/vision_transformer.py | 260 ++++++++++++++++++++++++++---- 1 file changed, 227 insertions(+), 33 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 3c7b9a2277..718c395dce 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -41,8 +41,8 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \ - trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ +from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, LayerNorm, RmsNorm, PatchDropout, \ + SwiGLUPacked, SwiGLU, trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ get_act_layer, get_norm_layer, LayerType from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -64,10 +64,11 @@ def __init__( num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, proj_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0., - norm_layer: Type[nn.Module] = nn.LayerNorm, + norm_layer: Type[nn.Module] = LayerNorm, ) -> None: super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' @@ -79,6 +80,7 @@ def __init__( self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.norm = norm_layer(dim) if scale_attn_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) @@ -102,6 +104,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) + x = self.norm(x) x = self.proj(x) x = self.proj_drop(x) return x @@ -130,13 +133,15 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., init_values: Optional[float] = None, drop_path: float = 0., act_layer: Type[nn.Module] = nn.GELU, - norm_layer: Type[nn.Module] = nn.LayerNorm, + norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Type[nn.Module] = Mlp, ) -> None: super().__init__() @@ -146,6 +151,7 @@ def __init__( num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + scale_attn_norm=scale_attn_norm, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -159,6 +165,7 @@ def __init__( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + norm_layer=norm_layer if scale_mlp_norm else None, bias=proj_bias, drop=proj_drop, ) @@ -179,13 +186,15 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., init_values: Optional[float] = None, drop_path: float = 0., act_layer: Type[nn.Module] = nn.GELU, - norm_layer: Type[nn.Module] = nn.LayerNorm, + norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Type[nn.Module] = Mlp, ) -> None: super().__init__() @@ -196,6 +205,7 @@ def __init__( num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + scale_attn_norm=scale_attn_norm, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -208,6 +218,7 @@ def __init__( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + norm_layer=norm_layer if scale_mlp_norm else None, bias=proj_bias, drop=proj_drop, ) @@ -248,7 +259,7 @@ def __init__( init_values: Optional[float] = None, drop_path: float = 0., act_layer: Type[nn.Module] = nn.GELU, - norm_layer: Type[nn.Module] = nn.LayerNorm, + norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Optional[Type[nn.Module]] = None, ) -> None: super().__init__() @@ -343,7 +354,7 @@ def __init__( attn_drop: float = 0., drop_path: float = 0., act_layer: Type[nn.Module] = nn.GELU, - norm_layer: Type[nn.Module] = nn.LayerNorm, + norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Type[nn.Module] = Mlp, ) -> None: super().__init__() @@ -443,6 +454,8 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = True, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, init_values: Optional[float] = None, class_token: bool = True, @@ -505,7 +518,7 @@ def __init__( assert class_token or global_pool != 'token' assert pos_embed in ('', 'none', 'learn') use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm - norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + norm_layer = get_norm_layer(norm_layer) or LayerNorm embed_norm_layer = get_norm_layer(embed_norm_layer) act_layer = get_act_layer(act_layer) or nn.GELU @@ -563,6 +576,8 @@ def __init__( mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_norm=qk_norm, + scale_attn_norm=scale_attn_norm, + scale_mlp_norm=scale_mlp_norm, proj_bias=proj_bias, init_values=init_values, proj_drop=proj_drop_rate, @@ -1166,6 +1181,89 @@ def _convert_aimv2( return out_dict +def _convert_beit3( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, +) -> Dict[str, torch.Tensor]: + """Convert BEiT3 weights to standard VisionTransformer format.""" + import re + + if 'model' in state_dict: + state_dict = state_dict['model'] + + # Remove text and mask tokens (vision-only) + state_dict.pop('beit3.text_embed.weight', None) + state_dict.pop('beit3.vision_embed.mask_token', None) + + # First pass: Apply all key transformations except qkv fusion + intermediate_dict = {} + for k, v in state_dict.items(): + # Skip B branch weights (use only A branch) + if '.B.' in k: + continue + + # Apply all BEiT3 key transformations in one go + if 'vision_embed.cls_token' in k: + k = 'cls_token' + else: + k = k.replace('beit3.', '') + k = k.replace('embed_positions.', 'pos_embed.') + k = k.replace('vision_embed.', 'patch_embed.') + k = k.replace('encoder.', '') + k = k.replace('layers.', 'blocks.') + k = k.replace('ffn.', 'mlp.') + k = k.replace('ffn_layernorm.', 'norm.') + k = k.replace('self_attn.', 'attn.') + k = k.replace('self_attn_layer_norm.', 'norm1.') + k = k.replace('final_layer_norm.', 'norm2.') + k = k.replace('inner_attn_ln', 'norm') # Map inner attention LayerNorm to scale norm + k = k.replace('out_proj', 'proj') # Map out_proj to proj + k = k.replace('A.', '') # Remove A branch prefix + + # Handle positional embedding - skip first 2 positions (BEiT3 starts from index 2) + if k == 'pos_embed.weight': + # BEiT3 pos_embed.weight has shape [num_patches + 3, embed_dim] + # We want [1, num_patches + 1, embed_dim] for standard ViT (cls token + patches) + intermediate_dict['pos_embed'] = v[2:].unsqueeze(0) # Skip first 2 positions, add batch dim + else: + intermediate_dict[k] = v + + # Second pass: Handle qkv fusion + out_dict = {} + processed_qkv = set() + for k, v in intermediate_dict.items(): + # Handle attention projections - convert separate q,k,v to fused qkv + if re.match(r"blocks\.(\d+)\.attn\.[qkv]_proj\.(weight|bias)", k): + block_idx = re.search(r"blocks\.(\d+)", k).group(1) + param_type = re.search(r"\.(weight|bias)$", k).group(1) + + # Only process once per block per parameter type + block_param_key = f"{block_idx}_{param_type}" + if block_param_key in processed_qkv: + continue + + # Collect all three projections for this block + q_key = f"blocks.{block_idx}.attn.q_proj.{param_type}" + k_key = f"blocks.{block_idx}.attn.k_proj.{param_type}" + v_key = f"blocks.{block_idx}.attn.v_proj.{param_type}" + + if all(key in intermediate_dict for key in [q_key, k_key, v_key]): + qkv_tensor = torch.cat([ + intermediate_dict[q_key], + intermediate_dict[k_key], + intermediate_dict[v_key] + ], dim=0) + out_dict[f"blocks.{block_idx}.attn.qkv.{param_type}"] = qkv_tensor + processed_qkv.add(block_param_key) + continue + else: + assert False + else: + out_dict[k] = v + + return out_dict + + def checkpoint_filter_fn( state_dict: Dict[str, torch.Tensor], model: VisionTransformer, @@ -1186,6 +1284,9 @@ def checkpoint_filter_fn( state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.') elif "mask_token" in state_dict: state_dict = _convert_dinov2(state_dict, model) + elif any('beit3.' in k for k in state_dict.keys()): + # BEiT3 model - multimodal checkpoint with beit3.* prefix + state_dict = _convert_beit3(state_dict, model) elif "encoder" in state_dict: # IJEPA, vit in an 'encoder' submodule state_dict = state_dict['encoder'] @@ -2377,6 +2478,24 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: input_size=(3, 160, 160), crop_pct=0.95), 'test_vit4.r160_in1k': _cfg( input_size=(3, 160, 160), crop_pct=0.95), + + # BEiT3 models (remapped to VisionTransformer with scale_norm=True) + 'beit3_base_patch16_224.in22k_ft_in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_base_patch16_224.in22k_indomain_ft_in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_large_patch16_224.in22k_ft_in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_large_patch16_224.in22k_indomain_ft_in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_giant_patch14_224.untrained': _cfg( + url='', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_giant_patch14_336.untrained': _cfg( + url='', input_size=(3, 336, 336), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), } _quick_gelu_cfgs = [n for n, c in default_cfgs.items() if c.get('notes', ()) and 'quickgelu' in c['notes'][0]] @@ -2710,7 +2829,7 @@ def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTrans @register_model def vit_xsmall_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: # TinyCLIP 8M - model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_xsmall_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2720,7 +2839,7 @@ def vit_xsmall_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTra def vit_medium_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: # TinyCLIP 40M model_args = dict( - patch_size=32, embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=32, embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_medium_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2729,7 +2848,7 @@ def vit_medium_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTra @register_model def vit_medium_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: # TinyCLIP 39M - model_args = dict(embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict(embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_medium_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2739,7 +2858,7 @@ def vit_medium_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTra def vit_betwixt_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: # TinyCLIP 61M model_args = dict( - patch_size=32, embed_dim=640, depth=12, num_heads=10, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=32, embed_dim=640, depth=12, num_heads=10, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_betwixt_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2750,7 +2869,7 @@ def vit_base_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTrans """ ViT-B/32 CLIP image tower @ 224x224 """ model_args = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2761,7 +2880,7 @@ def vit_base_patch32_clip_256(pretrained: bool = False, **kwargs) -> VisionTrans """ ViT-B/32 CLIP image tower @ 256x256 """ model_args = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_base_patch32_clip_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2772,7 +2891,7 @@ def vit_base_patch32_clip_384(pretrained: bool = False, **kwargs) -> VisionTrans """ ViT-B/32 CLIP image tower @ 384x384 """ model_args = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_base_patch32_clip_384', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2783,7 +2902,7 @@ def vit_base_patch32_clip_448(pretrained: bool = False, **kwargs) -> VisionTrans """ ViT-B/32 CLIP image tower @ 448x448 """ model_args = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_base_patch32_clip_448', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2793,7 +2912,8 @@ def vit_base_patch32_clip_448(pretrained: bool = False, **kwargs) -> VisionTrans def vit_base_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/16 CLIP image tower """ - model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2803,7 +2923,8 @@ def vit_base_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTrans def vit_base_patch16_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/16 CLIP image tower @ 384x384 """ - model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_base_patch16_clip_384', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2813,7 +2934,8 @@ def vit_base_patch16_clip_384(pretrained: bool = False, **kwargs) -> VisionTrans def vit_base_patch16_plus_clip_240(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Base (ViT-B/16+) CLIP image tower @ 240x240 """ - model_args = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=16, embed_dim=896, depth=12, num_heads=14, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_base_patch16_plus_clip_240', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2823,7 +2945,8 @@ def vit_base_patch16_plus_clip_240(pretrained: bool = False, **kwargs) -> Vision def vit_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/14) CLIP image tower """ - model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2833,7 +2956,8 @@ def vit_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTran def vit_large_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 """ - model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2843,7 +2967,8 @@ def vit_large_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTran def vit_huge_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/14) CLIP image tower. """ - model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2853,7 +2978,8 @@ def vit_huge_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTrans def vit_huge_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336 """ - model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_huge_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2863,7 +2989,8 @@ def vit_huge_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTrans def vit_huge_patch14_clip_378(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 """ - model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5)) model = _create_vision_transformer( 'vit_huge_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2875,7 +3002,9 @@ def vit_giant_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTran Pretrained weights from CLIP image tower. """ model_args = dict( - patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, + norm_layer=partial(LayerNorm, eps=1e-5), + ) model = _create_vision_transformer( 'vit_giant_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2887,7 +3016,9 @@ def vit_gigantic_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionT Pretrained weights from CLIP image tower. """ model_args = dict( - patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, + norm_layer=partial(LayerNorm, eps=1e-5), + ) model = _create_vision_transformer( 'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2899,7 +3030,8 @@ def vit_base_patch32_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> V """ model_args = dict( patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, - norm_layer=nn.LayerNorm, act_layer='quick_gelu') + norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu' + ) model = _create_vision_transformer( 'vit_base_patch32_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2911,7 +3043,8 @@ def vit_base_patch16_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> V """ model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, - norm_layer=nn.LayerNorm, act_layer='quick_gelu') + norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu' + ) model = _create_vision_transformer( 'vit_base_patch16_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2923,7 +3056,8 @@ def vit_large_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> """ model_args = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, - norm_layer=nn.LayerNorm, act_layer='quick_gelu') + norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu' + ) model = _create_vision_transformer( 'vit_large_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2935,7 +3069,8 @@ def vit_large_patch14_clip_quickgelu_336(pretrained: bool = False, **kwargs) -> """ model_args = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, - norm_layer=nn.LayerNorm, act_layer='quick_gelu') + norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu' + ) model = _create_vision_transformer( 'vit_large_patch14_clip_quickgelu_336', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2947,7 +3082,8 @@ def vit_huge_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> V """ model_args = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, - norm_layer=nn.LayerNorm, act_layer='quick_gelu') + norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu' + ) model = _create_vision_transformer( 'vit_huge_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2959,7 +3095,8 @@ def vit_huge_patch14_clip_quickgelu_378(pretrained: bool = False, **kwargs) -> V """ model_args = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, - norm_layer=nn.LayerNorm, act_layer='quick_gelu') + norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu' + ) model = _create_vision_transformer( 'vit_huge_patch14_clip_quickgelu_378', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2971,7 +3108,8 @@ def vit_gigantic_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) """ model_args = dict( patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, - norm_layer=nn.LayerNorm, act_layer='quick_gelu') + norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu' + ) model = _create_vision_transformer( 'vit_gigantic_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -4035,6 +4173,62 @@ def test_vit4(pretrained: bool = False, **kwargs) -> VisionTransformer: return model +@register_model +def beit3_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Base model (ViT-Base size) with patch size 16x16. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg', + norm_layer=partial(LayerNorm, eps=1e-5) + ) + model = _create_vision_transformer('beit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def beit3_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Large model (ViT-Large size) with patch size 16x16. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg', + norm_layer=partial(LayerNorm, eps=1e-5), + ) + model = _create_vision_transformer('beit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def beit3_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Giant model with patch size 14x14. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg', + norm_layer=partial(LayerNorm, eps=1e-5), + ) + model = _create_vision_transformer('beit3_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def beit3_giant_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Giant model with patch size 14x14 and image size 336x336. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + img_size=336, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg', + norm_layer=partial(LayerNorm, eps=1e-5), + ) + model = _create_vision_transformer('beit3_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + register_model_deprecations(__name__, { 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k', 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k', From 2ca94a6ce4ef955b42df94783989822c0dbf6ee7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 May 2025 10:52:39 -0700 Subject: [PATCH 2/5] Compact _covert_beit3 fn --- timm/models/vision_transformer.py | 127 ++++++++++++------------------ 1 file changed, 52 insertions(+), 75 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 718c395dce..a843f13980 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1181,87 +1181,64 @@ def _convert_aimv2( return out_dict -def _convert_beit3( - state_dict: Dict[str, torch.Tensor], - model: VisionTransformer, -) -> Dict[str, torch.Tensor]: - """Convert BEiT3 weights to standard VisionTransformer format.""" +def _convert_beit3(state_dict: dict, model): + """ + Turn a BEiT-3 checkpoint into a standard VisionTransformer state-dict. + """ import re + state_dict = state_dict.get("model", state_dict) # unwrap if needed + + # Prune unused + for k in ("beit3.text_embed.weight", "beit3.vision_embed.mask_token"): + state_dict.pop(k, None) + + # Key renaming rules + rules = [ + (r"beit3\.", ""), + (r"vision_embed\.cls_token", "cls_token"), + (r"vision_embed\.", "patch_embed."), + (r"embed_positions\.", "pos_embed."), + (r"encoder\.", ""), + (r"layers\.", "blocks."), + (r"ffn_layernorm\.", "norm."), (r"ffn\.", "mlp."), + (r"self_attn_layer_norm\.", "norm1."), (r"self_attn\.", "attn."), + (r"final_layer_norm\.", "norm2."), + (r"inner_attn_ln", "norm"), + (r"out_proj", "proj"), + (r"\.A\.", "."), + ] - if 'model' in state_dict: - state_dict = state_dict['model'] - - # Remove text and mask tokens (vision-only) - state_dict.pop('beit3.text_embed.weight', None) - state_dict.pop('beit3.vision_embed.mask_token', None) - - # First pass: Apply all key transformations except qkv fusion - intermediate_dict = {} + # First pass, rename keys + tmp = {} for k, v in state_dict.items(): - # Skip B branch weights (use only A branch) - if '.B.' in k: - continue - - # Apply all BEiT3 key transformations in one go - if 'vision_embed.cls_token' in k: - k = 'cls_token' + if ".B." in k: + continue # use branch-A only + for old, new in rules: + k = re.sub(old, new, k) + if k == "pos_embed.weight": + # strip first two positions, [1, N+1, D] + tmp["pos_embed"] = v[2:].unsqueeze(0) else: - k = k.replace('beit3.', '') - k = k.replace('embed_positions.', 'pos_embed.') - k = k.replace('vision_embed.', 'patch_embed.') - k = k.replace('encoder.', '') - k = k.replace('layers.', 'blocks.') - k = k.replace('ffn.', 'mlp.') - k = k.replace('ffn_layernorm.', 'norm.') - k = k.replace('self_attn.', 'attn.') - k = k.replace('self_attn_layer_norm.', 'norm1.') - k = k.replace('final_layer_norm.', 'norm2.') - k = k.replace('inner_attn_ln', 'norm') # Map inner attention LayerNorm to scale norm - k = k.replace('out_proj', 'proj') # Map out_proj to proj - k = k.replace('A.', '') # Remove A branch prefix - - # Handle positional embedding - skip first 2 positions (BEiT3 starts from index 2) - if k == 'pos_embed.weight': - # BEiT3 pos_embed.weight has shape [num_patches + 3, embed_dim] - # We want [1, num_patches + 1, embed_dim] for standard ViT (cls token + patches) - intermediate_dict['pos_embed'] = v[2:].unsqueeze(0) # Skip first 2 positions, add batch dim - else: - intermediate_dict[k] = v + tmp[k] = v + + # Second pass, fuse q, k, v + out, buf = {}, {} + pat = re.compile(r"blocks\.(\d+)\.attn\.(q|k|v)_proj\.(weight|bias)$") + for k, v in tmp.items(): + m = pat.fullmatch(k) + if not m: # anything not q/k/v -> copy through + out[k] = v + continue - # Second pass: Handle qkv fusion - out_dict = {} - processed_qkv = set() - for k, v in intermediate_dict.items(): - # Handle attention projections - convert separate q,k,v to fused qkv - if re.match(r"blocks\.(\d+)\.attn\.[qkv]_proj\.(weight|bias)", k): - block_idx = re.search(r"blocks\.(\d+)", k).group(1) - param_type = re.search(r"\.(weight|bias)$", k).group(1) - - # Only process once per block per parameter type - block_param_key = f"{block_idx}_{param_type}" - if block_param_key in processed_qkv: - continue - - # Collect all three projections for this block - q_key = f"blocks.{block_idx}.attn.q_proj.{param_type}" - k_key = f"blocks.{block_idx}.attn.k_proj.{param_type}" - v_key = f"blocks.{block_idx}.attn.v_proj.{param_type}" - - if all(key in intermediate_dict for key in [q_key, k_key, v_key]): - qkv_tensor = torch.cat([ - intermediate_dict[q_key], - intermediate_dict[k_key], - intermediate_dict[v_key] - ], dim=0) - out_dict[f"blocks.{block_idx}.attn.qkv.{param_type}"] = qkv_tensor - processed_qkv.add(block_param_key) - continue - else: - assert False - else: - out_dict[k] = v + blk, which, kind = m.groups() # block idx, 'q'/'k'/'v', 'weight'/'bias' + stash = buf.setdefault((blk, kind), {}) # Gather by block & param type + stash[which] = v + if len(stash) == 3: # Have q, k, v -> concatenate + out[f"blocks.{blk}.attn.qkv.{kind}"] = torch.cat( + [stash['q'], stash['k'], stash['v']], dim=0 + ) - return out_dict + return out def checkpoint_filter_fn( From 3a3d98bc38318c9cdaa0919f820b4dfa562f6f45 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 May 2025 11:34:19 -0700 Subject: [PATCH 3/5] Fix parallel blocks missing scale args and vitamin MLP --- timm/models/vision_transformer.py | 15 +++++++++++---- timm/models/vitamin.py | 3 ++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index a843f13980..fd4ed476eb 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -64,7 +64,7 @@ def __init__( num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, - scale_attn_norm: bool = False, + scale_norm: bool = False, proj_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0., @@ -80,7 +80,7 @@ def __init__( self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.norm = norm_layer(dim) if scale_attn_norm else nn.Identity() + self.norm = norm_layer(dim) if scale_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) @@ -151,7 +151,7 @@ def __init__( num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, - scale_attn_norm=scale_attn_norm, + scale_norm=scale_attn_norm, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -205,7 +205,7 @@ def __init__( num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, - scale_attn_norm=scale_attn_norm, + scale_norm=scale_attn_norm, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -253,6 +253,8 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., @@ -264,6 +266,7 @@ def __init__( ) -> None: super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' + assert not scale_attn_norm and not scale_mlp_norm, 'Scale norms not supported' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 @@ -348,6 +351,8 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, init_values: Optional[float] = None, proj_drop: float = 0., @@ -369,6 +374,7 @@ def __init__( num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + scale_norm=scale_attn_norm, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -383,6 +389,7 @@ def __init__( dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + norm_layer=norm_layer if scale_mlp_norm else None, bias=proj_bias, drop=proj_drop, )), diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index f8e2a9a1e0..7d023e0d31 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -259,11 +259,12 @@ def __init__( in_features, hidden_features, act_layer = 'gelu', + norm_layer = None, bias = True, drop = 0.0, ): super().__init__() - norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6) + norm_layer = partial(get_norm_layer(norm_layer or 'layernorm'), eps=1e-6) self.norm = norm_layer(in_features) self.w0 = nn.Linear(in_features, hidden_features, bias=bias) From 1be79999933da1b2c0c7feb751d18ebcb0ecda09 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 30 May 2025 14:36:41 -0700 Subject: [PATCH 4/5] Upload beit3 weights to hub, add pretrain weights --- timm/models/vision_transformer.py | 35 ++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index fd4ed476eb..7bc7026c08 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -2463,23 +2463,43 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: 'test_vit4.r160_in1k': _cfg( input_size=(3, 160, 160), crop_pct=0.95), - # BEiT3 models (remapped to VisionTransformer with scale_norm=True) + # BEiT3 models (remapped to VisionTransformer with scale_attn_norm=True, scale_mlp_norm=True) 'beit3_base_patch16_224.in22k_ft_in1k': _cfg( - url='https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224_in1k.pth', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), - 'beit3_base_patch16_224.in22k_indomain_ft_in1k': _cfg( - url='https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224_in1k.pth', + 'beit3_base_patch16_224.indomain_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), 'beit3_large_patch16_224.in22k_ft_in1k': _cfg( - url='https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224_in1k.pth', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), - 'beit3_large_patch16_224.in22k_indomain_ft_in1k': _cfg( - url='https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224_in1k.pth', + 'beit3_large_patch16_224.indomain_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), 'beit3_giant_patch14_224.untrained': _cfg( url='', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), 'beit3_giant_patch14_336.untrained': _cfg( url='', input_size=(3, 336, 336), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_base_patch16_224.pt': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0, + num_classes=0, + ), + 'beit3_base_patch16_224.indomain_pt': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0, + num_classes=0, + ), + 'beit3_large_patch16_224.pt': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0, + num_classes=0, + ), + 'beit3_large_patch16_224.indomain_pt': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0, + num_classes=0, + ), } _quick_gelu_cfgs = [n for n, c in default_cfgs.items() if c.get('notes', ()) and 'quickgelu' in c['notes'][0]] @@ -3728,7 +3748,6 @@ def vit_giantopt_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> V return model - @register_model def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( From 211cf907210dc84c2923960c4bc96c3a099866e8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 30 May 2025 15:11:51 -0700 Subject: [PATCH 5/5] Imports getting unwieldy in vision_transformer.py --- timm/models/vision_transformer.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 7bc7026c08..594415493b 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -39,11 +39,30 @@ import torch.nn.functional as F from torch.jit import Final -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ +from timm.data import ( + IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, + IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, LayerNorm, RmsNorm, PatchDropout, \ - SwiGLUPacked, SwiGLU, trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ - get_act_layer, get_norm_layer, LayerType +) +from timm.layers import ( + PatchEmbed, + Mlp, + DropPath, + AttentionPoolLatent, + LayerNorm, + RmsNorm, + PatchDropout, + SwiGLUPacked, + SwiGLU, + trunc_normal_, + lecun_normal_, + resample_patch_embed, + resample_abs_pos_embed, + use_fused_attn, + get_act_layer, + get_norm_layer, + LayerType, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv