Skip to content

Commit f9a24fa

Browse files
authored
Merge pull request #1846 from seefun/master
add I-JEPA pretrained weight for ViT
2 parents 2d597b1 + c3f24a5 commit f9a24fa

File tree

2 files changed

+60
-4
lines changed

2 files changed

+60
-4
lines changed

timm/models/vision_transformer.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,17 @@ def _convert_dinov2(state_dict, model):
879879
return out_dict
880880

881881

882+
def _convert_ijepa(state_dict, model):
883+
out_dict = {}
884+
for k, v in state_dict['encoder'].items():
885+
if k.startswith('module.'):
886+
k = k[7:]
887+
if k.startswith('norm.'):
888+
k = 'fc_norm.' + k[5:]
889+
out_dict[k] = v
890+
return out_dict
891+
892+
882893
def checkpoint_filter_fn(
883894
state_dict,
884895
model,
@@ -896,7 +907,10 @@ def checkpoint_filter_fn(
896907
return _convert_openai_clip(state_dict, model)
897908

898909
if "mask_token" in state_dict:
899-
return _convert_dinov2(state_dict, model)
910+
state_dict = _convert_dinov2(state_dict, model)
911+
912+
if "encoder" in state_dict:
913+
state_dict = _convert_ijepa(state_dict, model)
900914

901915
for k, v in state_dict.items():
902916
if 'patch_embed.proj.weight' in k:
@@ -1437,6 +1451,27 @@ def _cfg(url='', **kwargs):
14371451
hf_hub_id='timm/',
14381452
license='cc-by-nc-4.0',
14391453
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
1454+
1455+
'vit_huge_patch14_224_ijepa.in1k': _cfg(
1456+
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
1457+
# hf_hub_id='timm/',
1458+
license='cc-by-nc-4.0',
1459+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
1460+
'vit_huge_patch14_224_ijepa.in22k': _cfg(
1461+
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar',
1462+
# hf_hub_id='timm/',
1463+
license='cc-by-nc-4.0',
1464+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
1465+
'vit_huge_patch16_448_ijepa.in1k': _cfg(
1466+
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar',
1467+
# hf_hub_id='timm/',
1468+
license='cc-by-nc-4.0',
1469+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
1470+
'vit_gigantic_patch16_224_ijepa.in22k': _cfg(
1471+
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar',
1472+
# hf_hub_id='timm/',
1473+
license='cc-by-nc-4.0',
1474+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
14401475
})
14411476

14421477

@@ -2031,6 +2066,30 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
20312066
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
20322067
return model
20332068

2069+
@register_model
2070+
def vit_huge_patch14_224_ijepa(pretrained=False, **kwargs) -> VisionTransformer:
2071+
""" ViT-Huge model (ViT-H/14) from `I-JEPA` - https://arxiv.org/abs/2301.08243
2072+
"""
2073+
model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg')
2074+
model = _create_vision_transformer('vit_huge_patch14_224_ijepa', pretrained=pretrained, **dict(model_args, **kwargs))
2075+
return model
2076+
2077+
@register_model
2078+
def vit_huge_patch16_448_ijepa(pretrained=False, **kwargs) -> VisionTransformer:
2079+
""" ViT-Huge model (ViT-H/16) from `I-JEPA` - https://arxiv.org/abs/2301.08243
2080+
"""
2081+
model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', img_size=448)
2082+
model = _create_vision_transformer('vit_huge_patch16_448_ijepa', pretrained=pretrained, **dict(model_args, **kwargs))
2083+
return model
2084+
2085+
@register_model
2086+
def vit_gigantic_patch16_224_ijepa(pretrained=False, **kwargs) -> VisionTransformer:
2087+
""" ViT-Gigantic (big-G) model (ViT-G/16) from `I-JEPA - https://arxiv.org/abs/2301.08243
2088+
"""
2089+
model_args = dict(patch_size=16, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
2090+
model = _create_vision_transformer(
2091+
'vit_gigantic_patch16_224_ijepa', pretrained=pretrained, **dict(model_args, **kwargs))
2092+
return model
20342093

20352094
register_model_deprecations(__name__, {
20362095
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',

timm/models/vision_transformer_sam.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,3 @@ def samvit_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
605605
model = _create_vision_transformer(
606606
'samvit_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
607607
return model
608-
609-
# TODO:
610-
# support any input size, now only 1024 x 1024 (pretrained)

0 commit comments

Comments
 (0)