3838from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD , \
3939 OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
4040from timm .layers import PatchEmbed , Mlp , DropPath , trunc_normal_ , lecun_normal_ , resample_patch_embed , \
41- resample_abs_pos_embed , RmsNorm , PatchDropout , use_fused_attn , SwiGLUPacked
41+ resample_abs_pos_embed , resample_abs_pos_embed_nhwc , RmsNorm , PatchDropout , use_fused_attn , SwiGLUPacked
4242from ._builder import build_model_with_cfg
4343from ._manipulate import named_apply , checkpoint_seq , adapt_input_conv
4444from ._registry import generate_default_cfgs , register_model , register_model_deprecations
@@ -383,6 +383,7 @@ class VisionTransformer(nn.Module):
383383 A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
384384 - https://arxiv.org/abs/2010.11929
385385 """
386+ dynamic_size : Final [bool ]
386387
387388 def __init__ (
388389 self ,
@@ -400,6 +401,7 @@ def __init__(
400401 init_values : Optional [float ] = None ,
401402 class_token : bool = True ,
402403 no_embed_class : bool = False ,
404+ dynamic_size : bool = False ,
403405 pre_norm : bool = False ,
404406 fc_norm : Optional [bool ] = None ,
405407 drop_rate : float = 0. ,
@@ -452,14 +454,23 @@ def __init__(
452454 self .num_features = self .embed_dim = embed_dim # num_features for consistency with other models
453455 self .num_prefix_tokens = 1 if class_token else 0
454456 self .no_embed_class = no_embed_class
457+ self .dynamic_size = dynamic_size
455458 self .grad_checkpointing = False
456459
460+ embed_args = {}
461+ if dynamic_size :
462+ embed_args .update (dict (
463+ strict_img_size = False ,
464+ flatten = False , # flatten deferred until after pos embed
465+ output_fmt = 'NHWC' ,
466+ ))
457467 self .patch_embed = embed_layer (
458468 img_size = img_size ,
459469 patch_size = patch_size ,
460470 in_chans = in_chans ,
461471 embed_dim = embed_dim ,
462472 bias = not pre_norm , # disable bias if pre-norm is used (e.g. CLIP)
473+ ** embed_args ,
463474 )
464475 num_patches = self .patch_embed .num_patches
465476
@@ -546,18 +557,24 @@ def reset_classifier(self, num_classes: int, global_pool=None):
546557 self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
547558
548559 def _pos_embed (self , x ):
560+ if self .dynamic_size :
561+ B , H , W , C = x .shape
562+ pos_embed = resample_abs_pos_embed (self .pos_embed , (H , W ))
563+ x = x .view (B , - 1 , C )
564+ else :
565+ pos_embed = self .pos_embed
549566 if self .no_embed_class :
550567 # deit-3, updated JAX (big vision)
551568 # position embedding does not overlap with class token, add then concat
552- x = x + self . pos_embed
569+ x = x + pos_embed
553570 if self .cls_token is not None :
554571 x = torch .cat ((self .cls_token .expand (x .shape [0 ], - 1 , - 1 ), x ), dim = 1 )
555572 else :
556573 # original timm, JAX, and deit vit impl
557574 # pos_embed has entry for class token, concat then add
558575 if self .cls_token is not None :
559576 x = torch .cat ((self .cls_token .expand (x .shape [0 ], - 1 , - 1 ), x ), dim = 1 )
560- x = x + self . pos_embed
577+ x = x + pos_embed
561578 return self .pos_drop (x )
562579
563580 def _intermediate_layers (
0 commit comments