From 57f85542da8061a7c1c90832ddae1da419b75a58 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Thu, 29 May 2025 05:00:38 +0800 Subject: [PATCH 01/12] support gradient checkpoint in forward_intermediates --- timm/models/beit.py | 5 ++++- timm/models/byobnet.py | 7 +++++-- timm/models/cait.py | 7 +++++-- timm/models/davit.py | 7 +++++-- timm/models/efficientnet.py | 8 +++++--- timm/models/efficientvit_mit.py | 12 +++++++++--- timm/models/efficientvit_msra.py | 7 +++++-- timm/models/eva.py | 5 ++++- timm/models/hiera.py | 7 +++++-- timm/models/levit.py | 7 +++++-- timm/models/metaformer.py | 5 ++++- timm/models/mlp_mixer.py | 7 +++++-- timm/models/mobilenetv3.py | 8 +++++--- timm/models/nextvit.py | 1 - timm/models/repghost.py | 7 +++++-- timm/models/repvit.py | 7 +++++-- timm/models/resnetv2.py | 7 +++++-- timm/models/rexnet.py | 7 +++++-- timm/models/tiny_vit.py | 7 +++++-- timm/models/tnt.py | 5 ++++- timm/models/vision_transformer.py | 7 +++++-- timm/models/vision_transformer_relpos.py | 5 ++++- timm/models/vision_transformer_sam.py | 7 +++++-- timm/models/volo.py | 5 ++++- timm/models/xcit.py | 5 ++++- 25 files changed, 117 insertions(+), 45 deletions(-) diff --git a/timm/models/beit.py b/timm/models/beit.py index 5123a60627..13661f0fe3 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -451,7 +451,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x, shared_rel_pos_bias=rel_pos_bias) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias) + else: + x = blk(x, shared_rel_pos_bias=rel_pos_bias) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 764d5ad5eb..7bb92b39bf 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -44,7 +44,7 @@ ) from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import named_apply, checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq, named_apply from ._registry import generate_default_cfgs, register_model __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] @@ -1384,7 +1384,10 @@ def forward_intermediates( stages = self.stages[:max_index] for stage in stages: feat_idx += 1 - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if not exclude_final_conv and feat_idx == last_idx: # default feature_info for this model uses final_conv as the last feature output (if present) x = self.final_conv(x) diff --git a/timm/models/cait.py b/timm/models/cait.py index 28e14ec756..2c500ec3dd 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -18,7 +18,7 @@ from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, use_fused_attn from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import register_model, generate_default_cfgs __all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn'] @@ -373,7 +373,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) diff --git a/timm/models/davit.py b/timm/models/davit.py index f538ecca84..27098a8a6b 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -25,7 +25,7 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import generate_default_cfgs, register_model __all__ = ['DaVit'] @@ -671,7 +671,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: if norm and feat_idx == last_idx: x_inter = self.norm_pre(x) # applying final norm to last intermediate diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index b5bc35c036..ebb4cf4a6b 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -210,9 +210,11 @@ def forward_intermediates( blocks = self.blocks else: blocks = self.blocks[:max_index] - for blk in blocks: - feat_idx += 1 - x = blk(x) + for feat_idx, blk in enumerate(blocks, start=1): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index 27872310e0..68c84ba02c 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -19,7 +19,7 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_module -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -789,7 +789,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) @@ -943,7 +946,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index 91caaa5a4d..bf05019032 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -18,7 +18,7 @@ from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -510,7 +510,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/eva.py b/timm/models/eva.py index 166a07bb03..c6413f3e47 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -716,7 +716,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x, rope=rot_pos_embed) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, rope=rot_pos_embed) + else: + x = blk(x, rope=rot_pos_embed) if i in take_indices: intermediates.append(self.norm(x) if norm else x) diff --git a/timm/models/hiera.py b/timm/models/hiera.py index 2c16a9d63e..fa9d6d2833 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -24,7 +24,7 @@ # -------------------------------------------------------- import math from functools import partial -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -719,7 +719,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) if i in take_indices: x_int = self.reroll(x, i, mask=mask) intermediates.append(x_int.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x_int) diff --git a/timm/models/levit.py b/timm/models/levit.py index 577fc5f2d7..a4c9ce628a 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -34,7 +34,7 @@ from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import generate_default_cfgs, register_model __all__ = ['Levit'] @@ -671,7 +671,10 @@ def forward_intermediates( else: stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: if self.use_conv: intermediates.append(x) diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index 490852cfe4..a632936ba2 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -631,7 +631,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and stage.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 25cde6a67c..e8ad6860cb 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -49,7 +49,7 @@ from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import named_apply, checkpoint_seq +from ._manipulate import named_apply, checkpoint, checkpoint_seq from ._registry import generate_default_cfgs, register_model, register_model_deprecations __all__ = ['MixerBlock', 'MlpMixer'] # model_registry will add each entrypoint fn to this @@ -298,7 +298,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 08dcb064fa..20e068e500 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -211,9 +211,11 @@ def forward_intermediates( blocks = self.blocks else: blocks = self.blocks[:max_index] - for blk in blocks: - feat_idx += 1 - x = blk(x) + for feat_idx, blk in enumerate(blocks, start=1): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(blk, x, flatten=True) + else: + x = blk(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index 2f232e2990..9483510536 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -17,7 +17,6 @@ from timm.layers import ClassifierHead from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model diff --git a/timm/models/repghost.py b/timm/models/repghost.py index 77fc35d59e..7059eb776f 100644 --- a/timm/models/repghost.py +++ b/timm/models/repghost.py @@ -17,7 +17,7 @@ from ._builder import build_model_with_cfg from ._efficientnet_blocks import SqueezeExcite, ConvBnAct from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import register_model, generate_default_cfgs __all__ = ['RepGhostNet'] @@ -336,7 +336,10 @@ def forward_intermediates( stages = self.blocks[:max_index + 1] for feat_idx, stage in enumerate(stages, start=1): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/repvit.py b/timm/models/repvit.py index ddcfed55c8..d1ed7d023a 100644 --- a/timm/models/repvit.py +++ b/timm/models/repvit.py @@ -23,7 +23,7 @@ from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import register_model, generate_default_cfgs __all__ = ['RepVit'] @@ -367,7 +367,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 5cc164ae1b..dd761181be 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -41,7 +41,7 @@ DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv +from ._manipulate import checkpoint, checkpoint_seq, named_apply, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations __all__ = ['ResNetV2'] # model_registry will add each entrypoint fn to this @@ -585,7 +585,10 @@ def forward_intermediates( stages = self.stages[:max_index] for feat_idx, stage in enumerate(stages, start=1): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: if feat_idx == last_idx: x_inter = self.norm(x) if norm else x diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index dd3cb4f32f..31993c898c 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -22,7 +22,7 @@ from ._builder import build_model_with_cfg from ._efficientnet_builder import efficientnet_init_weights from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import generate_default_cfgs, register_model __all__ = ['RexNet'] # model_registry will add each entrypoint fn to this @@ -271,7 +271,10 @@ def forward_intermediates( stages = self.features[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index d238fa5b2d..2733ca36f9 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -22,7 +22,7 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_module -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -570,7 +570,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stages, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index fa6e1fc9e7..0ecd8e72a4 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -386,7 +386,10 @@ def forward_intermediates( blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - pixel_embed, patch_embed = blk(pixel_embed, patch_embed) + if self.grad_checkpointing and not torch.jit.is_scripting(): + pixel_embed, patch_embed = checkpoint(blk, pixel_embed, patch_embed) + else: + pixel_embed, patch_embed = blk(pixel_embed, patch_embed) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(patch_embed) if norm else patch_embed) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 3c7b9a2277..6cad3164fa 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -46,7 +46,7 @@ get_act_layer, get_norm_layer, LayerType from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv +from ._manipulate import named_apply, checkpoint, checkpoint_seq, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations __all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this @@ -759,7 +759,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk. x) + else: + x = blk(x) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 030c24dc69..dcccba73ba 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -427,7 +427,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x, shared_rel_pos=shared_rel_pos) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, shared_rel_pos=shared_rel_pos) + else: + x = blk(x, shared_rel_pos=shared_rel_pos) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 75bb12e56f..df70f4a251 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -24,7 +24,7 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import generate_default_cfgs, register_model # model_registry will add each entrypoint fn to this @@ -579,7 +579,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) if i in take_indices: # make output BCHW if norm: diff --git a/timm/models/volo.py b/timm/models/volo.py index 46be778f67..2a3d01adea 100644 --- a/timm/models/volo.py +++ b/timm/models/volo.py @@ -749,7 +749,10 @@ def forward_intermediates( # add positional encoding after outlooker blocks x = x + self.pos_embed x = self.pos_drop(x) - x = block(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(block, x) + else: + x = block(x) if idx in take_indices: if norm and idx >= 2: x_inter = self.norm(x) diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 250749f1cf..d4ea7af9ef 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -478,7 +478,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x, Hp, Wp) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, Hp, Wp) + else: + x = blk(x, Hp, Wp) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) From a80348a8c8935da4b0dd2fc37dc80961f59468ef Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Thu, 29 May 2025 05:13:50 +0800 Subject: [PATCH 02/12] support starnet and ghostnet --- timm/models/ghostnet.py | 7 +++++-- timm/models/starnet.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 2f1587015b..881b29322c 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -23,7 +23,7 @@ from ._builder import build_model_with_cfg from ._efficientnet_blocks import SqueezeExcite, ConvBnAct from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import register_model, generate_default_cfgs __all__ = ['GhostNet'] @@ -727,7 +727,10 @@ def forward_intermediates( stages = self.blocks[:max_index + 1] for feat_idx, stage in enumerate(stages, start=1): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/starnet.py b/timm/models/starnet.py index bc140e00d1..2d7b0e4902 100644 --- a/timm/models/starnet.py +++ b/timm/models/starnet.py @@ -19,7 +19,7 @@ from timm.layers import DropPath, SelectAdaptivePool2d, Linear, LayerType, trunc_normal_ from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import register_model, generate_default_cfgs __all__ = ['StarNet'] @@ -198,7 +198,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: if norm and feat_idx == last_idx: x_inter = self.norm(x) # applying final norm last intermediate From b0b28e29aab9190d994b5b73630b714a3daf828c Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Thu, 29 May 2025 05:17:16 +0800 Subject: [PATCH 03/12] fix metaformer and nextvit --- timm/models/metaformer.py | 6 +++--- timm/models/nextvit.py | 7 +++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index a632936ba2..72163b13d2 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -41,7 +41,7 @@ use_fused_attn from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import generate_default_cfgs, register_model __all__ = ['MetaFormer'] @@ -631,8 +631,8 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - if self.grad_checkpointing and stage.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(stage, x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) else: x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index 9483510536..402a9d76ea 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -17,7 +17,7 @@ from timm.layers import ClassifierHead from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import generate_default_cfgs, register_model __all__ = ['NextViT'] @@ -594,7 +594,10 @@ def forward_intermediates( stages = self.stages[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: if feat_idx == last_idx: x_inter = self.norm(x) if norm else x From c9f9c30dfad1b7327e00f350d87268fe23fe4180 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Fri, 13 Jun 2025 12:38:18 +0800 Subject: [PATCH 04/12] fix some model --- timm/models/byobnet.py | 4 +-- timm/models/efficientnet.py | 2 +- timm/models/efficientvit_mit.py | 4 +-- timm/models/ghostnet.py | 2 +- timm/models/hieradet_sam2.py | 19 +++++++++----- timm/models/inception_resnet_v2.py | 1 - timm/models/inception_v3.py | 1 - timm/models/nasnet.py | 2 -- timm/models/nextvit.py | 4 +-- timm/models/nfnet.py | 2 +- timm/models/pnasnet.py | 1 - timm/models/rdnet.py | 42 +++++++++++++++--------------- timm/models/repghost.py | 4 +-- timm/models/resnetv2.py | 4 +-- timm/models/starnet.py | 4 +-- timm/models/tresnet.py | 7 +++-- timm/models/xcit.py | 5 +++- 17 files changed, 57 insertions(+), 51 deletions(-) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 7bb92b39bf..111765031e 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -44,7 +44,7 @@ ) from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint, checkpoint_seq, named_apply +from ._manipulate import named_apply, checkpoint_seq from ._registry import generate_default_cfgs, register_model __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] @@ -1385,7 +1385,7 @@ def forward_intermediates( for stage in stages: feat_idx += 1 if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(stage, x) + x = checkpoint_seq(stage, x) else: x = stage(x) if not exclude_final_conv and feat_idx == last_idx: diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index ebb4cf4a6b..800f4a3833 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -212,7 +212,7 @@ def forward_intermediates( blocks = self.blocks[:max_index] for feat_idx, blk in enumerate(blocks, start=1): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(blk, x) + x = checkpoint_seq(blk, x, flatten=True) else: x = blk(x) if feat_idx in take_indices: diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index 68c84ba02c..f3ad836848 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -790,7 +790,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(stage, x) + x = checkpoint_seq(stages, x) else: x = stage(x) if feat_idx in take_indices: @@ -947,7 +947,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(stage, x) + x = checkpoint_seq(stages, x) else: x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 881b29322c..bd4a02ac1e 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -728,7 +728,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages, start=1): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(stage, x) + x = checkpoint_seq(stage, x, flatten=True) else: x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index 6cd2592a95..b9eb850fb3 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -1,12 +1,11 @@ import math from copy import deepcopy from functools import partial -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F -from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, LayerScale, \ @@ -14,8 +13,8 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv -from ._registry import generate_default_cfgs, register_model, register_model_deprecations +from ._manipulate import named_apply, checkpoint +from ._registry import generate_default_cfgs, register_model def window_partition(x, window_size: Tuple[int, int]): @@ -471,7 +470,10 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) if i in take_indices: x_out = x.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x intermediates.append(x_out) @@ -503,8 +505,11 @@ def prune_intermediate_layers( def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) # BHWC x = self._pos_embed(x) - for i, blk in enumerate(self.blocks): - x = blk(x) + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) return x def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor: diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index 7fdfee41ed..d691be7a8f 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -5,7 +5,6 @@ from functools import partial import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import create_classifier, ConvNormAct diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index 8cb1a151df..a55521c3de 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -4,7 +4,6 @@ Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE """ from functools import partial -from typing import Optional import torch import torch.nn as nn diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 0bcc048568..db12d2116f 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -3,11 +3,9 @@ https://github.com/Cadene/pretrained-models.pytorch """ from functools import partial -from typing import Optional import torch import torch.nn as nn -import torch.nn.functional as F from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier from ._builder import build_model_with_cfg diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index 402a9d76ea..9134ec8622 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -17,7 +17,7 @@ from timm.layers import ClassifierHead from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint, checkpoint_seq +from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model __all__ = ['NextViT'] @@ -595,7 +595,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(stage, x) + x = checkpoint_seq(stage, x) else: x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 68e92128f3..b59b428c01 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -215,7 +215,7 @@ def create_stem( if 'deep' in stem_type: if 'quad' in stem_type: # 4 deep conv stack as in NFNet-F models - assert not 'pool' in stem_type + assert 'pool' not in stem_type stem_chs = (out_chs // 8, out_chs // 4, out_chs // 2, out_chs) strides = (2, 1, 1, 2) stem_stride = 4 diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 20d17945b5..7f33aaeabb 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -10,7 +10,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier from ._builder import build_model_with_cfg diff --git a/timm/models/rdnet.py b/timm/models/rdnet.py index a3a205fff6..5764b6ed82 100644 --- a/timm/models/rdnet.py +++ b/timm/models/rdnet.py @@ -281,6 +281,27 @@ def __init__( named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) + @torch.jit.ignore + def group_matcher(self, coarse=False): + assert not coarse, "coarse grouping is not implemented for RDNet" + return dict( + stem=r'^stem', + blocks=r'^dense_stages\.(\d+)', + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.dense_stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head.fc + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) + def forward_intermediates( self, x: torch.Tensor, @@ -350,14 +371,6 @@ def prune_intermediate_layers( self.reset_classifier(0, '') return take_indices - @torch.jit.ignore - def get_classifier(self) -> nn.Module: - return self.head.fc - - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): - self.num_classes = num_classes - self.head.reset(num_classes, global_pool) - def forward_features(self, x): x = self.stem(x) x = self.dense_stages(x) @@ -372,19 +385,6 @@ def forward(self, x): x = self.forward_head(x) return x - @torch.jit.ignore - def group_matcher(self, coarse=False): - assert not coarse, "coarse grouping is not implemented for RDNet" - return dict( - stem=r'^stem', - blocks=r'^dense_stages\.(\d+)', - ) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - for s in self.dense_stages: - s.grad_checkpointing = enable - def _init_weights(module, name=None, head_init_scale=1.0): if isinstance(module, nn.Conv2d): diff --git a/timm/models/repghost.py b/timm/models/repghost.py index 7059eb776f..1a453313ee 100644 --- a/timm/models/repghost.py +++ b/timm/models/repghost.py @@ -17,7 +17,7 @@ from ._builder import build_model_with_cfg from ._efficientnet_blocks import SqueezeExcite, ConvBnAct from ._features import feature_take_indices -from ._manipulate import checkpoint, checkpoint_seq +from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs __all__ = ['RepGhostNet'] @@ -337,7 +337,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages, start=1): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(stage, x) + x = checkpoint_seq(stage, x, flatten=True) else: x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index dd761181be..ad84c0f1f7 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -41,7 +41,7 @@ DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint, checkpoint_seq, named_apply, adapt_input_conv +from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations __all__ = ['ResNetV2'] # model_registry will add each entrypoint fn to this @@ -586,7 +586,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages, start=1): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(stage, x) + x = checkpoint_seq(stage, x, flatten=True) else: x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/starnet.py b/timm/models/starnet.py index 2d7b0e4902..53a8641178 100644 --- a/timm/models/starnet.py +++ b/timm/models/starnet.py @@ -19,7 +19,7 @@ from timm.layers import DropPath, SelectAdaptivePool2d, Linear, LayerType, trunc_normal_ from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint, checkpoint_seq +from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs __all__ = ['StarNet'] @@ -199,7 +199,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(stage, x) + x = checkpoint_seq(stages, x, flatten=True) else: x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 0fb76fa40c..2c452e4707 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -15,7 +15,7 @@ from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import register_model, generate_default_cfgs, register_model_deprecations __all__ = ['TResNet'] # model_registry will add each entrypoint fn to this @@ -263,7 +263,10 @@ def forward_intermediates( stages = self.body[:max_index + 1] for feat_idx, stage in enumerate(stages): - x = stage(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(stage, x) + else: + x = stage(x) if feat_idx in take_indices: intermediates.append(x) diff --git a/timm/models/xcit.py b/timm/models/xcit.py index d4ea7af9ef..271578adf8 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -497,7 +497,10 @@ def forward_intermediates( # NOTE not supporting return of class tokens x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) for blk in self.cls_attn_blocks: - x = blk(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) x = self.norm(x) From 0638708731f7635d274d1884ccfc2f98cfc4221a Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Fri, 13 Jun 2025 12:38:50 +0800 Subject: [PATCH 05/12] rename naflexvit mask to attn_mask --- timm/models/naflexvit.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py index 9684b397ec..78ae582e81 100644 --- a/timm/models/naflexvit.py +++ b/timm/models/naflexvit.py @@ -20,13 +20,13 @@ import math from dataclasses import dataclass, fields, replace from functools import partial -from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Final, Any, Literal +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Any import torch import torch.nn as nn import torch.nn.functional as F -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import ( AttentionPoolLatent, Mlp, @@ -34,14 +34,13 @@ get_act_layer, get_norm_layer, LayerNorm, - LayerType, _assert, ) from timm.models._builder import build_model_with_cfg from timm.models._features import feature_take_indices from timm.models._features_fx import register_notrace_function, register_notrace_module from timm.models._registry import register_model, generate_default_cfgs -from timm.models._manipulate import checkpoint_seq, named_apply +from timm.models._manipulate import checkpoint, checkpoint_seq, named_apply from .vision_transformer import Block, global_pool_nlc @@ -1054,7 +1053,7 @@ def forward_intermediates( output_dict: bool = False, patch_coord: Optional[torch.Tensor] = None, patch_valid: Optional[torch.Tensor] = None, - mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]: """ Forward features that returns intermediates. @@ -1069,7 +1068,7 @@ def forward_intermediates( output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex - mask: Optional attention mask + attn_mask: Optional attention mask for masked attention Returns: A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix') @@ -1093,8 +1092,8 @@ def forward_intermediates( H, W = self.embeds.dynamic_feat_size((height, width)) # Create attention mask if patch_type is provided and mask is not - if mask is None and patch_valid is not None: - mask = create_attention_mask(patch_valid, self.num_prefix_tokens, patches.dtype) + if attn_mask is None and patch_valid is not None: + attn_mask = create_attention_mask(patch_valid, self.num_prefix_tokens, patches.dtype) # Forward pass through embedding x = self.embeds(patches, patch_coord=patch_coord) @@ -1107,7 +1106,12 @@ def forward_intermediates( blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x, attn_mask=mask) + if attn_mask is not None: + x = blk(x, attn_mask=attn_mask) + elif self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk. x) + else: + x = blk(x) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) From cd1542aa3977004249c3d1cb682002d8a1dda33a Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Fri, 13 Jun 2025 12:50:44 +0800 Subject: [PATCH 06/12] remove unused import --- timm/models/crossvit.py | 7 +------ timm/models/dla.py | 1 - timm/models/efficientvit_mit.py | 2 +- timm/models/ghostnet.py | 2 +- timm/models/hrnet.py | 1 - timm/models/nfnet.py | 2 +- timm/models/selecsls.py | 1 - timm/models/shvit.py | 3 +-- timm/models/swin_transformer.py | 2 +- timm/models/swin_transformer_v2_cr.py | 2 +- timm/models/vision_transformer_hybrid.py | 5 ++--- 11 files changed, 9 insertions(+), 19 deletions(-) diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index f3d52f8e49..0e1de2fadd 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -14,21 +14,16 @@ NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408 Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +Modified from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ # Copyright IBM All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 - -""" -Modified from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py - -""" from functools import partial from typing import List, Optional, Tuple import torch -import torch.hub import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD diff --git a/timm/models/dla.py b/timm/models/dla.py index 666acd9d9c..197060e4e6 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -10,7 +10,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import create_classifier diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index 86c7fdd171..8b35a04c87 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -19,7 +19,7 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_module -from ._manipulate import checkpoint, checkpoint_seq +from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 686225ad30..fff96e0c2c 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -23,7 +23,7 @@ from ._builder import build_model_with_cfg from ._efficientnet_blocks import SqueezeExcite, ConvBnAct from ._features import feature_take_indices -from ._manipulate import checkpoint, checkpoint_seq +from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs __all__ = ['GhostNet'] diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 75b157d67d..92ee3511cf 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -13,7 +13,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import create_classifier diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index b6857bdd96..68b8b1b6d6 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -19,7 +19,7 @@ from collections import OrderedDict from dataclasses import dataclass, replace from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple import torch import torch.nn as nn diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index fdfa16c318..dc19ff41d1 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -13,7 +13,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import create_classifier diff --git a/timm/models/shvit.py b/timm/models/shvit.py index be3e206ee8..ce61217eb9 100644 --- a/timm/models/shvit.py +++ b/timm/models/shvit.py @@ -11,7 +11,6 @@ year={2024} } """ -import re from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch @@ -429,7 +428,7 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) state_dict = state_dict.get('model', state_dict) # out_dict = {} - # + # import re # replace_rules = [ # (re.compile(r'^blocks1\.'), 'stages.0.blocks.'), # (re.compile(r'^blocks2\.'), 'stages.1.blocks.'), diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 7eeae8316b..e17f16746c 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -24,7 +24,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \ - _assert, use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed, ndgrid + use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed, ndgrid from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index c490fa23ca..6430b61223 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -29,7 +29,7 @@ # -------------------------------------------------------- import logging import math -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 4cf3a7664b..0ff4823497 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -13,15 +13,14 @@ Hacked together by / Copyright 2020, Ross Wightman """ -import math from functools import partial -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Dict, Tuple, Type, Union import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, HybridEmbed +from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_ntuple, HybridEmbed from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model, register_model_deprecations From 869bac2515ccc6bd70e50275f5f29a75df01363c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 20 Jun 2025 11:40:27 -0700 Subject: [PATCH 07/12] Fix forward_intermediates() grad_checkpointing in vision_transformer.py --- timm/models/vision_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index f0cc3de063..05e435e9e9 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -827,7 +827,7 @@ def forward_intermediates( if attn_mask is not None: x = blk(x, attn_mask=attn_mask) elif self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(blk. x) + x = checkpoint(blk, x) else: x = blk(x) if i in take_indices: From b6692ed5e7c6d944d782765d556ae620d68a19db Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 20 Jun 2025 13:49:15 -0700 Subject: [PATCH 08/12] Fix several grad checkpointing issues --- timm/models/efficientnet.py | 2 +- timm/models/fasternet.py | 2 +- timm/models/focalnet.py | 2 +- timm/models/gcvit.py | 2 +- timm/models/ghostnet.py | 2 +- timm/models/hgnet.py | 2 +- timm/models/hieradet_sam2.py | 1 + timm/models/mobilenetv3.py | 2 +- timm/models/mvitv2.py | 2 +- timm/models/pvt_v2.py | 2 +- timm/models/repghost.py | 2 +- timm/models/resnetv2.py | 20 ++++++++++---------- timm/models/rexnet.py | 2 +- timm/models/shvit.py | 2 +- timm/models/starnet.py | 4 ++-- timm/models/swiftformer.py | 2 +- timm/models/swin_transformer_v2.py | 2 +- timm/models/swin_transformer_v2_cr.py | 2 +- 18 files changed, 28 insertions(+), 27 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 232c888d35..7f2f5aa341 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -261,7 +261,7 @@ def forward_intermediates( blocks = self.blocks[:max_index] for feat_idx, blk in enumerate(blocks, start=1): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(blk, x, flatten=True) + x = checkpoint_seq(blk, x) else: x = blk(x) if feat_idx in take_indices: diff --git a/timm/models/fasternet.py b/timm/models/fasternet.py index d73f49a265..b9e4aed249 100644 --- a/timm/models/fasternet.py +++ b/timm/models/fasternet.py @@ -142,7 +142,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.downsample(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x, flatten=True) + x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) return x diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index aa3237925f..3c2bd75643 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -274,7 +274,7 @@ def forward(self, x): x = self.downsample(x) for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint.checkpoint(blk, x) + x = checkpoint(blk, x) else: x = blk(x) return x diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index 214619de9b..367e5dfff5 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -361,7 +361,7 @@ def forward(self, x): global_query = self.global_norm(global_query.permute(0, 2, 3, 1)) for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint.checkpoint(blk, x) + x = checkpoint(blk, x, global_query) else: x = blk(x, global_query) x = self.norm(x) diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index fff96e0c2c..126d638f43 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -728,7 +728,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages, start=1): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(stage, x, flatten=True) + x = checkpoint_seq(stage, x) else: x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/hgnet.py b/timm/models/hgnet.py index 212cbb58ff..5c49f9ddca 100644 --- a/timm/models/hgnet.py +++ b/timm/models/hgnet.py @@ -345,7 +345,7 @@ def __init__( def forward(self, x): x = self.downsample(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x, flatten=False) + x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) return x diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index b9eb850fb3..fbd7ce28ec 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -288,6 +288,7 @@ def __init__( norm_layer = get_norm_layer(norm_layer) act_layer = get_act_layer(act_layer) assert len(stages) == len(window_spec) + self.grad_checkpointing = False self.num_classes = num_classes self.window_spec = window_spec self.output_fmt = 'NHWC' diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index f7b601a17a..eb87bb38d8 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -229,7 +229,7 @@ def forward_intermediates( blocks = self.blocks[:max_index] for feat_idx, blk in enumerate(blocks, start=1): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(blk, x, flatten=True) + x = checkpoint_seq(blk, x) else: x = blk(x) if feat_idx in take_indices: diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index c048a07277..01c4550ed8 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -681,7 +681,7 @@ def __init__( def forward(self, x, feat_size: List[int]): for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x, feat_size = checkpoint.checkpoint(blk, x, feat_size) + x, feat_size = checkpoint(blk, x, feat_size) else: x, feat_size = blk(x, feat_size) return x, feat_size diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index bb1baf6645..0259a1f64e 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -267,7 +267,7 @@ def forward(self, x): x = x.reshape(B, -1, C) for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint.checkpoint(blk, x, feat_size) + x = checkpoint(blk, x, feat_size) else: x = blk(x, feat_size) x = self.norm(x) diff --git a/timm/models/repghost.py b/timm/models/repghost.py index 93462eef46..c5a7d93a4f 100644 --- a/timm/models/repghost.py +++ b/timm/models/repghost.py @@ -337,7 +337,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages, start=1): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(stage, x, flatten=True) + x = checkpoint_seq(stage, x) else: x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 54c86112df..b7d4173321 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -397,6 +397,8 @@ def __init__( **block_kwargs: Any, ): super(ResNetStage, self).__init__() + self.grad_checkpointing = False + first_dilation = 1 if dilation in (1, 2) else 2 layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer) proj_layer = DownsampleAvg if avg_down else DownsampleConv @@ -431,7 +433,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: Output tensor. """ - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) return x @@ -631,7 +636,8 @@ def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: @torch.jit.ignore def set_grad_checkpointing(self, enable: bool = True) -> None: """Enable or disable gradient checkpointing.""" - self.grad_checkpointing = enable + for s in self.stages: + s.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: @@ -689,10 +695,7 @@ def forward_intermediates( stages = self.stages[:max_index] for feat_idx, stage in enumerate(stages, start=1): - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(stage, x, flatten=True) - else: - x = stage(x) + x = stage(x) if feat_idx in take_indices: if feat_idx == last_idx: x_inter = self.norm(x) if norm else x @@ -734,10 +737,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: Feature tensor. """ x = self.stem(x) - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.stages, x, flatten=True) - else: - x = self.stages(x) + x = self.stages(x) x = self.norm(x) return x diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index dc35b80f47..77b801db87 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -429,7 +429,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: """ x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.features, x, flatten=True) + x = checkpoint_seq(self.features, x) else: x = self.features(x) return x diff --git a/timm/models/shvit.py b/timm/models/shvit.py index ce61217eb9..c165f1a280 100644 --- a/timm/models/shvit.py +++ b/timm/models/shvit.py @@ -244,7 +244,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.downsample(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x, flatten=True) + x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) return x diff --git a/timm/models/starnet.py b/timm/models/starnet.py index 3a9a0a9455..398eac4433 100644 --- a/timm/models/starnet.py +++ b/timm/models/starnet.py @@ -199,7 +199,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(stages, x, flatten=True) + x = checkpoint_seq(stages, x) else: x = stage(x) if feat_idx in take_indices: @@ -236,7 +236,7 @@ def prune_intermediate_layers( def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.stages, x, flatten=True) + x = checkpoint_seq(self.stages, x) else: x = self.stages(x) x = self.norm(x) diff --git a/timm/models/swiftformer.py b/timm/models/swiftformer.py index 5998c233fd..38df6f1638 100644 --- a/timm/models/swiftformer.py +++ b/timm/models/swiftformer.py @@ -304,7 +304,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.downsample(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x, flatten=True) + x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) return x diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index f7b758aa8b..35c0daa8ac 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -619,7 +619,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint.checkpoint(blk, x) + x = checkpoint(blk, x) else: x = blk(x) return x diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 6430b61223..1ef3164fd9 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -636,7 +636,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for block in self.blocks: # Perform checkpointing if utilized if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint.checkpoint(block, x) + x = checkpoint(block, x) else: x = block(x) x = bhwc_to_bchw(x) From 1f9eb663715f40e4a5a78e57180f8f0240198ec2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 20 Jun 2025 13:56:45 -0700 Subject: [PATCH 09/12] Add basic grad checkpointing tests --- tests/test_models.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index b4686a3efe..58dad77e27 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -186,6 +186,18 @@ def test_model_forward(model_name, batch_size): assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs' + # Test that grad-checkpointing, if supported, doesn't cause model failures or change in output + try: + model.set_grad_checkpointing() + except: + # throws if not supported, that's fine + pass + else: + outputs2 = model(inputs) + if isinstance(outputs, tuple): + outputs2 = torch.cat(outputs2) + assert torch.allclose(outputs, outputs2, rtol=1e-4, atol=1e-5), 'Output does not match' + @pytest.mark.base @pytest.mark.timeout(timeout120) @@ -529,6 +541,20 @@ def test_model_forward_intermediates(model_name, batch_size): output2 = model.forward_features(inpt) assert torch.allclose(output, output2) + # Test that grad-checkpointing, if supported + try: + model.set_grad_checkpointing() + except: + # throws if not supported, that's fine + pass + else: + output3, _ = model.forward_intermediates( + inpt, + output_fmt=output_fmt, + ) + assert torch.allclose(output, output3, rtol=1e-4, atol=1e-5), 'Output does not match' + + def _create_fx_model(model, train=False): # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode @@ -717,4 +743,4 @@ def test_model_forward_torchscript_with_features_fx(model_name, batch_size): for tensor in outputs: assert tensor.shape[0] == batch_size - assert not torch.isnan(tensor).any(), 'Output included NaNs' \ No newline at end of file + assert not torch.isnan(tensor).any(), 'Output included NaNs' From 725071e02026dd6e1a626264a18a4d5c51df7ca4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 20 Jun 2025 14:23:52 -0700 Subject: [PATCH 10/12] More forward_intermediate specific grad checkpointing fixes --- timm/models/starnet.py | 2 +- timm/models/tiny_vit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/starnet.py b/timm/models/starnet.py index 398eac4433..9ed32a85d7 100644 --- a/timm/models/starnet.py +++ b/timm/models/starnet.py @@ -199,7 +199,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(stages, x) + x = checkpoint_seq(stage, x) else: x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index 9bc8dde9ec..39bacc850c 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -571,7 +571,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(stages, x) + x = checkpoint(stage, x) else: x = stage(x) if feat_idx in take_indices: From b1a9a9e28af3a3867820cfafa6a2d571a684b1dc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 20 Jun 2025 14:32:22 -0700 Subject: [PATCH 11/12] NaFlexVit and NextVit forward_intermediates grad checkpointing fixes --- timm/models/naflexvit.py | 2 +- timm/models/nextvit.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py index c3f4b5abf7..5794fdf5e0 100644 --- a/timm/models/naflexvit.py +++ b/timm/models/naflexvit.py @@ -1257,7 +1257,7 @@ def forward_intermediates( if attn_mask is not None: x = blk(x, attn_mask=attn_mask) elif self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(blk. x) + x = checkpoint(blk, x) else: x = blk(x) if i in take_indices: diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index 9134ec8622..402a9d76ea 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -17,7 +17,7 @@ from timm.layers import ClassifierHead from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint, checkpoint_seq from ._registry import generate_default_cfgs, register_model __all__ = ['NextViT'] @@ -595,7 +595,7 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(stage, x) + x = checkpoint(stage, x) else: x = stage(x) if feat_idx in take_indices: From b3a87738dce3989a951ac47671abd190e05610b3 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sat, 21 Jun 2025 19:59:50 +0800 Subject: [PATCH 12/12] Delete grad_checkpointing from ResNetV2 class --- timm/models/resnetv2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index b7d4173321..0b78e7b44d 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -609,7 +609,6 @@ def __init__( ) self.init_weights(zero_init_last=zero_init_last) - self.grad_checkpointing = False @torch.jit.ignore def init_weights(self, zero_init_last: bool = True) -> None: