Skip to content

Commit fc5d705

Browse files
committed
dynamic_size -> dynamic_img_size, add dynamic_img_pad for padding option
1 parent 1f4512f commit fc5d705

File tree

5 files changed

+34
-14
lines changed

5 files changed

+34
-14
lines changed

timm/layers/patch_embed.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class PatchEmbed(nn.Module):
2626
""" 2D Image to Patch Embedding
2727
"""
2828
output_fmt: Format
29+
dynamic_img_pad: torch.jit.Final[bool]
2930

3031
def __init__(
3132
self,
@@ -38,6 +39,7 @@ def __init__(
3839
output_fmt: Optional[str] = None,
3940
bias: bool = True,
4041
strict_img_size: bool = True,
42+
dynamic_img_pad: bool = False,
4143
):
4244
super().__init__()
4345
self.patch_size = to_2tuple(patch_size)
@@ -58,6 +60,7 @@ def __init__(
5860
self.flatten = flatten
5961
self.output_fmt = Format.NCHW
6062
self.strict_img_size = strict_img_size
63+
self.dynamic_img_pad = dynamic_img_pad
6164

6265
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
6366
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
@@ -68,7 +71,7 @@ def forward(self, x):
6871
if self.strict_img_size:
6972
_assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
7073
_assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
71-
else:
74+
elif not self.dynamic_img_pad:
7275
_assert(
7376
H % self.patch_size[0] == 0,
7477
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
@@ -77,7 +80,10 @@ def forward(self, x):
7780
W % self.patch_size[1] == 0,
7881
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
7982
)
80-
83+
if self.dynamic_img_pad:
84+
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
85+
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
86+
x = F.pad(x, (0, pad_w, 0, pad_h))
8187
x = self.proj(x)
8288
if self.flatten:
8389
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC

timm/models/deit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def set_distilled_training(self, enable=True):
7474
self.distilled_training = enable
7575

7676
def _pos_embed(self, x):
77-
if self.dynamic_size:
77+
if self.dynamic_img_size:
7878
B, H, W, C = x.shape
7979
pos_embed = resample_abs_pos_embed(
8080
self.pos_embed,

timm/models/eva.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,8 @@ def __init__(
367367
use_abs_pos_emb: bool = True,
368368
use_rot_pos_emb: bool = False,
369369
use_post_norm: bool = False,
370-
dynamic_size: bool = False,
370+
dynamic_img_size: bool = False,
371+
dynamic_img_pad: bool = False,
371372
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
372373
head_init_scale: float = 0.001,
373374
):
@@ -407,18 +408,19 @@ def __init__(
407408
self.global_pool = global_pool
408409
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
409410
self.num_prefix_tokens = 1 if class_token else 0
410-
self.dynamic_size = dynamic_size
411+
self.dynamic_img_size = dynamic_img_size
411412
self.grad_checkpointing = False
412413

413414
embed_args = {}
414-
if dynamic_size:
415+
if dynamic_img_size:
415416
# flatten deferred until after pos embed
416417
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
417418
self.patch_embed = PatchEmbed(
418419
img_size=img_size,
419420
patch_size=patch_size,
420421
in_chans=in_chans,
421422
embed_dim=embed_dim,
423+
dynamic_img_pad=dynamic_img_pad,
422424
**embed_args,
423425
)
424426
num_patches = self.patch_embed.num_patches
@@ -442,7 +444,7 @@ def __init__(
442444
self.rope = RotaryEmbeddingCat(
443445
embed_dim // num_heads,
444446
in_pixels=False,
445-
feat_shape=None if dynamic_size else self.patch_embed.grid_size,
447+
feat_shape=None if dynamic_img_size else self.patch_embed.grid_size,
446448
ref_feat_shape=ref_feat_shape,
447449
)
448450
else:
@@ -527,7 +529,7 @@ def reset_classifier(self, num_classes, global_pool=None):
527529
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
528530

529531
def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
530-
if self.dynamic_size:
532+
if self.dynamic_img_size:
531533
B, H, W, C = x.shape
532534
if self.pos_embed is not None:
533535
pos_embed = resample_abs_pos_embed(

timm/models/vision_transformer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ class VisionTransformer(nn.Module):
383383
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
384384
- https://arxiv.org/abs/2010.11929
385385
"""
386-
dynamic_size: Final[bool]
386+
dynamic_img_size: Final[bool]
387387

388388
def __init__(
389389
self,
@@ -401,9 +401,10 @@ def __init__(
401401
init_values: Optional[float] = None,
402402
class_token: bool = True,
403403
no_embed_class: bool = False,
404-
dynamic_size: bool = False,
405404
pre_norm: bool = False,
406405
fc_norm: Optional[bool] = None,
406+
dynamic_img_size: bool = False,
407+
dynamic_img_pad: bool = False,
407408
drop_rate: float = 0.,
408409
pos_drop_rate: float = 0.,
409410
patch_drop_rate: float = 0.,
@@ -454,11 +455,11 @@ def __init__(
454455
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
455456
self.num_prefix_tokens = 1 if class_token else 0
456457
self.no_embed_class = no_embed_class
457-
self.dynamic_size = dynamic_size
458+
self.dynamic_img_size = dynamic_img_size
458459
self.grad_checkpointing = False
459460

460461
embed_args = {}
461-
if dynamic_size:
462+
if dynamic_img_size:
462463
# flatten deferred until after pos embed
463464
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
464465
self.patch_embed = embed_layer(
@@ -467,6 +468,7 @@ def __init__(
467468
in_chans=in_chans,
468469
embed_dim=embed_dim,
469470
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
471+
dynamic_img_pad=dynamic_img_pad,
470472
**embed_args,
471473
)
472474
num_patches = self.patch_embed.num_patches
@@ -554,7 +556,7 @@ def reset_classifier(self, num_classes: int, global_pool=None):
554556
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
555557

556558
def _pos_embed(self, x):
557-
if self.dynamic_size:
559+
if self.dynamic_img_size:
558560
B, H, W, C = x.shape
559561
pos_embed = resample_abs_pos_embed(
560562
self.pos_embed,

timm/models/vision_transformer_hybrid.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020
import torch.nn as nn
21+
import torch.nn.functional as F
2122

2223
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2324
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
@@ -32,6 +33,7 @@ class HybridEmbed(nn.Module):
3233
Extract feature map from CNN, flatten, project to embedding dim.
3334
"""
3435
output_fmt: Format
36+
dynamic_img_pad: torch.jit.Final[bool]
3537

3638
def __init__(
3739
self,
@@ -45,6 +47,7 @@ def __init__(
4547
flatten: bool = True,
4648
output_fmt: Optional[str] = None,
4749
strict_img_size: bool = True,
50+
dynamic_img_pad: bool = False,
4851
):
4952
super().__init__()
5053
assert isinstance(backbone, nn.Module)
@@ -71,7 +74,8 @@ def __init__(
7174
feature_dim = self.backbone.feature_info.channels()[-1]
7275
else:
7376
feature_dim = self.backbone.num_features
74-
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
77+
if not dynamic_img_pad:
78+
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
7579
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
7680
self.num_patches = self.grid_size[0] * self.grid_size[1]
7781
if output_fmt is not None:
@@ -82,13 +86,19 @@ def __init__(
8286
self.flatten = flatten
8387
self.output_fmt = Format.NCHW
8488
self.strict_img_size = strict_img_size
89+
self.dynamic_img_pad = dynamic_img_pad
8590

8691
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
8792

8893
def forward(self, x):
8994
x = self.backbone(x)
9095
if isinstance(x, (list, tuple)):
9196
x = x[-1] # last feature if backbone outputs list/tuple of features
97+
_, _, H, W = x.shape
98+
if self.dynamic_img_pad:
99+
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
100+
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
101+
x = F.pad(x, (0, pad_w, 0, pad_h))
92102
x = self.proj(x)
93103
if self.flatten:
94104
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC

0 commit comments

Comments
 (0)