@@ -94,7 +94,7 @@ def forward(self, x):
9494
9595
9696class PatchMerging (nn .Module ):
97- def __init__ (self , input_resolution , dim , out_dim , activation ):
97+ def __init__ (self , input_resolution , dim , out_dim , activation , in_fmt = 'BCHW' ):
9898 super ().__init__ ()
9999 self .input_resolution = input_resolution
100100 self .dim = dim
@@ -104,18 +104,21 @@ def __init__(self, input_resolution, dim, out_dim, activation):
104104 self .conv2 = ConvNorm (out_dim , out_dim , 3 , 2 , 1 , groups = out_dim )
105105 self .conv3 = ConvNorm (out_dim , out_dim , 1 , 1 , 0 )
106106 self .output_resolution = (math .ceil (input_resolution [0 ] / 2 ), math .ceil (input_resolution [1 ] / 2 ))
107+ self .in_fmt = in_fmt
108+ assert self .in_fmt in ['BCHW' , 'BLC' ]
107109
108110 def forward (self , x ):
109- if x .ndim == 3 :
111+ if self .in_fmt == 'BLC' :
112+ # (B, H * W, C) -> (B, C, H, W)
110113 H , W = self .input_resolution
111- B = len (x )
112- # (B, C, H, W)
114+ B = x .shape [0 ]
113115 x = x .view (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 )
114116 x = self .conv1 (x )
115117 x = self .act (x )
116118 x = self .conv2 (x )
117119 x = self .act (x )
118120 x = self .conv3 (x )
121+ # (B, C, H, W) -> (B, H * W, C)
119122 x = x .flatten (2 ).transpose (1 , 2 )
120123 return x
121124
@@ -369,6 +372,7 @@ class TinyVitStage(nn.Module):
369372 local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
370373 activation: the activation function. Default: nn.GELU
371374 out_dim: the output dimension of the layer. Default: dim
375+ in_fmt: input format ('BCHW' or 'BLC'). Default: 'BCHW'
372376 """
373377
374378 def __init__ (
@@ -385,6 +389,7 @@ def __init__(
385389 local_conv_size = 3 ,
386390 activation = nn .GELU ,
387391 out_dim = None ,
392+ in_fmt = 'BCHW'
388393 ):
389394
390395 super ().__init__ ()
@@ -396,7 +401,7 @@ def __init__(
396401 # patch merging layer
397402 if downsample is not None :
398403 self .downsample = downsample (
399- input_resolution , dim = input_dim , out_dim = self .out_dim , activation = activation )
404+ input_resolution , dim = input_dim , out_dim = self .out_dim , activation = activation , in_fmt = in_fmt )
400405 input_resolution = self .downsample .output_resolution
401406 else :
402407 self .downsample = nn .Identity ()
@@ -483,6 +488,10 @@ def __init__(
483488 else :
484489 out_dim = embed_dims [stage_idx ]
485490 drop_path_rate = dpr [sum (depths [:stage_idx ]):sum (depths [:stage_idx + 1 ])]
491+ if stage_idx == 1 :
492+ in_fmt = 'BCHW'
493+ else :
494+ in_fmt = 'BLC'
486495 stage = TinyVitStage (
487496 num_heads = num_heads [stage_idx ],
488497 window_size = window_sizes [stage_idx ],
@@ -496,6 +505,7 @@ def __init__(
496505 downsample = PatchMerging ,
497506 out_dim = out_dim ,
498507 activation = activation ,
508+ in_fmt = in_fmt
499509 )
500510 input_resolution = (math .ceil (input_resolution [0 ] / 2 ), math .ceil (input_resolution [1 ] / 2 ))
501511 stride *= 2
0 commit comments