@@ -367,6 +367,7 @@ def __init__(
367367 use_abs_pos_emb : bool = True ,
368368 use_rot_pos_emb : bool = False ,
369369 use_post_norm : bool = False ,
370+ dynamic_size : bool = False ,
370371 ref_feat_shape : Optional [Union [Tuple [int , int ], int ]] = None ,
371372 head_init_scale : float = 0.001 ,
372373 ):
@@ -406,13 +407,19 @@ def __init__(
406407 self .global_pool = global_pool
407408 self .num_features = self .embed_dim = embed_dim # num_features for consistency with other models
408409 self .num_prefix_tokens = 1 if class_token else 0
410+ self .dynamic_size = dynamic_size
409411 self .grad_checkpointing = False
410412
413+ embed_args = {}
414+ if dynamic_size :
415+ # flatten deferred until after pos embed
416+ embed_args .update (dict (strict_img_size = False , output_fmt = 'NHWC' ))
411417 self .patch_embed = PatchEmbed (
412418 img_size = img_size ,
413419 patch_size = patch_size ,
414420 in_chans = in_chans ,
415421 embed_dim = embed_dim ,
422+ ** embed_args ,
416423 )
417424 num_patches = self .patch_embed .num_patches
418425
@@ -435,7 +442,7 @@ def __init__(
435442 self .rope = RotaryEmbeddingCat (
436443 embed_dim // num_heads ,
437444 in_pixels = False ,
438- feat_shape = self .patch_embed .grid_size ,
445+ feat_shape = None if dynamic_size else self .patch_embed .grid_size ,
439446 ref_feat_shape = ref_feat_shape ,
440447 )
441448 else :
@@ -519,30 +526,44 @@ def reset_classifier(self, num_classes, global_pool=None):
519526 self .global_pool = global_pool
520527 self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
521528
522- def forward_features (self , x ):
523- x = self .patch_embed (x )
529+ def _pos_embed (self , x ) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
530+ if self .dynamic_size :
531+ B , H , W , C = x .shape
532+ if self .pos_embed is not None :
533+ pos_embed = resample_abs_pos_embed (
534+ self .pos_embed ,
535+ (H , W ),
536+ num_prefix_tokens = self .num_prefix_tokens ,
537+ )
538+ else :
539+ pos_embed = None
540+ x = x .view (B , - 1 , C )
541+ rot_pos_embed = self .rope .get_embed (shape = (H , W )) if self .rope is not None else None
542+ else :
543+ pos_embed = self .pos_embed
544+ rot_pos_embed = self .rope .get_embed () if self .rope is not None else None
524545
525546 if self .cls_token is not None :
526547 x = torch .cat ((self .cls_token .expand (x .shape [0 ], - 1 , - 1 ), x ), dim = 1 )
527-
528- # apply abs position embedding
529- if self .pos_embed is not None :
530- x = x + self .pos_embed
548+ if pos_embed is not None :
549+ x = x + pos_embed
531550 x = self .pos_drop (x )
532551
533552 # obtain shared rotary position embedding and apply patch dropout
534- rot_pos_embed = self .rope .get_embed () if self .rope is not None else None
535553 if self .patch_drop is not None :
536554 x , keep_indices = self .patch_drop (x )
537555 if rot_pos_embed is not None and keep_indices is not None :
538556 rot_pos_embed = apply_keep_indices_nlc (x , rot_pos_embed , keep_indices )
557+ return x , rot_pos_embed
539558
559+ def forward_features (self , x ):
560+ x = self .patch_embed (x )
561+ x , rot_pos_embed = self ._pos_embed (x )
540562 for blk in self .blocks :
541563 if self .grad_checkpointing and not torch .jit .is_scripting ():
542564 x = checkpoint (blk , x , rope = rot_pos_embed )
543565 else :
544566 x = blk (x , rope = rot_pos_embed )
545-
546567 x = self .norm (x )
547568 return x
548569
0 commit comments