@@ -438,6 +438,7 @@ def __init__(
438438 no_embed_class : bool = False ,
439439 reg_tokens : int = 0 ,
440440 pre_norm : bool = False ,
441+ final_norm : bool = True ,
441442 fc_norm : Optional [bool ] = None ,
442443 dynamic_img_size : bool = False ,
443444 dynamic_img_pad : bool = False ,
@@ -471,7 +472,9 @@ def __init__(
471472 class_token: Use class token.
472473 no_embed_class: Don't include position embeddings for class (or reg) tokens.
473474 reg_tokens: Number of register tokens.
474- fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
475+ pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
476+ final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
477+ fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
475478 drop_rate: Head dropout rate.
476479 pos_drop_rate: Position embedding dropout rate.
477480 attn_drop_rate: Attention dropout rate.
@@ -554,7 +557,7 @@ def __init__(
554557 for i in range (depth )])
555558 self .feature_info = [
556559 dict (module = f'blocks.{ i } ' , num_chs = embed_dim , reduction = reduction ) for i in range (depth )]
557- self .norm = norm_layer (embed_dim ) if not use_fc_norm else nn .Identity ()
560+ self .norm = norm_layer (embed_dim ) if final_norm and not use_fc_norm else nn .Identity ()
558561
559562 # Classifier Head
560563 if global_pool == 'map' :
@@ -566,7 +569,7 @@ def __init__(
566569 )
567570 else :
568571 self .attn_pool = None
569- self .fc_norm = norm_layer (embed_dim ) if use_fc_norm else nn .Identity ()
572+ self .fc_norm = norm_layer (embed_dim ) if final_norm and use_fc_norm else nn .Identity ()
570573 self .head_drop = nn .Dropout (drop_rate )
571574 self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
572575
@@ -2051,6 +2054,12 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
20512054 'vit_so150m_patch16_reg4_map_256.untrained' : _cfg (
20522055 input_size = (3 , 256 , 256 )),
20532056
2057+ 'vit_intern300m_patch14_448.ogvl_dist' : _cfg (
2058+ hf_hub_id = 'timm/' ,
2059+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ,
2060+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 , num_classes = 0 ,
2061+ ),
2062+
20542063 'test_vit.r160_in1k' : _cfg (
20552064 hf_hub_id = 'timm/' ,
20562065 input_size = (3 , 160 , 160 ), crop_pct = 0.95 ),
@@ -2091,7 +2100,7 @@ def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs)
20912100 _filter_fn = checkpoint_filter_fn
20922101
20932102 # FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln?
2094- strict = True
2103+ strict = kwargs . pop ( 'pretrained_strict' , True )
20952104 if 'siglip' in variant and kwargs .get ('global_pool' , None ) != 'map' :
20962105 strict = False
20972106
@@ -3298,6 +3307,17 @@ def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
32983307 return model
32993308
33003309
3310+ @register_model
3311+ def vit_intern300m_patch14_448 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
3312+ model_args = dict (
3313+ patch_size = 14 , embed_dim = 1024 , depth = 24 , num_heads = 16 ,
3314+ init_values = 0.1 , final_norm = False , dynamic_img_size = True ,
3315+ )
3316+ model = _create_vision_transformer (
3317+ 'vit_intern300m_patch14_448' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
3318+ return model
3319+
3320+
33013321@register_model
33023322def test_vit (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
33033323 """ ViT Test
0 commit comments