@@ -428,6 +428,7 @@ def __init__(
428428 act_layer : Optional [LayerType ] = None ,
429429 block_fn : Type [nn .Module ] = Block ,
430430 mlp_layer : Type [nn .Module ] = Mlp ,
431+ repr_size = False ,
431432 ) -> None :
432433 """
433434 Args:
@@ -536,6 +537,14 @@ def __init__(
536537 )
537538 else :
538539 self .attn_pool = None
540+ if repr_size :
541+ repr_size = self .embed_dim if isinstance (repr_size , bool ) else repr_size
542+ self .repr = nn .Sequential (nn .Linear (self .embed_dim , repr_size ), nn .Tanh ())
543+ embed_dim = repr_size
544+ print (self .repr )
545+ else :
546+ self .repr = nn .Identity ()
547+
539548 self .fc_norm = norm_layer (embed_dim ) if use_fc_norm else nn .Identity ()
540549 self .head_drop = nn .Dropout (drop_rate )
541550 self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
@@ -752,6 +761,7 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso
752761 x = x [:, self .num_prefix_tokens :].mean (dim = 1 )
753762 elif self .global_pool :
754763 x = x [:, 0 ] # class token
764+ x = self .repr (x )
755765 x = self .fc_norm (x )
756766 x = self .head_drop (x )
757767 return x if pre_logits else self .head (x )
@@ -1790,23 +1800,40 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
17901800 license = 'mit' ,
17911801 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , num_classes = 512 ),
17921802
1793- 'vit_wee_patch16_reg1_gap_256' : _cfg (
1803+ 'vit_wee_patch16_reg1_gap_256.sbb_in1k ' : _cfg (
17941804 #file='',
17951805 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1796- 'vit_little_patch16_reg4_gap_256' : _cfg (
1797- #file='',
1806+ 'vit_pwee_patch16_reg1_gap_256.sbb_in1k' : _cfg (
1807+ file = './vit_pwee-in1k-8.pth' ,
1808+ input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1809+ 'vit_little_patch16_reg4_gap_256.sbb_in1k' : _cfg (
1810+ file = 'vit_little_patch16-in1k-8a.pth' ,
1811+ input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1812+ 'vit_medium_patch16_reg1_gap_256.sbb_in1k' : _cfg (
1813+ file = 'vit_medium_gap1-in1k-20231118-8.pth' ,
1814+ input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1815+ 'vit_medium_patch16_reg4_gap_256.sbb_in1k' : _cfg (
1816+ file = 'vit_medium_gap4-in1k-20231115-8.pth' ,
1817+ input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1818+ 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k' : _cfg (
1819+ file = 'vit_mp_patch16_reg4-in1k-5a.pth' ,
1820+ input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1821+ 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k' : _cfg (
1822+ file = 'vit_mp_patch16_reg4-in12k-8.pth' ,
1823+ num_classes = 11821 ,
17981824 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1799- 'vit_medium_patch16_reg1_gap_256 ' : _cfg (
1800- # file='vit_medium_gap1 -in1k-20231118 -8.pth',
1825+ 'vit_betwixt_patch16_reg1_gap_256.sbb_in1k ' : _cfg (
1826+ file = 'vit_betwixt_gap1 -in1k-20231121 -8.pth' ,
18011827 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1802- 'vit_medium_patch16_reg4_gap_256 ' : _cfg (
1803- # file='vit_medium_gap4- in1k-20231115-8 .pth',
1828+ 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k ' : _cfg (
1829+ file = 'vit_betwixt_patch16_reg4-ft- in1k-8b .pth' ,
18041830 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1805- 'vit_betwixt_patch16_reg1_gap_256 ' : _cfg (
1806- # file='vit_betwixt_gap1 -in1k-20231121 -8.pth',
1831+ 'vit_betwixt_patch16_reg4_gap_256.sbb_in1k ' : _cfg (
1832+ file = 'vit_betwixt_gap4 -in1k-20231106 -8.pth' ,
18071833 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1808- 'vit_betwixt_patch16_reg4_gap_256' : _cfg (
1809- #file='vit_betwixt_gap4-in1k-20231106-8.pth',
1834+ 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k' : _cfg (
1835+ file = 'vit_betwixt_gap4-in12k-8.pth' ,
1836+ num_classes = 11821 ,
18101837 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
18111838 'vit_base_patch16_reg4_gap_256' : _cfg (
18121839 input_size = (3 , 256 , 256 )),
@@ -1906,6 +1933,14 @@ def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransform
19061933 return model
19071934
19081935
1936+ @register_model
1937+ def vit_small_patch16_gap_224 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
1938+ """ ViT-Small (ViT-S/16)
1939+ """
1940+ model_args = dict (patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , global_pool = 'avg' , class_token = False , repr_size = True )
1941+ model = _create_vision_transformer ('vit_small_patch16_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1942+ return model
1943+
19091944@register_model
19101945def vit_small_patch16_384 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
19111946 """ ViT-Small (ViT-S/16)
@@ -2755,10 +2790,21 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
27552790def vit_wee_patch16_reg1_gap_256 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
27562791 model_args = dict (
27572792 patch_size = 16 , embed_dim = 256 , depth = 14 , num_heads = 4 , init_values = 1e-5 , mlp_ratio = 5 ,
2793+ class_token = False , no_embed_class = True , reg_tokens = 1 , global_pool = 'avg' ,
2794+ )
2795+ model = _create_vision_transformer (
2796+ 'vit_wee_patch16_reg1_gap_256' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2797+ return model
2798+
2799+
2800+ @register_model
2801+ def vit_pwee_patch16_reg1_gap_256 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2802+ model_args = dict (
2803+ patch_size = 16 , embed_dim = 256 , depth = 16 , num_heads = 4 , init_values = 1e-5 , mlp_ratio = 5 ,
27582804 class_token = False , no_embed_class = True , reg_tokens = 1 , global_pool = 'avg' , block_fn = ParallelScalingBlock ,
27592805 )
27602806 model = _create_vision_transformer (
2761- 'vit_medium_patch16_reg1_gap_256 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2807+ 'vit_pwee_patch16_reg1_gap_256 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
27622808 return model
27632809
27642810
@@ -2769,7 +2815,7 @@ def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
27692815 class_token = False , no_embed_class = True , reg_tokens = 4 , global_pool = 'avg' ,
27702816 )
27712817 model = _create_vision_transformer (
2772- 'vit_medium_patch16_reg1_gap_256 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2818+ 'vit_little_patch16_reg4_gap_256 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
27732819 return model
27742820
27752821
@@ -2795,6 +2841,17 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
27952841 return model
27962842
27972843
2844+ @register_model
2845+ def vit_mediumd_patch16_reg4_gap_256 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2846+ model_args = dict (
2847+ patch_size = 16 , embed_dim = 512 , depth = 20 , num_heads = 8 , init_values = 1e-5 ,
2848+ class_token = False , no_embed_class = True , reg_tokens = 4 , global_pool = 'avg' ,
2849+ )
2850+ model = _create_vision_transformer (
2851+ 'vit_mediumd_patch16_reg4_gap_256' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2852+ return model
2853+
2854+
27982855@register_model
27992856def vit_betwixt_patch16_reg1_gap_256 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
28002857 model_args = dict (
0 commit comments