1818
1919import torch
2020import torch .nn as nn
21+ import torch .nn .functional as F
2122
2223from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
2324from timm .layers import StdConv2dSame , StdConv2d , to_2tuple , Format , nchw_to
@@ -32,6 +33,7 @@ class HybridEmbed(nn.Module):
3233 Extract feature map from CNN, flatten, project to embedding dim.
3334 """
3435 output_fmt : Format
36+ dynamic_img_pad : torch .jit .Final [bool ]
3537
3638 def __init__ (
3739 self ,
@@ -45,6 +47,7 @@ def __init__(
4547 flatten : bool = True ,
4648 output_fmt : Optional [str ] = None ,
4749 strict_img_size : bool = True ,
50+ dynamic_img_pad : bool = False ,
4851 ):
4952 super ().__init__ ()
5053 assert isinstance (backbone , nn .Module )
@@ -71,7 +74,8 @@ def __init__(
7174 feature_dim = self .backbone .feature_info .channels ()[- 1 ]
7275 else :
7376 feature_dim = self .backbone .num_features
74- assert feature_size [0 ] % patch_size [0 ] == 0 and feature_size [1 ] % patch_size [1 ] == 0
77+ if not dynamic_img_pad :
78+ assert feature_size [0 ] % patch_size [0 ] == 0 and feature_size [1 ] % patch_size [1 ] == 0
7579 self .grid_size = (feature_size [0 ] // patch_size [0 ], feature_size [1 ] // patch_size [1 ])
7680 self .num_patches = self .grid_size [0 ] * self .grid_size [1 ]
7781 if output_fmt is not None :
@@ -82,13 +86,19 @@ def __init__(
8286 self .flatten = flatten
8387 self .output_fmt = Format .NCHW
8488 self .strict_img_size = strict_img_size
89+ self .dynamic_img_pad = dynamic_img_pad
8590
8691 self .proj = nn .Conv2d (feature_dim , embed_dim , kernel_size = patch_size , stride = patch_size , bias = bias )
8792
8893 def forward (self , x ):
8994 x = self .backbone (x )
9095 if isinstance (x , (list , tuple )):
9196 x = x [- 1 ] # last feature if backbone outputs list/tuple of features
97+ _ , _ , H , W = x .shape
98+ if self .dynamic_img_pad :
99+ pad_h = (self .patch_size [0 ] - H % self .patch_size [0 ]) % self .patch_size [0 ]
100+ pad_w = (self .patch_size [1 ] - W % self .patch_size [1 ]) % self .patch_size [1 ]
101+ x = F .pad (x , (0 , pad_w , 0 , pad_h ))
92102 x = self .proj (x )
93103 if self .flatten :
94104 x = x .flatten (2 ).transpose (1 , 2 ) # NCHW -> NLC
0 commit comments