Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*'
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*',
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
'davit', 'rdnet', 'convnext', 'pit'
]

# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
Expand Down
24 changes: 12 additions & 12 deletions timm/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,29 +452,29 @@ def forward_intermediates(
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
take_indices, max_index = feature_take_indices(len(self.stages), indices)

# forward pass
feat_idx = 0 # stem is index 0
x = self.stem(x)
if feat_idx in take_indices:
intermediates.append(x)

last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index]
for stage in stages:
feat_idx += 1
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled
intermediates.append(x)
if norm and feat_idx == last_idx:
intermediates.append(self.norm_pre(x))
else:
intermediates.append(x)

if intermediates_only:
return intermediates

x = self.norm_pre(x)
if feat_idx == last_idx:
x = self.norm_pre(x)

return x, intermediates

Expand All @@ -486,8 +486,8 @@ def prune_intermediate_layers(
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm_pre = nn.Identity()
if prune_head:
Expand Down
69 changes: 68 additions & 1 deletion timm/models/davit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# All rights reserved.
# This source code is licensed under the MIT license
from functools import partial
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -23,6 +23,7 @@
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
Expand Down Expand Up @@ -636,6 +637,72 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.

Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:

"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)

# forward pass
x = self.stem(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]

for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm_pre(x) # applying final norm to last intermediate
else:
x_inter = x
intermediates.append(x_inter)

if intermediates_only:
return intermediates

if feat_idx == last_idx:
x = self.norm_pre(x)

return x, intermediates

def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm_pre = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices

def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
Expand Down
69 changes: 68 additions & 1 deletion timm/models/edgenext.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""
import math
from functools import partial
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand All @@ -19,6 +19,7 @@
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_module
from ._manipulate import named_apply, checkpoint_seq
from ._registry import register_model, generate_default_cfgs
Expand Down Expand Up @@ -418,6 +419,72 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.

Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:

"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)

# forward pass
x = self.stem(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]

for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm_pre(x) # applying final norm to last intermediate
else:
x_inter = x
intermediates.append(x_inter)

if intermediates_only:
return intermediates

if feat_idx == last_idx:
x = self.norm_pre(x)

return x, intermediates

def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm_pre = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices

def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
Expand Down
70 changes: 69 additions & 1 deletion timm/models/efficientformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""
import math
from functools import partial
from typing import Dict, Optional
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -25,6 +25,7 @@
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model

Expand Down Expand Up @@ -625,6 +626,73 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
def set_distilled_training(self, enable=True):
self.distilled_training = enable

def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.

Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:

"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)

# forward pass
x = self.stem(x)

last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]

for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if feat_idx == last_idx:
x_inter = self.norm(x) if norm else x
intermediates.append(x_inter)
else:
intermediates.append(x)

if intermediates_only:
return intermediates

if feat_idx == last_idx:
x = self.norm(x)

return x, intermediates

def prune_intermediate_layers(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices

def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
Expand Down
Loading