1616
1717import torch
1818import torch .nn as nn
19- import torch .nn .functional as F
2019
2120from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
2221from .helpers import named_apply , build_model_with_cfg , checkpoint_seq
2322from .layers import trunc_normal_ , SelectAdaptivePool2d , DropPath , ConvMlp , Mlp , LayerNorm2d ,\
24- create_conv2d , make_divisible
23+ create_conv2d , get_act_layer , make_divisible , to_ntuple
2524from .registry import register_model
2625
2726
@@ -40,14 +39,13 @@ def _cfg(url='', **kwargs):
4039
4140
4241default_cfgs = dict (
43- convnext_tiny = _cfg (url = "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth" ),
44- convnext_small = _cfg (url = "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth" ),
45- convnext_base = _cfg (url = "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth" ),
46- convnext_large = _cfg (url = "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth" ),
47-
4842 # timm specific variants
49- convnext_atto = _cfg (url = '' ),
50- convnext_atto_ols = _cfg (url = '' ),
43+ convnext_atto = _cfg (
44+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth' ,
45+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 0.95 ),
46+ convnext_atto_ols = _cfg (
47+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth' ,
48+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 0.95 ),
5149 convnext_femto = _cfg (
5250 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth' ,
5351 test_input_size = (3 , 288 , 288 ), test_crop_pct = 0.95 ),
@@ -70,16 +68,34 @@ def _cfg(url='', **kwargs):
7068 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth' ,
7169 crop_pct = 0.95 , test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
7270
71+ convnext_tiny = _cfg (
72+ url = "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth" ,
73+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
74+ convnext_small = _cfg (
75+ url = "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth" ,
76+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
77+ convnext_base = _cfg (
78+ url = "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth" ,
79+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
80+ convnext_large = _cfg (
81+ url = "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth" ,
82+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
83+
7384 convnext_tiny_in22ft1k = _cfg (
74- url = 'https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth' ),
85+ url = 'https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth' ,
86+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
7587 convnext_small_in22ft1k = _cfg (
76- url = 'https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth' ),
88+ url = 'https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth' ,
89+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
7790 convnext_base_in22ft1k = _cfg (
78- url = 'https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth' ),
91+ url = 'https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth' ,
92+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
7993 convnext_large_in22ft1k = _cfg (
80- url = 'https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth' ),
94+ url = 'https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth' ,
95+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
8196 convnext_xlarge_in22ft1k = _cfg (
82- url = 'https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth' ),
97+ url = 'https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth' ,
98+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
8399
84100 convnext_tiny_384_in22ft1k = _cfg (
85101 url = 'https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth' ,
@@ -121,37 +137,39 @@ class ConvNeXtBlock(nn.Module):
121137 is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
122138
123139 Args:
124- dim (int): Number of input channels.
140+ in_chs (int): Number of input channels.
125141 drop_path (float): Stochastic depth rate. Default: 0.0
126142 ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
127143 """
128144
129145 def __init__ (
130146 self ,
131- dim ,
132- dim_out = None ,
147+ in_chs ,
148+ out_chs = None ,
149+ kernel_size = 7 ,
133150 stride = 1 ,
134151 dilation = 1 ,
135152 mlp_ratio = 4 ,
136153 conv_mlp = False ,
137154 conv_bias = True ,
138155 ls_init_value = 1e-6 ,
156+ act_layer = 'gelu' ,
139157 norm_layer = None ,
140- act_layer = nn .GELU ,
141158 drop_path = 0. ,
142159 ):
143160 super ().__init__ ()
144- dim_out = dim_out or dim
161+ out_chs = out_chs or in_chs
162+ act_layer = get_act_layer (act_layer )
145163 if not norm_layer :
146164 norm_layer = partial (LayerNorm2d , eps = 1e-6 ) if conv_mlp else partial (nn .LayerNorm , eps = 1e-6 )
147165 mlp_layer = ConvMlp if conv_mlp else Mlp
148166 self .use_conv_mlp = conv_mlp
149167
150168 self .conv_dw = create_conv2d (
151- dim , dim_out , kernel_size = 7 , stride = stride , dilation = dilation , depthwise = True , bias = conv_bias )
152- self .norm = norm_layer (dim_out )
153- self .mlp = mlp_layer (dim_out , int (mlp_ratio * dim_out ), act_layer = act_layer )
154- self .gamma = nn .Parameter (ls_init_value * torch .ones (dim_out )) if ls_init_value > 0 else None
169+ in_chs , out_chs , kernel_size = kernel_size , stride = stride , dilation = dilation , depthwise = True , bias = conv_bias )
170+ self .norm = norm_layer (out_chs )
171+ self .mlp = mlp_layer (out_chs , int (mlp_ratio * out_chs ), act_layer = act_layer )
172+ self .gamma = nn .Parameter (ls_init_value * torch .ones (out_chs )) if ls_init_value > 0 else None
155173 self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
156174
157175 def forward (self , x ):
@@ -178,13 +196,15 @@ def __init__(
178196 self ,
179197 in_chs ,
180198 out_chs ,
199+ kernel_size = 7 ,
181200 stride = 2 ,
182201 depth = 2 ,
183202 dilation = (1 , 1 ),
184203 drop_path_rates = None ,
185204 ls_init_value = 1.0 ,
186205 conv_mlp = False ,
187206 conv_bias = True ,
207+ act_layer = 'gelu' ,
188208 norm_layer = None ,
189209 norm_layer_cl = None
190210 ):
@@ -208,13 +228,15 @@ def __init__(
208228 stage_blocks = []
209229 for i in range (depth ):
210230 stage_blocks .append (ConvNeXtBlock (
211- dim = in_chs ,
212- dim_out = out_chs ,
231+ in_chs = in_chs ,
232+ out_chs = out_chs ,
233+ kernel_size = kernel_size ,
213234 dilation = dilation [1 ],
214235 drop_path = drop_path_rates [i ],
215236 ls_init_value = ls_init_value ,
216237 conv_mlp = conv_mlp ,
217238 conv_bias = conv_bias ,
239+ act_layer = act_layer ,
218240 norm_layer = norm_layer if conv_mlp else norm_layer_cl
219241 ))
220242 in_chs = out_chs
@@ -252,19 +274,22 @@ def __init__(
252274 output_stride = 32 ,
253275 depths = (3 , 3 , 9 , 3 ),
254276 dims = (96 , 192 , 384 , 768 ),
277+ kernel_sizes = 7 ,
255278 ls_init_value = 1e-6 ,
256279 stem_type = 'patch' ,
257280 patch_size = 4 ,
258281 head_init_scale = 1. ,
259282 head_norm_first = False ,
260283 conv_mlp = False ,
261284 conv_bias = True ,
285+ act_layer = 'gelu' ,
262286 norm_layer = None ,
263287 drop_rate = 0. ,
264288 drop_path_rate = 0. ,
265289 ):
266290 super ().__init__ ()
267291 assert output_stride in (8 , 16 , 32 )
292+ kernel_sizes = to_ntuple (4 )(kernel_sizes )
268293 if norm_layer is None :
269294 norm_layer = partial (LayerNorm2d , eps = 1e-6 )
270295 norm_layer_cl = norm_layer if conv_mlp else partial (nn .LayerNorm , eps = 1e-6 )
@@ -312,13 +337,15 @@ def __init__(
312337 stages .append (ConvNeXtStage (
313338 prev_chs ,
314339 out_chs ,
340+ kernel_size = kernel_sizes [i ],
315341 stride = stride ,
316342 dilation = (first_dilation , dilation ),
317343 depth = depths [i ],
318344 drop_path_rates = dp_rates [i ],
319345 ls_init_value = ls_init_value ,
320346 conv_mlp = conv_mlp ,
321347 conv_bias = conv_bias ,
348+ act_layer = act_layer ,
322349 norm_layer = norm_layer ,
323350 norm_layer_cl = norm_layer_cl
324351 ))
0 commit comments