Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,17 @@ def _convert_dinov2(state_dict, model):
return out_dict


def _convert_ijepa(state_dict, model):
out_dict = {}
for k, v in state_dict['encoder'].items():
if k.startswith('module.'):
k = k[7:]
if k.startswith('norm.'):
k = 'fc_norm.' + k[5:]
out_dict[k] = v
return out_dict


def checkpoint_filter_fn(
state_dict,
model,
Expand All @@ -896,7 +907,10 @@ def checkpoint_filter_fn(
return _convert_openai_clip(state_dict, model)

if "mask_token" in state_dict:
return _convert_dinov2(state_dict, model)
state_dict = _convert_dinov2(state_dict, model)

if "encoder" in state_dict:
state_dict = _convert_ijepa(state_dict, model)

for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
Expand Down Expand Up @@ -1437,6 +1451,27 @@ def _cfg(url='', **kwargs):
hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),

'vit_huge_patch14_224_ijepa.in1k': _cfg(
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_huge_patch14_224_ijepa.in22k': _cfg(
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_huge_patch16_448_ijepa.in1k': _cfg(
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_gigantic_patch16_224_ijepa.in22k': _cfg(
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
})


Expand Down Expand Up @@ -2031,6 +2066,30 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model

@register_model
def vit_huge_patch14_224_ijepa(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/14) from `I-JEPA` - https://arxiv.org/abs/2301.08243
"""
model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg')
model = _create_vision_transformer('vit_huge_patch14_224_ijepa', pretrained=pretrained, **dict(model_args, **kwargs))
return model

@register_model
def vit_huge_patch16_448_ijepa(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/16) from `I-JEPA` - https://arxiv.org/abs/2301.08243
"""
model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', img_size=448)
model = _create_vision_transformer('vit_huge_patch16_448_ijepa', pretrained=pretrained, **dict(model_args, **kwargs))
return model

@register_model
def vit_gigantic_patch16_224_ijepa(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Gigantic (big-G) model (ViT-G/16) from `I-JEPA - https://arxiv.org/abs/2301.08243
"""
model_args = dict(patch_size=16, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
model = _create_vision_transformer(
'vit_gigantic_patch16_224_ijepa', pretrained=pretrained, **dict(model_args, **kwargs))
return model

register_model_deprecations(__name__, {
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
Expand Down
3 changes: 0 additions & 3 deletions timm/models/vision_transformer_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,3 @@ def samvit_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
model = _create_vision_transformer(
'samvit_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
return model

# TODO:
# support any input size, now only 1024 x 1024 (pretrained)