@@ -428,7 +428,6 @@ def __init__(
428428 act_layer : Optional [LayerType ] = None ,
429429 block_fn : Type [nn .Module ] = Block ,
430430 mlp_layer : Type [nn .Module ] = Mlp ,
431- repr_size = False ,
432431 ) -> None :
433432 """
434433 Args:
@@ -537,14 +536,6 @@ def __init__(
537536 )
538537 else :
539538 self .attn_pool = None
540- if repr_size :
541- repr_size = self .embed_dim if isinstance (repr_size , bool ) else repr_size
542- self .repr = nn .Sequential (nn .Linear (self .embed_dim , repr_size ), nn .Tanh ())
543- embed_dim = repr_size
544- print (self .repr )
545- else :
546- self .repr = nn .Identity ()
547-
548539 self .fc_norm = norm_layer (embed_dim ) if use_fc_norm else nn .Identity ()
549540 self .head_drop = nn .Dropout (drop_rate )
550541 self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
@@ -761,7 +752,6 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso
761752 x = x [:, self .num_prefix_tokens :].mean (dim = 1 )
762753 elif self .global_pool :
763754 x = x [:, 0 ] # class token
764- x = self .repr (x )
765755 x = self .fc_norm (x )
766756 x = self .head_drop (x )
767757 return x if pre_logits else self .head (x )
@@ -1804,35 +1794,45 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
18041794 #file='',
18051795 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
18061796 'vit_pwee_patch16_reg1_gap_256.sbb_in1k' : _cfg (
1807- file = './vit_pwee-in1k-8.pth' ,
1797+ #file='./vit_pwee-in1k-8.pth',
1798+ hf_hub_id = 'timm/' ,
18081799 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
18091800 'vit_little_patch16_reg4_gap_256.sbb_in1k' : _cfg (
1810- file = 'vit_little_patch16-in1k-8a.pth' ,
1801+ #file='vit_little_patch16-in1k-8a.pth',
1802+ hf_hub_id = 'timm/' ,
18111803 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
18121804 'vit_medium_patch16_reg1_gap_256.sbb_in1k' : _cfg (
1813- file = 'vit_medium_gap1-in1k-20231118-8.pth' ,
1805+ #file='vit_medium_gap1-in1k-20231118-8.pth',
1806+ hf_hub_id = 'timm/' ,
18141807 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
18151808 'vit_medium_patch16_reg4_gap_256.sbb_in1k' : _cfg (
1816- file = 'vit_medium_gap4-in1k-20231115-8.pth' ,
1809+ #file='vit_medium_gap4-in1k-20231115-8.pth',
1810+ hf_hub_id = 'timm/' ,
18171811 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
18181812 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k' : _cfg (
1819- file = 'vit_mp_patch16_reg4-in1k-5a.pth' ,
1813+ #file='vit_mp_patch16_reg4-in1k-5a.pth',
1814+ hf_hub_id = 'timm/' ,
18201815 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
18211816 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k' : _cfg (
1822- file = 'vit_mp_patch16_reg4-in12k-8.pth' ,
1817+ #file='vit_mp_patch16_reg4-in12k-8.pth',
1818+ hf_hub_id = 'timm/' ,
18231819 num_classes = 11821 ,
18241820 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
18251821 'vit_betwixt_patch16_reg1_gap_256.sbb_in1k' : _cfg (
1826- file = 'vit_betwixt_gap1-in1k-20231121-8.pth' ,
1822+ #file='vit_betwixt_gap1-in1k-20231121-8.pth',
1823+ hf_hub_id = 'timm/' ,
18271824 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
18281825 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k' : _cfg (
1829- file = 'vit_betwixt_patch16_reg4-ft-in1k-8b.pth' ,
1826+ #file='vit_betwixt_patch16_reg4-ft-in1k-8b.pth',
1827+ hf_hub_id = 'timm/' ,
18301828 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
18311829 'vit_betwixt_patch16_reg4_gap_256.sbb_in1k' : _cfg (
1832- file = 'vit_betwixt_gap4-in1k-20231106-8.pth' ,
1830+ #file='vit_betwixt_gap4-in1k-20231106-8.pth',
1831+ hf_hub_id = 'timm/' ,
18331832 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
18341833 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k' : _cfg (
1835- file = 'vit_betwixt_gap4-in12k-8.pth' ,
1834+ #file='vit_betwixt_gap4-in12k-8.pth',
1835+ hf_hub_id = 'timm/' ,
18361836 num_classes = 11821 ,
18371837 input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
18381838 'vit_base_patch16_reg4_gap_256' : _cfg (
@@ -1933,14 +1933,6 @@ def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransform
19331933 return model
19341934
19351935
1936- @register_model
1937- def vit_small_patch16_gap_224 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
1938- """ ViT-Small (ViT-S/16)
1939- """
1940- model_args = dict (patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , global_pool = 'avg' , class_token = False , repr_size = True )
1941- model = _create_vision_transformer ('vit_small_patch16_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1942- return model
1943-
19441936@register_model
19451937def vit_small_patch16_384 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
19461938 """ ViT-Small (ViT-S/16)
0 commit comments