Skip to content

Commit 1f4512f

Browse files
committed
Support dynamic_resize in eva.py models
1 parent ea3519a commit 1f4512f

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

timm/layers/pos_embed_sincos.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -400,13 +400,12 @@ def __init__(
400400
temperature=temperature,
401401
step=1,
402402
)
403-
print(bands)
404403
self.register_buffer(
405404
'bands',
406405
bands,
407406
persistent=False,
408407
)
409-
self.embed = None
408+
self.pos_embed = None
410409
else:
411410
# cache full sin/cos embeddings if shape provided up front
412411
embeds = build_rotary_pos_embed(
@@ -425,17 +424,19 @@ def __init__(
425424
)
426425

427426
def get_embed(self, shape: Optional[List[int]] = None):
428-
if self.bands is not None:
427+
if self.bands is not None and shape is not None:
429428
# rebuild embeddings every call, use if target shape changes
430-
_assert(shape is not None, 'valid shape needed')
431429
embeds = build_rotary_pos_embed(
432430
shape,
433431
self.bands,
434432
in_pixels=self.in_pixels,
433+
ref_feat_shape=self.ref_feat_shape,
435434
)
436435
return torch.cat(embeds, -1)
437-
else:
436+
elif self.pos_embed is not None:
438437
return self.pos_embed
438+
else:
439+
assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands"
439440

440441
def forward(self, x):
441442
# assuming channel-first tensor where spatial dim are >= 2

timm/models/eva.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def __init__(
367367
use_abs_pos_emb: bool = True,
368368
use_rot_pos_emb: bool = False,
369369
use_post_norm: bool = False,
370+
dynamic_size: bool = False,
370371
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
371372
head_init_scale: float = 0.001,
372373
):
@@ -406,13 +407,19 @@ def __init__(
406407
self.global_pool = global_pool
407408
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
408409
self.num_prefix_tokens = 1 if class_token else 0
410+
self.dynamic_size = dynamic_size
409411
self.grad_checkpointing = False
410412

413+
embed_args = {}
414+
if dynamic_size:
415+
# flatten deferred until after pos embed
416+
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
411417
self.patch_embed = PatchEmbed(
412418
img_size=img_size,
413419
patch_size=patch_size,
414420
in_chans=in_chans,
415421
embed_dim=embed_dim,
422+
**embed_args,
416423
)
417424
num_patches = self.patch_embed.num_patches
418425

@@ -435,7 +442,7 @@ def __init__(
435442
self.rope = RotaryEmbeddingCat(
436443
embed_dim // num_heads,
437444
in_pixels=False,
438-
feat_shape=self.patch_embed.grid_size,
445+
feat_shape=None if dynamic_size else self.patch_embed.grid_size,
439446
ref_feat_shape=ref_feat_shape,
440447
)
441448
else:
@@ -519,30 +526,44 @@ def reset_classifier(self, num_classes, global_pool=None):
519526
self.global_pool = global_pool
520527
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
521528

522-
def forward_features(self, x):
523-
x = self.patch_embed(x)
529+
def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
530+
if self.dynamic_size:
531+
B, H, W, C = x.shape
532+
if self.pos_embed is not None:
533+
pos_embed = resample_abs_pos_embed(
534+
self.pos_embed,
535+
(H, W),
536+
num_prefix_tokens=self.num_prefix_tokens,
537+
)
538+
else:
539+
pos_embed = None
540+
x = x.view(B, -1, C)
541+
rot_pos_embed = self.rope.get_embed(shape=(H, W)) if self.rope is not None else None
542+
else:
543+
pos_embed = self.pos_embed
544+
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
524545

525546
if self.cls_token is not None:
526547
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
527-
528-
# apply abs position embedding
529-
if self.pos_embed is not None:
530-
x = x + self.pos_embed
548+
if pos_embed is not None:
549+
x = x + pos_embed
531550
x = self.pos_drop(x)
532551

533552
# obtain shared rotary position embedding and apply patch dropout
534-
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
535553
if self.patch_drop is not None:
536554
x, keep_indices = self.patch_drop(x)
537555
if rot_pos_embed is not None and keep_indices is not None:
538556
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
557+
return x, rot_pos_embed
539558

559+
def forward_features(self, x):
560+
x = self.patch_embed(x)
561+
x, rot_pos_embed = self._pos_embed(x)
540562
for blk in self.blocks:
541563
if self.grad_checkpointing and not torch.jit.is_scripting():
542564
x = checkpoint(blk, x, rope=rot_pos_embed)
543565
else:
544566
x = blk(x, rope=rot_pos_embed)
545-
546567
x = self.norm(x)
547568
return x
548569

0 commit comments

Comments
 (0)