From 7aeebf20e2eb2fd7a28177171a6f4b9f75d5679f Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Tue, 6 May 2025 01:29:55 +0800 Subject: [PATCH 1/4] add BEIT3 --- timm/models/__init__.py | 1 + timm/models/beit3.py | 492 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 493 insertions(+) create mode 100644 timm/models/beit3.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 238c1ccca5..557ee50236 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -1,4 +1,5 @@ from .beit import * +from .beit3 import * from .byoanet import * from .byobnet import * from .cait import * diff --git a/timm/models/beit3.py b/timm/models/beit3.py new file mode 100644 index 0000000000..5f41c08746 --- /dev/null +++ b/timm/models/beit3.py @@ -0,0 +1,492 @@ +""" BEiT3 +Paper: `Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks` + - https://arxiv.org/abs/2208.10442 + - https://openaccess.thecvf.com/content/CVPR2023/papers/Wang_Image_as_a_Foreign_Language_BEiT_Pretraining_for_Vision_and_CVPR_2023_paper.pdf + +Model from official source: + - https://github.com/microsoft/unilm/tree/master/beit3 + - https://github.com/microsoft/torchscale/blob/main/torchscale/model/BEiT3.py + +@inproceedings{beit3, + title={Image as a foreign language: {BEiT} pretraining for vision and vision-language tasks}, + author={Wenhui Wang and Hangbo Bao and Li Dong and Johan Bjorck and Zhiliang Peng and Qiang Liu and Kriti Aggarwal and Owais Khan Mohammed and Saksham Singhal and Subhojit Som and Furu Wei}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + year={2023} +} +@InProceedings{Wang_2023_CVPR, + author = {Wang, Wenhui and Bao, Hangbo and Dong, Li and Bjorck, Johan and Peng, Zhiliang and Liu, Qiang and Aggarwal, Kriti and Mohammed, Owais Khan and Singhal, Saksham and Som, Subhojit and Wei, Furu}, + title = {Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2023}, + pages = {19175-19186} +} + +Original implementation by Wenhui Wang et al., +adapted for timm by Ryan Hou and Ross Wightman. + +At this point only the 1k fine-tuned classification weights and model configs have been added, +see original source above for pre-training models and procedure. + +Adapted from https://github.com/microsoft/torchscale/blob/main/torchscale/model/BEiT3.py, original copyright below +""" +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +# -------------------------------------------------------- +# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) +# Github source: https://github.com/microsoft/unilm/tree/master/beit3 +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# --------------------------------------------------------' + +import math +from functools import partial +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import PatchEmbed, Mlp, LayerNorm, DropPath, trunc_normal_, LayerType + +from ._builder import build_model_with_cfg +from ._features import feature_take_indices +from ._manipulate import checkpoint +from ._registry import generate_default_cfgs, register_model + +__all__ = ['BEiT3'] + + +class PositionalEmbedding(nn.Embedding): + """ + Reference from: + https://github.com/microsoft/torchscale/blob/main/torchscale/component/embedding.py#L99-L119 + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.embedding( + torch.arange(2, self.num_embeddings).long().unsqueeze(0).to(x.device), + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + +class Attention(nn.Module): + """ + Reference from: + https://github.com/microsoft/torchscale/blob/main/torchscale/component/multihead_attention.py#L20-L171 + """ + def __init__( + self, + dim: int, + num_heads: int, + drop_rate: float = 0., + norm_layer: LayerType = partial(LayerNorm, eps=1e-5) + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scaling = self.head_dim ** -0.5 + + self.k_proj = nn.Linear(dim, dim) + self.v_proj = nn.Linear(dim, dim) + self.q_proj = nn.Linear(dim, dim) + self.out_proj = nn.Linear(dim, dim) + self.inner_attn_ln = norm_layer(dim) + self.attn_drop = nn.Dropout(drop_rate) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + q *= self.scaling + + q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) + q = q.reshape(B * self.num_heads, N, self.head_dim) + k = k.reshape(B * self.num_heads, N, self.head_dim) + v = v.reshape(B * self.num_heads, N, self.head_dim) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( + attn_weights + ) + attn_probs = self.attn_drop(attn_weights) + + attn = torch.bmm(attn_probs, v) + attn = attn.transpose(0, 1).reshape(N, B, C).transpose(0, 1) + attn = self.inner_attn_ln(attn) + attn = self.out_proj(attn) + return attn + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + drop_rate: float = 0., + drop_path: float = 0., + attn_drop: float = 0., + act_layer: LayerType = nn.GELU, + norm_layer: LayerType = partial(LayerNorm, eps=1e-5), + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, drop_rate=attn_drop, norm_layer=norm_layer) + self.attn_drop = nn.Dropout(drop_rate) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + norm_layer=norm_layer, + drop=drop_rate + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path(self.attn_drop(self.attn(self.norm1(x)))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class BEiT3(nn.Module): + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + act_layer: LayerType = nn.GELU, + norm_layer: LayerType = partial(LayerNorm, eps=1e-5), + head_init_scale: float = 0.001, + ): + super().__init__() + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models + self.num_prefix_tokens = 1 + self.grad_checkpointing = False + + # vision_embed + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + # encoder + self.pos_embed = PositionalEmbedding(num_patches + 3, embed_dim) + self.pos_drop = nn.Dropout(drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop_rate=drop_rate, + drop_path=dpr[i], + attn_drop=attn_drop_rate, + act_layer=act_layer, + norm_layer=norm_layer, + ) + for i in range(depth)]) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] + + # class_head + use_fc_norm = self.global_pool == 'avg' + self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + trunc_normal_(self.cls_token, std=.02) + + self.fix_init_weight(depth) + if isinstance(self.head, nn.Linear): + trunc_normal_(self.head.weight, std=.02) + self.head.weight.data.mul_(head_init_scale) + self.head.bias.data.mul_(head_init_scale) + + def fix_init_weight(self, depth: int): + init_scale = math.sqrt(math.log(depth * 2)) + for name, p in self.named_parameters(): + if ( + "fc1" in name + or "fc2" in name + or "out_proj" in name + or "v_proj" in name + ): + p.data.mul_(init_scale) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + return {'pos_embed', 'cls_token'} + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + matcher = dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))], + ) + return matcher + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if an int, if is a sequence, select by matching indices + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + + # forward pass + B, _, height, width = x.shape + + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.pos_embed(x) + x = self.pos_drop(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if self.num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] + intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] + + if reshape: + # reshape to BCHW output format + H, W = self.patch_embed.dynamic_feat_size((height, width)) + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + if not torch.jit.is_scripting() and return_prefix_tokens: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.fc_norm = nn.Identity() + self.reset_classifier(0, '') + return take_indices + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.pos_embed(x) + x = self.pos_drop(x) + + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) + x = self.norm(x) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + if self.global_pool: + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]: + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'paper_ids': 'arXiv:2208.10442', + 'paper_name': 'Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks', + 'origin_url': 'https://github.com/microsoft/unilm/tree/master/beit3', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'beit3_base_patch16_224.in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224_in1k.pth', + # hf_hub_id='timm/', + ), + 'beit3_base_patch16_224.indomain_in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224_in1k.pth', + # hf_hub_id='timm/', + ), + 'beit3_large_patch16_224.in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224_in1k.pth', + # hf_hub_id='timm/', + ), + 'beit3_large_patch16_224.indomain_in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224_in1k.pth', + # hf_hub_id='timm/', + ), +}) + + +def checkpoint_filter_fn( + state_dict: Dict[str, torch.Tensor], + model: BEiT3, +) -> Dict[str, torch.Tensor]: + if 'model' in state_dict: + state_dict = state_dict['model'] + + if 'patch_embed.proj.weight' in state_dict: + return state_dict + + state_dict.pop('beit3.text_embed.weight') + state_dict.pop('beit3.vision_embed.mask_token') + + out_dict = {} + + for k, v in state_dict.items(): + if '.B.' in k: + continue + elif 'vision_embed.cls_token' in k: + k = 'cls_token' + else: + k = k.replace('beit3.', '') + k = k.replace('embed_positions.', 'pos_embed.') + k = k.replace('vision_embed.', 'patch_embed.') + k = k.replace('encoder.', '') + k = k.replace('layers.', 'blocks.') + k = k.replace('ffn.', 'mlp.') + k = k.replace('ffn_layernorm.', 'norm.') + k = k.replace('self_attn.', 'attn.') + k = k.replace('self_attn_layer_norm.', 'norm1.') + k = k.replace('final_layer_norm.', 'norm2.') + k = k.replace('A.', '') + + out_dict[k] = v + + return out_dict + + +def _create_beit3(variant: str, pretrained: bool, **kwargs: Any) -> BEiT3: + out_indices = kwargs.pop('out_indices', 3) + model = build_model_with_cfg( + BEiT3, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) + return model + + +@register_model +def beit3_base_patch16_224(pretrained: bool = False, **kwargs: Any) -> BEiT3: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4) + model = _create_beit3('beit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def beit3_large_patch16_224(pretrained: bool = False, **kwargs: Any) -> BEiT3: + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4) + model = _create_beit3('beit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model From afe4375e778a9c60c3a2d265f73039c55a8e02a7 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Mon, 12 May 2025 00:13:52 +0800 Subject: [PATCH 2/4] update BEiT3 --- tests/test_models.py | 4 ++-- timm/models/beit3.py | 12 +++++------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index aa866ccdb6..eb95ad2228 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -56,13 +56,13 @@ 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest', 'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext', - 'davit', 'rdnet', 'convnext', 'pit' + 'davit', 'rdnet', 'convnext', 'pit', 'beit3', ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', - 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', + 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', 'beit3*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*', 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*', ] diff --git a/timm/models/beit3.py b/timm/models/beit3.py index 5f41c08746..99b5a1ef4d 100644 --- a/timm/models/beit3.py +++ b/timm/models/beit3.py @@ -86,7 +86,7 @@ def __init__( dim: int, num_heads: int, drop_rate: float = 0., - norm_layer: LayerType = partial(LayerNorm, eps=1e-5) + norm_layer: LayerType = partial(LayerNorm, eps=1e-5), ): super().__init__() self.num_heads = num_heads @@ -122,7 +122,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: attn_probs = self.attn_drop(attn_weights) attn = torch.bmm(attn_probs, v) - attn = attn.transpose(0, 1).reshape(N, B, C).transpose(0, 1) + attn = attn.view(B, self.num_heads, N, self.head_dim).transpose(1, 2) + attn = attn.reshape(B, N, C) attn = self.inner_attn_ln(attn) attn = self.out_proj(attn) return attn @@ -403,7 +404,7 @@ def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]: 'paper_ids': 'arXiv:2208.10442', 'paper_name': 'Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks', 'origin_url': 'https://github.com/microsoft/unilm/tree/master/beit3', - **kwargs + **kwargs, } @@ -427,10 +428,7 @@ def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]: }) -def checkpoint_filter_fn( - state_dict: Dict[str, torch.Tensor], - model: BEiT3, -) -> Dict[str, torch.Tensor]: +def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: BEiT3) -> Dict[str, torch.Tensor]: if 'model' in state_dict: state_dict = state_dict['model'] From b5a814e4c14b571dc4a90e9221a1dee2ef4472b4 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Mon, 12 May 2025 00:24:15 +0800 Subject: [PATCH 3/4] add giant model param --- timm/models/beit3.py | 58 ++++++++++++++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/timm/models/beit3.py b/timm/models/beit3.py index 99b5a1ef4d..1bc52b3282 100644 --- a/timm/models/beit3.py +++ b/timm/models/beit3.py @@ -9,12 +9,14 @@ @inproceedings{beit3, title={Image as a foreign language: {BEiT} pretraining for vision and vision-language tasks}, - author={Wenhui Wang and Hangbo Bao and Li Dong and Johan Bjorck and Zhiliang Peng and Qiang Liu and Kriti Aggarwal and Owais Khan Mohammed and Saksham Singhal and Subhojit Som and Furu Wei}, + author={Wenhui Wang and Hangbo Bao and Li Dong and Johan Bjorck and Zhiliang Peng and Qiang Liu and Kriti Aggarwal + and Owais Khan Mohammed and Saksham Singhal and Subhojit Som and Furu Wei}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, year={2023} } @InProceedings{Wang_2023_CVPR, - author = {Wang, Wenhui and Bao, Hangbo and Dong, Li and Bjorck, Johan and Peng, Zhiliang and Liu, Qiang and Aggarwal, Kriti and Mohammed, Owais Khan and Singhal, Saksham and Som, Subhojit and Wei, Furu}, + author = {Wang, Wenhui and Bao, Hangbo and Dong, Li and Bjorck, Johan and Peng, Zhiliang and Liu, Qiang and Aggarwal, + Kriti and Mohammed, Owais Khan and Singhal, Saksham and Som, Subhojit and Wei, Furu}, title = {Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks}, booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, month = {June}, @@ -65,6 +67,7 @@ class PositionalEmbedding(nn.Embedding): https://github.com/microsoft/torchscale/blob/main/torchscale/component/embedding.py#L99-L119 """ def forward(self, x: torch.Tensor) -> torch.Tensor: + # being consistent with Fairseq, which starts from 2. return F.embedding( torch.arange(2, self.num_embeddings).long().unsqueeze(0).to(x.device), self.weight, @@ -108,22 +111,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: v = self.v_proj(x) q *= self.scaling + ## (B, N, C) >> (B, N, num_heads, head_dim) >> (B, num_heads, N, head_dim) q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) + + ## (B, num_heads, N, head_dim) >> (B * num_heads, N, head_dim) q = q.reshape(B * self.num_heads, N, self.head_dim) k = k.reshape(B * self.num_heads, N, self.head_dim) v = v.reshape(B * self.num_heads, N, self.head_dim) - attn_weights = torch.bmm(q, k.transpose(1, 2)) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( - attn_weights - ) + attn_weights = torch.bmm(q, k.transpose(1, 2)) # (B * num_heads, N, N) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights) attn_probs = self.attn_drop(attn_weights) + attn = torch.bmm(attn_probs, v) # (B * num_heads, N, head_dim) - attn = torch.bmm(attn_probs, v) - attn = attn.view(B, self.num_heads, N, self.head_dim).transpose(1, 2) - attn = attn.reshape(B, N, C) + ## (B * num_heads N, head_dim) >> (B, N, num_heads * head_dim) == (B, N, C) + attn = attn.view(B, self.num_heads, N, self.head_dim).transpose(1, 2).reshape(B, N, C) attn = self.inner_attn_ln(attn) attn = self.out_proj(attn) return attn @@ -409,26 +413,28 @@ def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]: default_cfgs = generate_default_cfgs({ - 'beit3_base_patch16_224.in1k': _cfg( + 'beit3_base_patch16_224.in22k_ft_in1k': _cfg( url='https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224_in1k.pth', # hf_hub_id='timm/', ), - 'beit3_base_patch16_224.indomain_in1k': _cfg( + 'beit3_base_patch16_224.in22k_indomain_ft_in1k': _cfg( url='https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224_in1k.pth', # hf_hub_id='timm/', ), - 'beit3_large_patch16_224.in1k': _cfg( + 'beit3_large_patch16_224.in22k_ft_in1k': _cfg( url='https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224_in1k.pth', # hf_hub_id='timm/', ), - 'beit3_large_patch16_224.indomain_in1k': _cfg( + 'beit3_large_patch16_224.in22k_indomain_ft_in1k': _cfg( url='https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224_in1k.pth', # hf_hub_id='timm/', ), + 'beit3_giant_patch14_224.untrained': _cfg(url=''), + 'beit3_giant_patch14_336.untrained': _cfg(url='', input_size=(3, 336, 336)), }) -def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: BEiT3) -> Dict[str, torch.Tensor]: +def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]: if 'model' in state_dict: state_dict = state_dict['model'] @@ -459,11 +465,11 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: BEiT3) -> D k = k.replace('A.', '') out_dict[k] = v - + return out_dict -def _create_beit3(variant: str, pretrained: bool, **kwargs: Any) -> BEiT3: +def _create_beit3(variant: str, pretrained: bool = False, **kwargs: Any) -> BEiT3: out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( BEiT3, variant, pretrained, @@ -488,3 +494,23 @@ def beit3_large_patch16_224(pretrained: bool = False, **kwargs: Any) -> BEiT3: patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4) model = _create_beit3('beit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model + + +@register_model +def beit3_giant_patch14_224(pretrained: bool = False, **kwargs: Any) -> BEiT3: + ## FFN inner hidden size = embed_dim * mlp_ratio + ## 6144 = int(1408 * 4.3637) + model_args = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637) + model = _create_beit3('beit3_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def beit3_giant_patch14_336(pretrained: bool = False, **kwargs: Any) -> BEiT3: + ## FFN inner hidden size = embed_dim * mlp_ratio + ## 6144 = int(1408 * 4.3637) + model_args = dict( + img_size=336, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637) + model = _create_beit3('beit3_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model From 55e52c45ef34123e87052e0531acd6360b37d7d8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 May 2025 09:50:17 -0700 Subject: [PATCH 4/4] Initial run through remapping beit3 -> vision_transformer.py --- timm/models/__init__.py | 2 +- timm/models/vision_transformer.py | 209 ++++++++++++++++++++++++++++++ 2 files changed, 210 insertions(+), 1 deletion(-) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 557ee50236..b356e12991 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -1,5 +1,5 @@ from .beit import * -from .beit3 import * +#from .beit3 import * from .byoanet import * from .byobnet import * from .cait import * diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 3c7b9a2277..6f5e144d15 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -64,6 +64,7 @@ def __init__( num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, proj_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0., @@ -79,6 +80,7 @@ def __init__( self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.norm = norm_layer(dim) if scale_attn_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) @@ -102,6 +104,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) + x = self.norm(x) x = self.proj(x) x = self.proj_drop(x) return x @@ -130,6 +133,8 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., @@ -146,6 +151,7 @@ def __init__( num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + scale_attn_norm=scale_attn_norm, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -159,6 +165,7 @@ def __init__( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + norm_layer=norm_layer if scale_mlp_norm else None, bias=proj_bias, drop=proj_drop, ) @@ -179,6 +186,8 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., @@ -196,6 +205,7 @@ def __init__( num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + scale_attn_norm=scale_attn_norm, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -208,6 +218,7 @@ def __init__( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + norm_layer=norm_layer if scale_mlp_norm else None, bias=proj_bias, drop=proj_drop, ) @@ -443,6 +454,8 @@ def __init__( mlp_ratio: float = 4., qkv_bias: bool = True, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, init_values: Optional[float] = None, class_token: bool = True, @@ -563,6 +576,8 @@ def __init__( mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_norm=qk_norm, + scale_attn_norm=scale_attn_norm, + scale_mlp_norm=scale_mlp_norm, proj_bias=proj_bias, init_values=init_values, proj_drop=proj_drop_rate, @@ -1166,6 +1181,127 @@ def _convert_aimv2( return out_dict +def _convert_beit3( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, +) -> Dict[str, torch.Tensor]: + """Convert BEiT3 weights to standard VisionTransformer format. + + First applies BEiT3's own filtering (from multimodal to vision-only BEiT3 format), + then converts from BEiT3 format to standard VisionTransformer format. + """ + import re + + # Step 1: Apply BEiT3's own checkpoint filtering logic + # (equivalent to beit3.checkpoint_filter_fn) + if 'model' in state_dict: + state_dict = state_dict['model'] + + # If already processed, skip BEiT3 filtering + if 'patch_embed.proj.weight' in state_dict: + intermediate_dict = state_dict + else: + # Remove text and mask tokens (vision-only) + state_dict.pop('beit3.text_embed.weight', None) + state_dict.pop('beit3.vision_embed.mask_token', None) + + intermediate_dict = {} + + for k, v in state_dict.items(): + # Skip B branch weights (use only A branch) + if '.B.' in k: + continue + elif 'vision_embed.cls_token' in k: + k = 'cls_token' + else: + # Apply BEiT3's key transformations + k = k.replace('beit3.', '') + k = k.replace('embed_positions.', 'pos_embed.') + k = k.replace('vision_embed.', 'patch_embed.') + k = k.replace('encoder.', '') + k = k.replace('layers.', 'blocks.') + k = k.replace('ffn.', 'mlp.') + k = k.replace('ffn_layernorm.', 'norm.') + k = k.replace('self_attn.', 'attn.') + k = k.replace('self_attn_layer_norm.', 'norm1.') + k = k.replace('final_layer_norm.', 'norm2.') + k = k.replace('A.', '') # Remove A branch prefix + + intermediate_dict[k] = v + + # Step 2: Convert from BEiT3 format to VisionTransformer format + out_dict = {} + + for k, v in intermediate_dict.items(): + # Handle attention projections - convert separate q,k,v to fused qkv + if re.match(r"blocks\.(\d+)\.attn\.[qkv]_proj\.weight", k): + block_idx = re.search(r"blocks\.(\d+)", k).group(1) + proj_type = re.search(r"\.([qkv])_proj", k).group(1) + + # Collect all three projections for this block + q_key = f"blocks.{block_idx}.attn.q_proj.weight" + k_key = f"blocks.{block_idx}.attn.k_proj.weight" + v_key = f"blocks.{block_idx}.attn.v_proj.weight" + + if all(key in intermediate_dict for key in [q_key, k_key, v_key]): + # Only create qkv weight once when we encounter the first projection + if proj_type == 'q': + qkv_weight = torch.cat([ + intermediate_dict[q_key], + intermediate_dict[k_key], + intermediate_dict[v_key] + ], dim=0) + out_dict[f"blocks.{block_idx}.attn.qkv.weight"] = qkv_weight + # Skip k and v projections as they're handled with q + continue + else: + # Fallback if not all projections available + out_dict[k.replace('q_proj', 'qkv').replace('k_proj', 'qkv').replace('v_proj', 'qkv')] = v + + # Handle attention projection biases + elif re.match(r"blocks\.(\d+)\.attn\.[qkv]_proj\.bias", k): + block_idx = re.search(r"blocks\.(\d+)", k).group(1) + proj_type = re.search(r"\.([qkv])_proj", k).group(1) + + q_key = f"blocks.{block_idx}.attn.q_proj.bias" + k_key = f"blocks.{block_idx}.attn.k_proj.bias" + v_key = f"blocks.{block_idx}.attn.v_proj.bias" + + if all(key in intermediate_dict for key in [q_key, k_key, v_key]): + if proj_type == 'q': + qkv_bias = torch.cat([ + intermediate_dict[q_key], + intermediate_dict[k_key], + intermediate_dict[v_key] + ], dim=0) + out_dict[f"blocks.{block_idx}.attn.qkv.bias"] = qkv_bias + continue + else: + out_dict[k.replace('q_proj', 'qkv').replace('k_proj', 'qkv').replace('v_proj', 'qkv')] = v + + # Map inner attention LayerNorm to scale norm + elif 'attn.inner_attn_ln' in k: + out_dict[k.replace('inner_attn_ln', 'norm')] = v + + # Map out_proj to proj + elif 'attn.out_proj' in k: + out_dict[k.replace('out_proj', 'proj')] = v + elif 'attn.proj' in k: + out_dict[k] = v + + # Handle positional embedding - skip first 2 positions (BEiT3 starts from index 2) + elif k == 'pos_embed.weight': + # BEiT3 pos_embed.weight has shape [num_patches + 3, embed_dim] + # We want [1, num_patches + 1, embed_dim] for standard ViT (cls token + patches) + out_dict['pos_embed'] = v[2:].unsqueeze(0) # Skip first 2 positions, add batch dim + + # Pass through other weights unchanged + else: + out_dict[k] = v + + return out_dict + + def checkpoint_filter_fn( state_dict: Dict[str, torch.Tensor], model: VisionTransformer, @@ -1186,6 +1322,9 @@ def checkpoint_filter_fn( state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.') elif "mask_token" in state_dict: state_dict = _convert_dinov2(state_dict, model) + elif any('beit3.' in k for k in state_dict.keys()): + # BEiT3 model - multimodal checkpoint with beit3.* prefix + state_dict = _convert_beit3(state_dict, model) elif "encoder" in state_dict: # IJEPA, vit in an 'encoder' submodule state_dict = state_dict['encoder'] @@ -2377,6 +2516,24 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: input_size=(3, 160, 160), crop_pct=0.95), 'test_vit4.r160_in1k': _cfg( input_size=(3, 160, 160), crop_pct=0.95), + + # BEiT3 models (remapped to VisionTransformer with scale_norm=True) + 'beit3_base_patch16_224.in22k_ft_in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_base_patch16_224.in22k_indomain_ft_in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_large_patch16_224.in22k_ft_in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_large_patch16_224.in22k_indomain_ft_in1k': _cfg( + url='https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_giant_patch14_224.untrained': _cfg( + url='', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), + 'beit3_giant_patch14_336.untrained': _cfg( + url='', input_size=(3, 336, 336), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0), } _quick_gelu_cfgs = [n for n, c in default_cfgs.items() if c.get('notes', ()) and 'quickgelu' in c['notes'][0]] @@ -4035,6 +4192,58 @@ def test_vit4(pretrained: bool = False, **kwargs) -> VisionTransformer: return model +@register_model +def beit3_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Base model (ViT-Base size) with patch size 16x16. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg' + ) + model = _create_vision_transformer('beit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def beit3_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Large model (ViT-Large size) with patch size 16x16. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg' + ) + model = _create_vision_transformer('beit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def beit3_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Giant model with patch size 14x14. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg' + ) + model = _create_vision_transformer('beit3_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def beit3_giant_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ BEiT3 Giant model with patch size 14x14 and image size 336x336. + Remapped to VisionTransformer with scale_norm=True. + """ + model_args = dict( + img_size=336, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637, + scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg' + ) + model = _create_vision_transformer('beit3_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + register_model_deprecations(__name__, { 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k', 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',