Skip to content

Commit fdd8c7c

Browse files
committed
Initial impl of dynamic resize for existing vit models (incl vit-resnet hybrids)
1 parent 38c474e commit fdd8c7c

File tree

3 files changed

+39
-7
lines changed

3 files changed

+39
-7
lines changed

timm/layers/pos_embed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def resample_abs_pos_embed(
2929
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
3030
return posemb
3131

32-
if not old_size:
32+
if old_size is None:
3333
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
3434
old_size = hw, hw
3535

timm/models/vision_transformer.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
3939
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
4040
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
41-
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
41+
resample_abs_pos_embed, resample_abs_pos_embed_nhwc, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
4242
from ._builder import build_model_with_cfg
4343
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
4444
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@@ -383,6 +383,7 @@ class VisionTransformer(nn.Module):
383383
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
384384
- https://arxiv.org/abs/2010.11929
385385
"""
386+
dynamic_size: Final[bool]
386387

387388
def __init__(
388389
self,
@@ -400,6 +401,7 @@ def __init__(
400401
init_values: Optional[float] = None,
401402
class_token: bool = True,
402403
no_embed_class: bool = False,
404+
dynamic_size: bool = False,
403405
pre_norm: bool = False,
404406
fc_norm: Optional[bool] = None,
405407
drop_rate: float = 0.,
@@ -452,14 +454,23 @@ def __init__(
452454
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
453455
self.num_prefix_tokens = 1 if class_token else 0
454456
self.no_embed_class = no_embed_class
457+
self.dynamic_size = dynamic_size
455458
self.grad_checkpointing = False
456459

460+
embed_args = {}
461+
if dynamic_size:
462+
embed_args.update(dict(
463+
strict_img_size=False,
464+
flatten=False, # flatten deferred until after pos embed
465+
output_fmt='NHWC',
466+
))
457467
self.patch_embed = embed_layer(
458468
img_size=img_size,
459469
patch_size=patch_size,
460470
in_chans=in_chans,
461471
embed_dim=embed_dim,
462472
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
473+
**embed_args,
463474
)
464475
num_patches = self.patch_embed.num_patches
465476

@@ -546,18 +557,24 @@ def reset_classifier(self, num_classes: int, global_pool=None):
546557
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
547558

548559
def _pos_embed(self, x):
560+
if self.dynamic_size:
561+
B, H, W, C = x.shape
562+
pos_embed = resample_abs_pos_embed(self.pos_embed, (H, W))
563+
x = x.view(B, -1, C)
564+
else:
565+
pos_embed = self.pos_embed
549566
if self.no_embed_class:
550567
# deit-3, updated JAX (big vision)
551568
# position embedding does not overlap with class token, add then concat
552-
x = x + self.pos_embed
569+
x = x + pos_embed
553570
if self.cls_token is not None:
554571
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
555572
else:
556573
# original timm, JAX, and deit vit impl
557574
# pos_embed has entry for class token, concat then add
558575
if self.cls_token is not None:
559576
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
560-
x = x + self.pos_embed
577+
x = x + pos_embed
561578
return self.pos_drop(x)
562579

563580
def _intermediate_layers(

timm/models/vision_transformer_hybrid.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
Hacked together by / Copyright 2020, Ross Wightman
1515
"""
1616
from functools import partial
17-
from typing import List, Tuple
17+
from typing import List, Optional, Tuple
1818

1919
import torch
2020
import torch.nn as nn
2121

2222
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
23-
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple
23+
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
2424
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
2525
from .resnet import resnet26d, resnet50d
2626
from .resnetv2 import ResNetV2, create_resnetv2_stem
@@ -40,6 +40,9 @@ def __init__(
4040
in_chans=3,
4141
embed_dim=768,
4242
bias=True,
43+
flatten: bool = True,
44+
output_fmt: Optional[str] = None,
45+
strict_img_size: bool = True,
4346
):
4447
super().__init__()
4548
assert isinstance(backbone, nn.Module)
@@ -69,14 +72,26 @@ def __init__(
6972
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
7073
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
7174
self.num_patches = self.grid_size[0] * self.grid_size[1]
75+
if output_fmt is not None:
76+
self.flatten = False
77+
self.output_fmt = Format(output_fmt)
78+
else:
79+
# flatten spatial dim and transpose to channels last, kept for bwd compat
80+
self.flatten = flatten
81+
self.output_fmt = Format.NCHW
82+
self.strict_img_size = strict_img_size
83+
7284
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
7385

7486
def forward(self, x):
7587
x = self.backbone(x)
7688
if isinstance(x, (list, tuple)):
7789
x = x[-1] # last feature if backbone outputs list/tuple of features
7890
x = self.proj(x)
79-
x = x.flatten(2).transpose(1, 2)
91+
if self.flatten:
92+
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
93+
elif self.output_fmt != Format.NCHW:
94+
x = nchw_to(x, self.output_fmt)
8095
return x
8196

8297

0 commit comments

Comments
 (0)