From ee528a976bf1d5544a05ba4b84f215f711a124c5 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 4 May 2025 23:49:43 +0800 Subject: [PATCH 01/15] support tresnet --- timm/models/tresnet.py | 61 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 37c37e2fdc..59f2507e46 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -7,13 +7,14 @@ """ from collections import OrderedDict from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs, register_model_deprecations @@ -228,6 +229,64 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.body) - 1, indices) + print(take_indices, max_index) + + # forward pass + x = self.body[0](x) # s2d + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = [self.body[1], self.body[2], self.body[3], self.body[4], self.body[5]] + else: + stages = self.body[1:max_index + 2] + print(len(stages)) + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.body) - 1, indices) + self.body = self.body[1:max_index + 2] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): if self.grad_checkpointing and not torch.jit.is_scripting(): x = self.body.s2d(x) From 2ece990ffaa84b9730387285f84e924d60d25bea Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 4 May 2025 23:49:58 +0800 Subject: [PATCH 02/15] support vovnet --- timm/models/vovnet.py | 64 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 08e6d0b6c3..0b48a34c14 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -11,7 +11,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ -from typing import List, Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -20,6 +20,7 @@ from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \ create_attn, create_norm_act_layer from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -264,6 +265,67 @@ def reset_classifier(self, num_classes, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(5, indices) + + # forward pass + feat_idx = 0 + x = self.stem[:-1](x) + if feat_idx in take_indices: + intermediates.append(x) + + x = self.stem[-1](x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index] + + for feat_idx, stage in enumerate(stages, start=1): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(5, indices) + self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) return self.stages(x) From 23dd92aa524d8f99e94ab8dd428feacd12936dba Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Mon, 5 May 2025 01:35:02 +0800 Subject: [PATCH 03/15] support tiny_vit --- timm/models/tiny_vit.py | 59 ++++++++++++++++++++++++++++++++++++++++- timm/models/tresnet.py | 2 -- 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index 12a5ef2f16..d238fa5b2d 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -10,7 +10,7 @@ import itertools from functools import partial -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -20,6 +20,7 @@ from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\ trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_module from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -536,6 +537,62 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.patch_embed(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 59f2507e46..28e9028570 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -253,7 +253,6 @@ def forward_intermediates( assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] take_indices, max_index = feature_take_indices(len(self.body) - 1, indices) - print(take_indices, max_index) # forward pass x = self.body[0](x) # s2d @@ -261,7 +260,6 @@ def forward_intermediates( stages = [self.body[1], self.body[2], self.body[3], self.body[4], self.body[5]] else: stages = self.body[1:max_index + 2] - print(len(stages)) for feat_idx, stage in enumerate(stages): x = stage(x) From 880b76191432fa0758e6d32ac7494423d1d15765 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Mon, 5 May 2025 04:21:12 +0800 Subject: [PATCH 04/15] support rexnet, resnetv2, repvit and repghostnet --- timm/models/repghost.py | 69 ++++++++++++++++++++++++++++++++++++++++- timm/models/repvit.py | 63 +++++++++++++++++++++++++++++++++++-- timm/models/resnetv2.py | 67 ++++++++++++++++++++++++++++++++++++++- timm/models/rexnet.py | 64 +++++++++++++++++++++++++++++++++++++- 4 files changed, 257 insertions(+), 6 deletions(-) diff --git a/timm/models/repghost.py b/timm/models/repghost.py index 4b802d79b6..77fc35d59e 100644 --- a/timm/models/repghost.py +++ b/timm/models/repghost.py @@ -6,7 +6,7 @@ """ import copy from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -16,6 +16,7 @@ from timm.layers import SelectAdaptivePool2d, Linear, make_divisible from ._builder import build_model_with_cfg from ._efficientnet_blocks import SqueezeExcite, ConvBnAct +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -294,6 +295,72 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + take_indices = [stage_ends[i]+1 for i in take_indices] + max_index = stage_ends[max_index] + + # forward pass + feat_idx = 0 + x = self.conv_stem(x) + if feat_idx in take_indices: + intermediates.append(x) + x = self.bn1(x) + x = self.act1(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.blocks + else: + stages = self.blocks[:max_index + 1] + + for feat_idx, stage in enumerate(stages, start=1): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + max_index = stage_ends[max_index] + self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.conv_stem(x) x = self.bn1(x) diff --git a/timm/models/repvit.py b/timm/models/repvit.py index 7dcb2cd939..ddcfed55c8 100644 --- a/timm/models/repvit.py +++ b/timm/models/repvit.py @@ -14,9 +14,7 @@ Adapted from official impl at https://github.com/jameslahm/RepViT """ - -__all__ = ['RepVit'] -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -24,9 +22,12 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs +__all__ = ['RepVit'] + class ConvNorm(nn.Sequential): def __init__(self, in_dim, out_dim, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): @@ -333,6 +334,62 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, def set_distilled_training(self, enable=True): self.head.distilled_training = enable + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index d7d2905bd9..1bac794c5e 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -31,7 +31,7 @@ from collections import OrderedDict # pylint: disable=g-importing-member from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -40,6 +40,7 @@ from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dS0, FilterResponseNormTlu2d, ClassifierHead, \ DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -543,6 +544,70 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(5, indices) + + # forward pass + feat_idx = 0 + x = self.stem(x) + if feat_idx in take_indices: + intermediates.append(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index] + + for feat_idx, stage in enumerate(stages, start=1): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(5, indices) + self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 9971728c24..dd3cb4f32f 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -12,7 +12,7 @@ from functools import partial from math import ceil -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -21,6 +21,7 @@ from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule from ._builder import build_model_with_cfg from ._efficientnet_builder import efficientnet_init_weights +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -234,6 +235,67 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + take_indices = [stage_ends[i] for i in take_indices] + max_index = stage_ends[max_index] + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.features + else: + stages = self.features[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + max_index = stage_ends[max_index] + self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): From bac56a61413c4e4553572aed066133a913b7e94a Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Mon, 5 May 2025 04:22:11 +0800 Subject: [PATCH 05/15] fix tresnet and rdnet --- timm/models/rdnet.py | 1 + timm/models/tresnet.py | 17 ++++++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/timm/models/rdnet.py b/timm/models/rdnet.py index 393dc97bbc..b55cc33c2a 100644 --- a/timm/models/rdnet.py +++ b/timm/models/rdnet.py @@ -355,6 +355,7 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): def forward_features(self, x): x = self.stem(x) x = self.dense_stages(x) + x = self.norm_pre(x) return x def forward_head(self, x, pre_logits: bool = False): diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 28e9028570..0fb76fa40c 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -252,14 +252,15 @@ def forward_intermediates( """ assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] - take_indices, max_index = feature_take_indices(len(self.body) - 1, indices) - + stage_ends = [1, 2, 3, 4, 5] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + take_indices = [stage_ends[i] for i in take_indices] + max_index = stage_ends[max_index] # forward pass - x = self.body[0](x) # s2d if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript - stages = [self.body[1], self.body[2], self.body[3], self.body[4], self.body[5]] + stages = self.body else: - stages = self.body[1:max_index + 2] + stages = self.body[:max_index + 1] for feat_idx, stage in enumerate(stages): x = stage(x) @@ -279,8 +280,10 @@ def prune_intermediate_layers( ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.body) - 1, indices) - self.body = self.body[1:max_index + 2] # truncate blocks w/ stem as idx 0 + stage_ends = [1, 2, 3, 4, 5] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + max_index = stage_ends[max_index] + self.body = self.body[:max_index + 1] # truncate blocks w/ stem as idx 0 if prune_head: self.reset_classifier(0, '') return take_indices From 85be962f449e1d9fee70852183ae4643c469adea Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Mon, 5 May 2025 05:30:57 +0800 Subject: [PATCH 06/15] support mambaout, metaformer, nest, nextvit, pvt_v2 --- timm/models/mambaout.py | 65 ++++++++++++++++++++++++++++++++++++++- timm/models/metaformer.py | 59 ++++++++++++++++++++++++++++++++++- timm/models/nest.py | 63 +++++++++++++++++++++++++++++++++++++ timm/models/nextvit.py | 63 ++++++++++++++++++++++++++++++++++++- timm/models/pvt_v2.py | 59 ++++++++++++++++++++++++++++++++++- 5 files changed, 305 insertions(+), 4 deletions(-) diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index f53a9cdfc2..8eac6e7bf2 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -6,7 +6,7 @@ InceptionNeXt (https://github.com/sail-sg/inceptionnext) """ from collections import OrderedDict -from typing import Optional +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -14,6 +14,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -417,6 +418,68 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW or NHWC.' + channel_first = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if channel_first: + # reshape to BCHW output format + intermediates = [y.permute(0, 3, 1, 2).contiguous() for y in intermediates] + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + + def forward_features(self, x): x = self.stem(x) x = self.stages(x) diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index 23ef37242d..490852cfe4 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -28,7 +28,7 @@ from collections import OrderedDict from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -40,6 +40,7 @@ from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \ use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -597,6 +598,62 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): final = nn.Identity() self.head.fc = final + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_head(self, x: Tensor, pre_logits: bool = False): # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( x = self.head.global_pool(x) diff --git a/timm/models/nest.py b/timm/models/nest.py index 1d9c752105..9ee504632d 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -19,6 +19,7 @@ import logging import math from functools import partial +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -28,6 +29,7 @@ from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq, named_apply from ._registry import register_model, generate_default_cfgs, register_model_deprecations @@ -420,6 +422,67 @@ def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.global_pool, self.head = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.levels), indices) + + # forward pass + x = self.patch_embed(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.levels + else: + stages = self.levels[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + # Layer norm done over channel dim only (to NHWC and back) + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.levels), indices) + self.levels = self.levels[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) x = self.levels(x) diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index 01e63fce8a..5d6ec9724d 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -6,7 +6,7 @@ """ # Copyright (c) ByteDance Inc. All rights reserved. from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -16,6 +16,7 @@ from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn from timm.layers import ClassifierHead from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -560,6 +561,66 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 6977350ad1..8cd42fe8a1 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -16,7 +16,7 @@ """ import math -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -25,6 +25,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint from ._registry import register_model, generate_default_cfgs @@ -386,6 +387,62 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.global_pool = global_pool self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.patch_embed(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) x = self.stages(x) From 46433adb288f66ef5213cb677f230dee2b491ea0 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Mon, 5 May 2025 23:11:59 +0800 Subject: [PATCH 07/15] support more models inception_next, hgnet, gcvit, focalnet, inception_v4 --- timm/models/focalnet.py | 69 ++++++++++++++++++++++++++++++++++- timm/models/gcvit.py | 68 ++++++++++++++++++++++++++++++++-- timm/models/hgnet.py | 59 +++++++++++++++++++++++++++++- timm/models/inception_next.py | 59 +++++++++++++++++++++++++++++- timm/models/inception_v4.py | 62 +++++++++++++++++++++++++++++++ 5 files changed, 310 insertions(+), 7 deletions(-) diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index 5608facb92..51ab4d0803 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -18,7 +18,7 @@ # Written by Jianwei Yang (jianwyan@microsoft.com) # -------------------------------------------------------- from functools import partial -from typing import Callable, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -26,6 +26,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint from ._registry import generate_default_cfgs, register_model @@ -458,6 +459,72 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.layers), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.layers + else: + stages = self.layers[:max_index + 1] + + last_idx = len(self.layers) + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + if norm and feat_idx == last_idx: + x_inter = self.norm(x) # applying final norm to last intermediate + else: + x_inter = x + intermediates.append(x_inter) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.layers), indices) + self.layers = self.layers[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) x = self.layers(x) diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index b31b5768bd..c862dc4a20 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -30,6 +30,7 @@ from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \ get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._manipulate import named_apply, checkpoint from ._registry import register_model, generate_default_cfgs @@ -397,7 +398,7 @@ def __init__( act_layer = get_act_layer(act_layer) norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps) - + self.feature_info = [] img_size = to_2tuple(img_size) feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4 self.global_pool = global_pool @@ -441,6 +442,7 @@ def __init__( norm_layer=norm_layer, norm_layer_cl=norm_layer_cl, )) + self.feature_info += [dict(num_chs=stages[-1].dim, reduction=2**(i+2), module=f'stages.{i}')] self.stages = nn.Sequential(*stages) # Classifier head @@ -494,6 +496,62 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): global_pool = self.head.global_pool.pool_type self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.stem(x) x = self.stages(x) @@ -509,9 +567,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _create_gcvit(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - model = build_model_with_cfg(GlobalContextVit, variant, pretrained, **kwargs) + model = build_model_with_cfg( + GlobalContextVit, variant, pretrained, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **kwargs + ) return model diff --git a/timm/models/hgnet.py b/timm/models/hgnet.py index ea0a92d955..3e44c9dc9c 100644 --- a/timm/models/hgnet.py +++ b/timm/models/hgnet.py @@ -6,7 +6,7 @@ PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py """ -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -15,6 +15,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._registry import register_model, generate_default_cfgs from ._manipulate import checkpoint_seq @@ -508,6 +509,62 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, 'avg') + return take_indices + def forward_features(self, x): x = self.stem(x) return self.stages(x) diff --git a/timm/models/inception_next.py b/timm/models/inception_next.py index 3c4906aa05..2fcf123ffa 100644 --- a/timm/models/inception_next.py +++ b/timm/models/inception_next.py @@ -4,7 +4,7 @@ """ from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -12,6 +12,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_, DropPath, to_2tuple, get_padding, SelectAdaptivePool2d from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -349,6 +350,62 @@ def set_grad_checkpointing(self, enable=True): def no_weight_decay(self): return set() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, 'avg') + return take_indices + def forward_features(self, x): x = self.stem(x) x = self.stages(x) diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index a435533fd4..315328a21e 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -3,6 +3,7 @@ based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) """ from functools import partial +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -10,6 +11,7 @@ from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import create_classifier, ConvNormAct from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._registry import register_model, generate_default_cfgs __all__ = ['InceptionV4'] @@ -285,6 +287,66 @@ def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.global_pool, self.last_linear = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + take_indices = [stage_ends[i] for i in take_indices] + max_index = stage_ends[max_index] + + # forward pass + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.features + else: + stages = self.features[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + max_index = stage_ends[max_index] + self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): return self.features(x) From 45c4d44a023c104ac86bf6bf113f021726cb3dcd Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Mon, 5 May 2025 23:15:39 +0800 Subject: [PATCH 08/15] fix norm at last feat_idx --- timm/models/mambaout.py | 1 - timm/models/maxxvit.py | 3 ++- timm/models/nest.py | 12 +++++++++--- timm/models/nextvit.py | 10 ++++++++-- timm/models/rdnet.py | 5 +++-- timm/models/resnetv2.py | 11 ++++++++--- 6 files changed, 30 insertions(+), 12 deletions(-) diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index 8eac6e7bf2..71d12fe672 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -479,7 +479,6 @@ def prune_intermediate_layers( self.reset_classifier(0, '') return take_indices - def forward_features(self, x): x = self.stem(x) x = self.stages(x) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index e4375b34e5..b7d4e7e44c 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -1302,7 +1302,8 @@ def forward_intermediates( if intermediates_only: return intermediates - x = self.norm(x) + if feat_idx == last_idx: + x = self.norm(x) return x, intermediates diff --git a/timm/models/nest.py b/timm/models/nest.py index 9ee504632d..8b4ce5edbd 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -449,6 +449,7 @@ def forward_intermediates( # forward pass x = self.patch_embed(x) + last_idx = self.num_blocks - 1 if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript stages = self.levels else: @@ -457,13 +458,18 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + if norm and feat_idx == last_idx: + x_inter = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + intermediates.append(x_inter) + else: + intermediates.append(x) if intermediates_only: return intermediates - # Layer norm done over channel dim only (to NHWC and back) - x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + if feat_idx == last_idx: + # Layer norm done over channel dim only (to NHWC and back) + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x, intermediates diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index 5d6ec9724d..2f232e2990 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -588,6 +588,7 @@ def forward_intermediates( # forward pass x = self.stem(x) + last_idx = len(self.stages) - 1 if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript stages = self.stages else: @@ -596,12 +597,17 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + if feat_idx == last_idx: + x_inter = self.norm(x) if norm else x + intermediates.append(x_inter) + else: + intermediates.append(x) if intermediates_only: return intermediates - x = self.norm(x) + if feat_idx == last_idx: + x = self.norm(x) return x, intermediates diff --git a/timm/models/rdnet.py b/timm/models/rdnet.py index b55cc33c2a..3c556e37bd 100644 --- a/timm/models/rdnet.py +++ b/timm/models/rdnet.py @@ -309,7 +309,7 @@ def forward_intermediates( x = self.stem(x) if feat_idx in take_indices: intermediates.append(x) - + last_idx = len(self.dense_stages) if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript dense_stages = self.dense_stages else: @@ -324,7 +324,8 @@ def forward_intermediates( if intermediates_only: return intermediates - x = self.norm_pre(x) + if feat_idx == last_idx: + x = self.norm_pre(x) return x, intermediates diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 1bac794c5e..1cc3b86443 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -574,7 +574,7 @@ def forward_intermediates( x = self.stem(x) if feat_idx in take_indices: intermediates.append(x) - + last_idx = len(self.stages) if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript stages = self.stages else: @@ -583,12 +583,17 @@ def forward_intermediates( for feat_idx, stage in enumerate(stages, start=1): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + if feat_idx == last_idx: + x_inter = self.norm(x) if norm else x + intermediates.append(x_inter) + else: + intermediates.append(x) if intermediates_only: return intermediates - x = self.norm(x) + if feat_idx == last_idx: + x = self.norm(x) return x, intermediates From c0b1183d1a25a4d99f8edef23a7661a2a171cbe5 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Tue, 6 May 2025 00:24:57 +0800 Subject: [PATCH 09/15] support efficientvit, edgenext, davit --- timm/models/davit.py | 69 +++++++++++++++++- timm/models/edgenext.py | 69 +++++++++++++++++- timm/models/efficientvit_mit.py | 117 ++++++++++++++++++++++++++++++- timm/models/efficientvit_msra.py | 60 +++++++++++++++- 4 files changed, 311 insertions(+), 4 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 650098880c..f538ecca84 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -12,7 +12,7 @@ # All rights reserved. # This source code is licensed under the MIT license from functools import partial -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -23,6 +23,7 @@ from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn from timm.layers import NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -636,6 +637,72 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + last_idx = len(self.stages) - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + if norm and feat_idx == last_idx: + x_inter = self.norm_pre(x) # applying final norm to last intermediate + else: + x_inter = x + intermediates.append(x_inter) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.norm_pre(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm_pre = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index d768b1dc33..e21be9713b 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -9,7 +9,7 @@ """ import math from functools import partial -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -19,6 +19,7 @@ from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \ NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_module from ._manipulate import named_apply, checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -418,6 +419,72 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + last_idx = len(self.stages) - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + if norm and feat_idx == last_idx: + x_inter = self.norm_pre(x) # applying final norm to last intermediate + else: + x_inter = x + intermediates.append(x_inter) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.norm_pre(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm_pre = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) x = self.stages(x) diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index 34be806b1e..27872310e0 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -7,7 +7,7 @@ """ __all__ = ['EfficientVit', 'EfficientVitLarge'] -from typing import List, Optional +from typing import List, Optional, Tuple, Union from functools import partial import torch @@ -17,6 +17,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_module from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -754,6 +755,63 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -851,6 +909,63 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index 7e5c09a475..91caaa5a4d 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -9,7 +9,7 @@ __all__ = ['EfficientVitMsra'] import itertools from collections import OrderedDict -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -17,6 +17,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -475,6 +476,63 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.head = NormLinear( self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.patch_embed(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) if self.grad_checkpointing and not torch.jit.is_scripting(): From d25805960e356af460534020b2a822f15c354f60 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Tue, 6 May 2025 00:25:21 +0800 Subject: [PATCH 10/15] support efficientformer_v2 --- timm/models/efficientformer_v2.py | 70 ++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/timm/models/efficientformer_v2.py b/timm/models/efficientformer_v2.py index dcf6499537..5bdc473fc1 100644 --- a/timm/models/efficientformer_v2.py +++ b/timm/models/efficientformer_v2.py @@ -16,7 +16,7 @@ """ import math from functools import partial -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -25,6 +25,7 @@ from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -625,6 +626,73 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): def set_distilled_training(self, enable=True): self.distilled_training = enable + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + + last_idx = len(self.stages) - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + if feat_idx == last_idx: + x_inter = self.norm(x) if norm else x + intermediates.append(x_inter) + else: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) x = self.stages(x) From 39a6c3027c5511bad9dc11ec80ff5749c3250853 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Tue, 6 May 2025 00:25:53 +0800 Subject: [PATCH 11/15] update FEAT_INTER_FILTERS list --- tests/test_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 3ba3615db4..ba0a1fc82e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -53,7 +53,10 @@ 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos', 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet', - 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*' + 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', + 'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest', + 'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext', + 'davit', ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. From 748821e9f9c96a3f058a86bfa151a6e0c5058f79 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Tue, 6 May 2025 00:56:36 +0800 Subject: [PATCH 12/15] fix final norm only apply at last indice --- timm/models/rdnet.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/timm/models/rdnet.py b/timm/models/rdnet.py index 3c556e37bd..246030af12 100644 --- a/timm/models/rdnet.py +++ b/timm/models/rdnet.py @@ -318,8 +318,11 @@ def forward_intermediates( feat_idx += 1 x = stage(x) if feat_idx in take_indices: - # NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled - intermediates.append(x) + if norm and feat_idx == last_idx: + x_inter = self.norm_pre(x) # applying final norm to last intermediate + else: + x_inter = x + intermediates.append(x_inter) if intermediates_only: return intermediates From 2af810f956df61f67dec84a88d6370414cb97a50 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Tue, 6 May 2025 03:43:49 +0800 Subject: [PATCH 13/15] fix nest type error --- timm/models/nest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/nest.py b/timm/models/nest.py index 8b4ce5edbd..9a423a9776 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -449,7 +449,7 @@ def forward_intermediates( # forward pass x = self.patch_embed(x) - last_idx = self.num_blocks - 1 + last_idx = len(self.num_blocks) - 1 if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript stages = self.levels else: From 941ec01f9706e67cd1f36a87581ea70130ab8159 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Thu, 8 May 2025 00:57:18 +0800 Subject: [PATCH 14/15] update some model --- timm/models/convnext.py | 20 ++++++------ timm/models/focalnet.py | 2 +- timm/models/mvitv2.py | 10 +++--- timm/models/pit.py | 68 ++++++++++++++++++++++++++++++++++++++++- timm/models/rdnet.py | 22 +++++++------ timm/models/resnetv2.py | 8 +++-- timm/models/xcit.py | 3 +- 7 files changed, 104 insertions(+), 29 deletions(-) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 47f2bf87c1..2f445118dd 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -452,29 +452,29 @@ def forward_intermediates( """ assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] - take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) + take_indices, max_index = feature_take_indices(len(self.stages), indices) # forward pass - feat_idx = 0 # stem is index 0 x = self.stem(x) - if feat_idx in take_indices: - intermediates.append(x) + last_idx = len(self.stages) - 1 if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript stages = self.stages else: - stages = self.stages[:max_index] - for stage in stages: - feat_idx += 1 + stages = self.stages[:max_index + 1] + for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - # NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled - intermediates.append(x) + if norm and feat_idx == last_idx: + intermediates.append(self.norm_pre(x)) + else: + intermediates.append(x) if intermediates_only: return intermediates - x = self.norm_pre(x) + if feat_idx == last_idx: + x = self.norm_pre(x) return x, intermediates diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index 51ab4d0803..ec7cd1cff1 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -491,7 +491,7 @@ def forward_intermediates( else: stages = self.layers[:max_index + 1] - last_idx = len(self.layers) + last_idx = len(self.layers) - 1 for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index f790fd0d13..c048a07277 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -870,10 +870,11 @@ def forward_intermediates( if self.pos_embed is not None: x = x + self.pos_embed - for i, stage in enumerate(self.stages): + last_idx = len(self.stages) - 1 + for feat_idx, stage in enumerate(self.stages): x, feat_size = stage(x, feat_size) - if i in take_indices: - if norm and i == (len(self.stages) - 1): + if feat_idx in take_indices: + if norm and feat_idx == last_idx: x_inter = self.norm(x) # applying final norm last intermediate else: x_inter = x @@ -887,7 +888,8 @@ def forward_intermediates( if intermediates_only: return intermediates - x = self.norm(x) + if feat_idx == last_idx: + x = self.norm(x) return x, intermediates diff --git a/timm/models/pit.py b/timm/models/pit.py index 3a1090b89f..109cfaf815 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -14,7 +14,7 @@ import math import re from functools import partial -from typing import Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Tuple, Union import torch from torch import nn @@ -22,6 +22,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_, to_2tuple from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._registry import register_model, generate_default_cfgs from .vision_transformer import Block @@ -254,6 +255,71 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): if self.head_dist is not None: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.transformers), indices) + + # forward pass + x = self.patch_embed(x) + x = self.pos_drop(x + self.pos_embed) + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) + + last_idx = len(self.transformers) - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.transformers + else: + stages = self.transformers[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x, cls_tokens = stage((x, cls_tokens)) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + cls_tokens = self.norm(cls_tokens) + + return cls_tokens, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.transformers), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) x = self.pos_drop(x + self.pos_embed) diff --git a/timm/models/rdnet.py b/timm/models/rdnet.py index 246030af12..a3a205fff6 100644 --- a/timm/models/rdnet.py +++ b/timm/models/rdnet.py @@ -302,20 +302,20 @@ def forward_intermediates( """ assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] - take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices) + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + take_indices = [stage_ends[i] for i in take_indices] + max_index = stage_ends[max_index] # forward pass - feat_idx = 0 # stem is index 0 x = self.stem(x) - if feat_idx in take_indices: - intermediates.append(x) - last_idx = len(self.dense_stages) + + last_idx = len(self.dense_stages) - 1 if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript dense_stages = self.dense_stages else: - dense_stages = self.dense_stages[:max_index] - for stage in dense_stages: - feat_idx += 1 + dense_stages = self.dense_stages[:max_index + 1] + for feat_idx, stage in enumerate(dense_stages): x = stage(x) if feat_idx in take_indices: if norm and feat_idx == last_idx: @@ -340,8 +340,10 @@ def prune_intermediate_layers( ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices) - self.dense_stages = self.dense_stages[:max_index] # truncate blocks w/ stem as idx 0 + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + max_index = stage_ends[max_index] + self.dense_stages = self.dense_stages[:max_index + 1] # truncate blocks w/ stem as idx 0 if prune_norm: self.norm_pre = nn.Identity() if prune_head: diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 1cc3b86443..5cc164ae1b 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -571,9 +571,13 @@ def forward_intermediates( # forward pass feat_idx = 0 - x = self.stem(x) + H, W = x.shape[-2:] + for stem in self.stem: + x = stem(x) + if x.shape[-2:] == (H //2, W //2): + x_down = x if feat_idx in take_indices: - intermediates.append(x) + intermediates.append(x_down) last_idx = len(self.stages) if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript stages = self.stages diff --git a/timm/models/xcit.py b/timm/models/xcit.py index e6cf87b789..250749f1cf 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -494,7 +494,8 @@ def forward_intermediates( # NOTE not supporting return of class tokens x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) for blk in self.cls_attn_blocks: - x = blk(x) + x = blk(x) + x = self.norm(x) return x, intermediates From 20c016e27dca24b9a954611e468877b4430f5772 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Thu, 8 May 2025 02:02:54 +0800 Subject: [PATCH 15/15] fix pit & add to test --- tests/test_models.py | 2 +- timm/models/convnext.py | 4 ++-- timm/models/pit.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index bb2a92ed91..aa866ccdb6 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -56,7 +56,7 @@ 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest', 'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext', - 'davit', + 'davit', 'rdnet', 'convnext', 'pit' ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 2f445118dd..e2eb48d37f 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -486,8 +486,8 @@ def prune_intermediate_layers( ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) - self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 if prune_norm: self.norm_pre = nn.Identity() if prune_head: diff --git a/timm/models/pit.py b/timm/models/pit.py index 109cfaf815..1d5386a921 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -313,7 +313,7 @@ def prune_intermediate_layers( """ Prune layers not required for specified intermediates. """ take_indices, max_index = feature_take_indices(len(self.transformers), indices) - self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + self.transformers = self.transformers[:max_index + 1] # truncate blocks w/ stem as idx 0 if prune_norm: self.norm = nn.Identity() if prune_head: @@ -380,7 +380,7 @@ def _create_pit(variant, pretrained=False, **kwargs): variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, - feature_cfg=dict(feature_cls='hook', no_rewrite=True, out_indices=out_indices), + feature_cfg=dict(feature_cls='hook', out_indices=out_indices), **kwargs, ) return model