From e83e2516f83bb49dbb7b2afc232caf68a62f4c7e Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Mon, 5 May 2025 15:36:12 +0000 Subject: [PATCH 1/9] MobileNetV5 Adds a TSV describing the full params from an Orbax checkpoint Switching to RMS Norm Fixes for correct export Tinkering with 'mobiletnetv5' details, fixing some issues with msfa A few tweaks and comments to example MNV5 impl Update RmsNorm2d modules to use own 2d eager kernel instead of torch rms_norm w/ permute Fix propagation of act_layer to RmsNormAct*, use ConvNormAct for stem instead of just Conv2d Fixes from weights conversion Plumbing norm_layer through to MultiQueryAttention2d impl forward_features for Transformers compatibility Adding forward_* APIs to MobileNetV5Encoder cleanup cleanup, model entrypt rename Large redundant with 300m Update input size for configs Fix stem conv layer name fix: always norm in MSFA Always call final MSFA norm layer Remove some FIXME, fix MSFA docstring. Remove use_layer_scale and rely on values == None, not used currently in any case. --- timm/layers/create_norm_act.py | 16 +- timm/layers/fast_norm.py | 47 +- timm/layers/norm.py | 17 +- timm/layers/norm_act.py | 74 ++- timm/models/__init__.py | 2 + timm/models/_efficientnet_blocks.py | 5 +- timm/models/mobilenetv5.py | 811 ++++++++++++++++++++++++++++ 7 files changed, 960 insertions(+), 12 deletions(-) create mode 100644 timm/models/mobilenetv5.py diff --git a/timm/layers/create_norm_act.py b/timm/layers/create_norm_act.py index 1f9f180590..e7805e7993 100644 --- a/timm/layers/create_norm_act.py +++ b/timm/layers/create_norm_act.py @@ -11,7 +11,7 @@ from .evo_norm import * from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d -from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d +from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, RmsNormAct, RmsNormAct2d from .inplace_abn import InplaceAbn _NORM_ACT_MAP = dict( @@ -34,11 +34,21 @@ frntlu=FilterResponseNormTlu2d, inplaceabn=InplaceAbn, iabn=InplaceAbn, + rmsnorm=RmsNormAct, + rmsnorm2d=RmsNormAct2d, ) _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} # has act_layer arg to define act type _NORM_ACT_REQUIRES_ARG = { - BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn} + BatchNormAct2d, + GroupNormAct, + LayerNormAct, + LayerNormAct2d, + FilterResponseNormAct2d, + InplaceAbn, + RmsNormAct, + RmsNormAct2d, +} def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs): @@ -83,6 +93,8 @@ def get_norm_act_layer(norm_layer, act_layer=None): norm_act_layer = LayerNormAct2d elif type_name.startswith('layernorm'): norm_act_layer = LayerNormAct + elif type_name.startswith('rmsnorm2d'): + norm_act_layer = RmsNormAct2d else: assert False, f"No equivalent norm_act layer for {type_name}" diff --git a/timm/layers/fast_norm.py b/timm/layers/fast_norm.py index e7cbbb9495..ec8c66814e 100644 --- a/timm/layers/fast_norm.py +++ b/timm/layers/fast_norm.py @@ -148,7 +148,7 @@ def fast_rms_norm( return fused_rms_norm_affine(x, weight, normalized_shape, eps) if is_autocast_enabled(x.device.type): - # normally native AMP casts LN inputs to float32 + # normally native AMP casts LN inputs to float32 and leaves the output as float32 # apex LN does not, this is behaving like Apex dt = get_autocast_dtype(x.device.type) x, weight = x.to(dt), weight.to(dt) @@ -162,6 +162,51 @@ def fast_rms_norm( return x +def rms_norm2d( + x: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor] = None, + eps: float = 1e-5, +): + assert len(normalized_shape) == 1 + v = x.pow(2) + v = torch.mean(v, dim=1, keepdim=True) + x = x * torch.rsqrt(v + eps) + if weight is not None: + x = x * weight.reshape(1, -1, 1, 1) + return x + + +def fast_rms_norm2d( + x: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor] = None, + eps: float = 1e-5, +) -> torch.Tensor: + if torch.jit.is_scripting(): + # this must be by itself, cannot merge with has_apex_rmsnorm + return rms_norm2d(x, normalized_shape, weight, eps) + + if has_apex_rmsnorm: + x = x.permute(0, 2, 3, 1) + if weight is None: + x = fused_rms_norm(x, normalized_shape, eps) + else: + x = fused_rms_norm_affine(x, weight, normalized_shape, eps) + x = x.permute(0, 3, 1, 2) + + if is_autocast_enabled(x.device.type): + # normally native AMP casts norm inputs to float32 and leaves the output as float32 + # apex does not, this is behaving like Apex + dt = get_autocast_dtype(x.device.type) + x, weight = x.to(dt), weight.to(dt) + + with torch.amp.autocast(device_type=x.device.type, enabled=False): + x = rms_norm2d(x, normalized_shape, weight, eps) + + return x + + def simple_norm( x: torch.Tensor, normalized_shape: List[int], diff --git a/timm/layers/norm.py b/timm/layers/norm.py index f718750dc5..3a6e87524e 100644 --- a/timm/layers/norm.py +++ b/timm/layers/norm.py @@ -11,7 +11,10 @@ import torch.nn as nn import torch.nn.functional as F -from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm, simple_norm +from .fast_norm import ( + is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d, + fast_simple_norm, simple_norm +) try: from torch.nn.functional import rms_norm @@ -173,7 +176,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class RmsNorm2d(nn.Module): - """ RmsNorm w/ fast (apex) norm if available + """ RmsNorm2D for NCHW tensors, w/ fast apex or cast norm if available + + NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction + on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something + like https://github.com/pytorch/pytorch/pull/150576 lands. """ __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] normalized_shape: Tuple[int, ...] @@ -205,14 +212,12 @@ def reset_parameters(self) -> None: nn.init.ones_(self.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.permute(0, 2, 3, 1) # NOTE fast norm fallback needs our rms norm impl, so both paths through here. # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed. if self._fast_norm: - x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) + x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps) else: - x = rms_norm(x, self.normalized_shape, self.weight, self.eps) - x = x.permute(0, 3, 1, 2) + x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps) return x diff --git a/timm/layers/norm_act.py b/timm/layers/norm_act.py index f211743770..b1e1ba98b5 100644 --- a/timm/layers/norm_act.py +++ b/timm/layers/norm_act.py @@ -20,9 +20,15 @@ from torchvision.ops.misc import FrozenBatchNorm2d from .create_act import create_act_layer -from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm +from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d +from .norm import RmsNorm, RmsNorm2d from .trace_utils import _assert +try: + from torch.nn.functional import rms_norm +except ImportError: + from .fast_norm import rms_norm + def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True): act_kwargs = act_kwargs or {} @@ -460,3 +466,69 @@ def forward(self, x): x = self.drop(x) x = self.act(x) return x + + +class RmsNormAct(RmsNorm): + """ RMSNorm + Activation for '2D' NCHW tensors + + NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction + on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something + like https://github.com/pytorch/pytorch/pull/150576 lands. + """ + def __init__( + self, + num_channels, + eps=1e-5, + affine=True, + apply_act=True, + act_layer=nn.ReLU, + act_kwargs=None, + inplace=True, + drop_layer=None, + ): + super().__init__(channels=num_channels, eps=eps, affine=affine) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) + self._fast_norm = is_fast_norm() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._fast_norm: + x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) + else: + x = rms_norm(x, self.normalized_shape, self.weight, self.eps) + x = self.drop(x) + x = self.act(x) + return x + + +class RmsNormAct2d(RmsNorm2d): + """ RMSNorm + Activation for '2D' NCHW tensors + + NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction + on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something + like https://github.com/pytorch/pytorch/pull/150576 lands. + """ + def __init__( + self, + num_channels, + eps=1e-5, + affine=True, + apply_act=True, + act_layer=nn.ReLU, + act_kwargs=None, + inplace=True, + drop_layer=None, + ): + super().__init__(channels=num_channels, eps=eps, affine=affine) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) + self._fast_norm = is_fast_norm() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._fast_norm: + x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps) + else: + x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps) + x = self.drop(x) + x = self.act(x) + return x diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 884df366a2..dc8cfa0359 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -40,6 +40,7 @@ from .metaformer import * from .mlp_mixer import * from .mobilenetv3 import * +from .mobilenetv5 import * from .mobilevit import * from .mvitv2 import * from .naflexvit import * @@ -129,6 +130,7 @@ load_model_config_from_hf as load_model_config_from_hf, load_state_dict_from_hf as load_state_dict_from_hf, push_to_hf_hub as push_to_hf_hub, + save_for_hf as save_for_hf, ) from ._manipulate import ( model_parameters as model_parameters, diff --git a/timm/models/_efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py index 6ac2f8cd6d..ab5864a455 100644 --- a/timm/models/_efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -517,12 +517,13 @@ def __init__( value_dim=value_dim, query_strides=query_strides, kv_stride=kv_stride, + dw_kernel_size=dw_kernel_size, dilation=dilation, padding=pad_type, - dw_kernel_size=dw_kernel_size, attn_drop=attn_drop, proj_drop=proj_drop, - #bias=use_bias, # why not here if used w/ mhsa? + norm_layer=norm_layer, + # use_bias=use_bias, # why not here if used w/ mhsa? ) else: self.attn = Attention2d( diff --git a/timm/models/mobilenetv5.py b/timm/models/mobilenetv5.py new file mode 100644 index 0000000000..ca5e094034 --- /dev/null +++ b/timm/models/mobilenetv5.py @@ -0,0 +1,811 @@ +from functools import partial +from typing import Callable, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.layers import ( + SelectAdaptivePool2d, Linear, LayerType, PadType, RmsNorm2d, ConvNormAct, create_conv2d, get_norm_act_layer, + to_2tuple +) +from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._efficientnet_blocks import SqueezeExcite, UniversalInvertedResidual +from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ + round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT +from ._features import FeatureInfo, FeatureHooks, feature_take_indices +from ._manipulate import checkpoint_seq, checkpoint +from ._registry import generate_default_cfgs, register_model, register_model_deprecations + +__all__ = ['MobileNetV5', 'MobileNetV5Encoder'] + + +class MobileNetV5MultiScaleFusionAdapter(nn.Module): + """Multi-layer fusion token adapter. + + Args: + in_chs: List of input channel counts for each feature scale. + out_chs: The number of output channels. + output_resolution: The output resolution. + expansion_ratio: The FFN expansion ratio. + interpolation_mode: The upsampling interpolation mode. + layer_scale_init_value: The initial value of the layer scale, no layer scale if None. + """ + + def __init__( + self, + in_chs: Union[int, List[int]], + out_chs: int, + output_resolution: int, + expansion_ratio: float = 2.0, + interpolation_mode: str = "nearest", + layer_scale_init_value: Optional[float] = None, + noskip: bool = True, + act_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + ): + super().__init__() + self.in_channels = sum(in_chs) if isinstance(in_chs, Sequence) else in_chs + self.out_channels = out_chs + self.output_resolution = to_2tuple(output_resolution) + self.expansion_ratio = expansion_ratio + self.interpolation_mode = interpolation_mode + self.layer_scale_init_value = layer_scale_init_value + self.noskip = noskip + + act_layer = act_layer or nn.GELU + norm_layer = norm_layer or RmsNorm2d + self.ffn = UniversalInvertedResidual( + in_chs=self.in_channels, + out_chs=self.out_channels, + dw_kernel_size_mid=0, + exp_ratio=self.expansion_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + noskip=self.noskip, + layer_scale_init_value=self.layer_scale_init_value, + ) + + self.norm = norm_layer(self.out_channels) + + def forward(self, inputs: list[torch.Tensor]) -> torch.Tensor: + # Inputs list of [B, C, H, W] tensors + high_resolution = inputs[0].shape[-2:] # Assuming the first input is the highest resolution. + resized_inputs = [] + for _, img in enumerate(inputs): + if any([r < hr for r, hr in zip(img.shape[-2:], high_resolution)]): + img = F.interpolate(img, size=high_resolution, mode=self.interpolation_mode) + resized_inputs.append(img) + + channel_cat_imgs = torch.cat(resized_inputs, dim=1) # Cat on channel dim, must equal self.in_channels + img = self.ffn(channel_cat_imgs) + + if any([ro != rh for ro, rh in zip(high_resolution, self.output_resolution)]): + # Interpolate / pool to target output_resolution if highest feature resolution differs + if ( + high_resolution[0] % self.output_resolution[0] != 0 or + high_resolution[1] % self.output_resolution[1] != 0 + ): + img = F.interpolate(img, size=self.output_resolution, mode="bilinear") + else: + h_strides = high_resolution[0] // self.output_resolution[0] + w_strides = high_resolution[1] // self.output_resolution[1] + img = F.avg_pool2d( + img, + kernel_size=(h_strides, w_strides), + stride=(h_strides, w_strides), + ) + + img = self.norm(img) + + return img + + +class MobileNetV5(nn.Module): + """ MobiletNet-V5 + """ + + def __init__( + self, + block_args: BlockArgs, + num_classes: int = 1000, + in_chans: int = 3, + stem_size: int = 16, + fix_stem: bool = False, + num_features: int = 2048, + pad_type: str = '', + use_msfa: bool = True, + msfa_indices: List[int] = (-3, -2, -1), + msfa_output_resolution: int = 16, + act_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[LayerType] = None, + se_from_exp: bool = True, + round_chs_fn: Callable = round_channels, + drop_rate: float = 0., + drop_path_rate: float = 0., + layer_scale_init_value: Optional[float] = None, + global_pool: str = 'avg', + ): + """ + Args: + block_args: Arguments for blocks of the network. + num_classes: Number of classes for classification head. + in_chans: Number of input image channels. + stem_size: Number of output channels of the initial stem convolution. + fix_stem: If True, don't scale stem by round_chs_fn. + num_features: Number of output channels of the conv head layer. + head_bias: If True, add a learnable bias to the conv head layer. + pad_type: Type of padding to use for convolution layers. + act_layer: Type of activation layer. + norm_layer: Type of normalization layer. + aa_layer: Type of anti-aliasing layer. + se_layer: Type of Squeeze-and-Excite layer. + se_from_exp: If True, calculate SE channel reduction from expanded mid channels. + round_chs_fn: Callable to round number of filters based on depth multiplier. + drop_rate: Dropout rate. + drop_path_rate: Stochastic depth rate. + layer_scale_init_value: Enable layer scale on compatible blocks if not None. + global_pool: Type of pooling to use for global pooling features of the FC head. + """ + super().__init__() + act_layer = act_layer or nn.GELU + norm_layer = norm_layer or RmsNorm2d + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + se_layer = se_layer or SqueezeExcite + self.num_classes = num_classes + self.drop_rate = drop_rate + self.grad_checkpointing = False + self.msfa_indices = msfa_indices + self.msfa_output_resolution = msfa_output_resolution + + # Stem + if not fix_stem: + stem_size = round_chs_fn(stem_size) + self.conv_stem = ConvNormAct( + in_chans, + stem_size, + kernel_size=3, + stride=2, + padding=pad_type, + norm_layer=norm_layer, + act_layer=act_layer, + ) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + output_stride=32, + pad_type=pad_type, + round_chs_fn=round_chs_fn, + se_from_exp=se_from_exp, + act_layer=act_layer, + norm_layer=norm_layer, + aa_layer=aa_layer, + se_layer=se_layer, + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, + ) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = builder.features + self.stage_ends = [f['stage'] for f in self.feature_info] + self.num_features = builder.in_chs # features of last stage, output of forward_features() + + # Neck (aggregation) + Head + Pooling + if use_msfa: + self.num_features = self.head_hidden_size = num_features # output of msfa is output of forward_features() + # Map msfa indices to feature info and calculate sum of feature channels + self.msfa_indices = feature_take_indices(len(self.feature_info), self.msfa_indices)[0] + self.msfa_in_chs = sum([self.feature_info[mi]['num_chs'] for mi in self.msfa_indices]) + + self.msfa = MobileNetV5MultiScaleFusionAdapter( + in_chs=self.msfa_in_chs, + out_chs=num_features, + output_resolution=self.msfa_output_resolution, + norm_layer=norm_layer, + act_layer=act_layer, + ) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.conv_head = None + self.norm_head = None + else: + self.num_features = builder.in_chs # features of last stage, output of forward_features() + self.head_hidden_size = num_features + self.msfa = None + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + num_pooled_chs = self.num_features * self.global_pool.feat_mult() + # mobilenet-v4 style post-pooling PW conv is followed by a norm+act layer + self.conv_head = create_conv2d(num_pooled_chs, self.head_hidden_size, 1, padding=pad_type) + self.norm_head = norm_act_layer(self.head_hidden_size) + + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled + self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() + + efficientnet_init_weights(self) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1] + layers.extend(self.blocks) + layers.append(self.global_pool) + if self.conv_head is not None: + layers.append(self.conv_head) + if self.norm_head is not None: + layers.append(self.norm_head) + layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False): + return dict( + stem=r'^conv_stem|bn1', + blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)' + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.classifier + + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): + self.num_classes = num_classes + # NOTE: cannot meaningfully change pooling of efficient head after creation + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled + self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + extra_blocks: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + if stop_early: + assert intermediates_only, 'Must use intermediates_only for early stopping.' + intermediates = [] + if extra_blocks: + take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) + else: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + take_indices = [self.stage_ends[i] for i in take_indices] + max_index = self.stage_ends[max_index] + + # FIXME MFSA and forward_intermediates overlap, they both take indices from specific features + # When a user wants to grab specific feature maps for a downstream task AND have the msfa output + # what should we do? Accumulate two intermediates? One for msfa and one for take_indices? + + # forward pass + feat_idx = 0 # stem is index 0 + x = self.conv_stem(x) + if feat_idx in take_indices: + intermediates.append(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index] + for blk in blocks: + feat_idx += 1 + x = blk(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + # FIXME see note above + # self.msfa(msfa_intermediatse) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + extra_blocks: bool = False, + ): + """ Prune layers not required for specified intermediates. + """ + if extra_blocks: + take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) + else: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + max_index = self.stage_ends[max_index] + self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0 + if max_index < len(self.blocks): + self.conv_head = None + self.norm_head = None + if prune_head: + self.conv_head = None + self.norm_head = None + self.reset_classifier(0, '') + return take_indices + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + if self.msfa is not None: + # When MSFA aggregation layer is present, we gather intermediates as is forward_intermediates + feat_idx = 0 # offset by one from blocks index due to stem feature + intermediates = [] + x = self.conv_stem(x) + if feat_idx in self.msfa_indices: + intermediates.append(x) + for blk in self.blocks: + feat_idx += 1 + # FIXME fix grad checkpointing + x = blk(x) + if feat_idx in self.msfa_indices: + intermediates.append(x) + x = self.msfa(intermediates) + else: + x = self.conv_stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x, flatten=True) + else: + x = self.blocks(x) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + x = self.global_pool(x) + if self.conv_head is not None: + x = self.conv_head(x) + if self.norm_head is not None: + x = self.norm_head(x) + x = self.flatten(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + if pre_logits: + return x + return self.classifier(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +class MobileNetV5Encoder(nn.Module): + """MobileNetV5 Vision Encoder""" + + def __init__( + self, + block_args: BlockArgs, + in_chans: int = 3, + stem_size: int = 64, + fix_stem: bool = False, + pad_type: str = '', + msfa_indices: Sequence[int] = (-2, -1), + msfa_output_resolution: int = 16, + act_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[LayerType] = None, + se_from_exp: bool = True, + round_chs_fn: Callable = round_channels, + drop_rate: float = 0., + drop_path_rate: float = 0., + layer_scale_init_value: Optional[float] = None, + ): + super().__init__() + act_layer = act_layer or nn.GELU + norm_layer = norm_layer or RmsNorm2d + se_layer = se_layer or SqueezeExcite + self.num_classes = 0 # Exists to satisfy ._hub module APIs. + self.drop_rate = drop_rate + self.grad_checkpointing = False + + # Stem + if not fix_stem: + stem_size = round_chs_fn(stem_size) + self.conv_stem = ConvNormAct( + in_chans, + stem_size, + kernel_size=3, + stride=2, + padding=pad_type, + norm_layer=norm_layer, + act_layer=act_layer, + ) + + builder = EfficientNetBuilder( + output_stride=32, + pad_type=pad_type, + round_chs_fn=round_chs_fn, + se_from_exp=se_from_exp, + act_layer=act_layer, + norm_layer=norm_layer, + aa_layer=aa_layer, + se_layer=se_layer, + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, + ) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = builder.features + self.stage_ends = [f['stage'] for f in self.feature_info] + + self.num_features = self.head_hidden_size = 2048 # output of msfa is output of forward_features() + # Map msfa indices to feature info and calculate sum of feature channels + self.msfa_indices = feature_take_indices(len(self.feature_info), msfa_indices)[0] + self.msfa_in_chs = sum([self.feature_info[mi]['num_chs'] for mi in self.msfa_indices]) + self.msfa_output_resolution = msfa_output_resolution + + self.msfa = MobileNetV5MultiScaleFusionAdapter( + in_chs=self.msfa_in_chs, + out_chs=self.num_features, + output_resolution=self.msfa_output_resolution, + norm_layer=norm_layer, + act_layer=act_layer, + ) + + efficientnet_init_weights(self) + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + extra_blocks: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: (Unused) Applies norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info + Returns: + + """ + del norm + + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + if stop_early: + assert intermediates_only, 'Must use intermediates_only for early stopping.' + + # MobileNet v5's MultiScaleFusionAdapter takes intermediates from specific feature indicies and uses them in + # its computation. These MSFA indices are not guaranteed to be captured by the `indices` parameter passed to + # this function, so we accumulate two sets of indices, one that aligns with the `indices` parameter and one + # that is required by the MSFA block. + intermediates = [] + msfa_intermediates = [] + + if extra_blocks: + take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) + else: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + take_indices = [self.stage_ends[i] for i in take_indices] + max_index = self.stage_ends[max_index] + + # forward pass + feat_idx = 0 # stem is index 0 + x = self.conv_stem(x) + if feat_idx in take_indices: + intermediates.append(x) + if feat_idx in self.msfa_indices: + msfa_intermediates.append(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index] + + for blk in blocks: + feat_idx += 1 + x = blk(x) + if feat_idx in take_indices: + intermediates.append(x) + if feat_idx in self.msfa_indices: + msfa_intermediates.append(x) + + if intermediates_only: + return intermediates + + return self.msfa(msfa_intermediates), intermediates + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + feat_idx = 0 # offset by one from blocks index due to stem feature + intermediates = [] + + x = self.conv_stem(x) + if feat_idx in self.msfa_indices: + intermediates.append(x) + + for blk in self.blocks: + feat_idx += 1 + # FIXME fix grad checkpointing + x = blk(x) + if feat_idx in self.msfa_indices: + intermediates.append(x) + + return self.msfa(intermediates) + + def forward_head(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("MobileNetV5Encoder does not support classification use cases.") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_features(x) + + +def _create_mnv5_encoder(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV5Encoder: + kwargs_filter = ( + 'num_classes', + 'num_features', + 'head_conv', + 'head_bias', + 'head_norm', + 'global_pool', + ) + model = build_model_with_cfg( + MobileNetV5Encoder, + variant, + pretrained, + pretrained_strict=False, + kwargs_filter=kwargs_filter, + **kwargs, + ) + return model + + +def _create_mnv5(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV5Encoder: + model = build_model_with_cfg( + MobileNetV5, + variant, + pretrained, + pretrained_strict=False, + **kwargs, + ) + return model + + +def _gen_mobilenet_v5( + variant: str, + channel_multiplier: float = 1.0, + group_size=None, + pretrained: bool = False, + encoder: bool = False, + **kwargs, +) -> MobileNetV5Encoder: + if 'mobilenetv5_base' in variant: + arch_def: list[list[str]] = [ + # Stage 0: 128x128 in + [ + 'er_r1_k3_s2_e4_c128', + 'er_r1_k3_s1_e4_c128', + 'er_r1_k3_s1_e4_c128', + ], + # Stage 1: 256x256 in + [ + 'uir_r1_a3_k5_s2_e6_c256', + 'uir_r1_a5_k0_s1_e4_c256', + 'uir_r1_a3_k0_s1_e4_c256', + 'uir_r1_a5_k0_s1_e4_c256', + 'uir_r1_a3_k0_s1_e4_c256', + ], + # Stage 2: 640x640 in + [ + "uir_r1_a5_k5_s2_e6_c512", + "uir_r1_a5_k0_s1_e4_c512", + "uir_r1_a5_k0_s1_e4_c512", + "uir_r1_a0_k0_s1_e1_c512", + 'mqa_r1_k3_h8_s2_d64_c512', + "uir_r1_a0_k0_s1_e2_c512", + 'mqa_r1_k3_h8_s2_d64_c512', + "uir_r1_a0_k0_s1_e2_c512", + 'mqa_r1_k3_h8_s2_d64_c512', + "uir_r1_a0_k0_s1_e2_c512", + 'mqa_r1_k3_h8_s2_d64_c512', + "uir_r1_a0_k0_s1_e2_c512", + 'mqa_r1_k3_h8_s2_d64_c512', + "uir_r1_a0_k0_s1_e2_c512", + 'mqa_r1_k3_h8_s2_d64_c512', + "uir_r1_a0_k0_s1_e2_c512", + ], + # Stage 3: 1280x1280 in + [ + "uir_r1_a5_k5_s2_e6_c1024", + 'mqa_r1_k3_h16_s1_d64_c1024', + "uir_r1_a0_k0_s1_e2_c1024", + 'mqa_r1_k3_h16_s1_d64_c1024', + "uir_r1_a0_k0_s1_e2_c1024", + 'mqa_r1_k3_h16_s1_d64_c1024', + "uir_r1_a0_k0_s1_e2_c1024", + 'mqa_r1_k3_h16_s1_d64_c1024', + "uir_r1_a0_k0_s1_e2_c1024", + 'mqa_r1_k3_h16_s1_d64_c1024', + "uir_r1_a0_k0_s1_e2_c1024", + 'mqa_r1_k3_h16_s1_d64_c1024', + "uir_r1_a0_k0_s1_e2_c1024", + 'mqa_r1_k3_h16_s1_d64_c1024', + "uir_r1_a0_k0_s1_e2_c1024", + ], + ] + else: + arch_def: list[list[str]] = [ + # Stage 0: 128x128 in + [ + 'er_r1_k3_s2_e4_c128', + 'er_r1_k3_s1_e4_c128', + 'er_r1_k3_s1_e4_c128', + ], + # Stage 1: 256x256 in + [ + 'uir_r1_a3_k5_s2_e6_c256', + 'uir_r1_a5_k0_s1_e4_c256', + 'uir_r1_a3_k0_s1_e4_c256', + 'uir_r1_a5_k0_s1_e4_c256', + 'uir_r1_a3_k0_s1_e4_c256', + ], + # Stage 2: 640x640 in + [ + "uir_r1_a5_k5_s2_e6_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a0_k0_s1_e1_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + ], + # Stage 3: 1280x1280 in + [ + "uir_r1_a5_k5_s2_e6_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + ], + ] + + model_kwargs = dict( + block_args=decode_arch_def(arch_def, group_size=group_size), + stem_size=64, + fix_stem=channel_multiplier < 1.0, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=RmsNorm2d, + act_layer=nn.GELU, + layer_scale_init_value=1e-5, + ) + model_kwargs = dict(model_kwargs, **kwargs) + if encoder: + model = _create_mnv5_encoder(variant, pretrained, **model_kwargs) + else: + model = _create_mnv5(variant, pretrained, **model_kwargs) + return model + + +def _cfg(url: str = '', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': 1.0, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv_stem.conv', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + # encoder-only configs + 'mobilenetv5_300m_enc': _cfg( + #hf_hub_id='timm/', + input_size=(3, 768, 768), pool_size=(24, 24), + num_classes=0), + + # classification configs for testing / fine-tune (WIP) + 'mobilenetv5_300m': _cfg( + # hf_hub_id='timm/', + input_size=(3, 768, 768), pool_size=(24, 24), + num_classes=0), + 'mobilenetv5_base.untrained': _cfg( + # hf_hub_id='timm/', + num_classes=1000, + input_size=(3, 224, 224), pool_size=(7, 7)), +}) + + +@register_model +def mobilenetv5_300m_enc(pretrained: bool = False, **kwargs) -> MobileNetV5Encoder: + """MobileNet V5 Vision Encoder""" + model = _gen_mobilenet_v5('mobilenetv5_300m_enc', pretrained=pretrained, encoder=True, **kwargs) + return model + + +@register_model +def mobilenetv5_300m(pretrained: bool = False, **kwargs) -> MobileNetV5: + model = _gen_mobilenet_v5('mobilenetv5_300m', pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv5_base(pretrained: bool = False, **kwargs) -> MobileNetV5: + model = _gen_mobilenet_v5('mobilenetv5_base', pretrained=pretrained, **kwargs) + return model From 1690574a0a4c9ed194fd446adee892c3fc3f9010 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 25 Jun 2025 21:31:19 -0700 Subject: [PATCH 2/9] Fix torchscript compat of MobileNetV5 MSFA --- timm/models/mobilenetv5.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/mobilenetv5.py b/timm/models/mobilenetv5.py index ca5e094034..853dd5b479 100644 --- a/timm/models/mobilenetv5.py +++ b/timm/models/mobilenetv5.py @@ -69,7 +69,7 @@ def __init__( self.norm = norm_layer(self.out_channels) - def forward(self, inputs: list[torch.Tensor]) -> torch.Tensor: + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: # Inputs list of [B, C, H, W] tensors high_resolution = inputs[0].shape[-2:] # Assuming the first input is the highest resolution. resized_inputs = [] @@ -81,7 +81,7 @@ def forward(self, inputs: list[torch.Tensor]) -> torch.Tensor: channel_cat_imgs = torch.cat(resized_inputs, dim=1) # Cat on channel dim, must equal self.in_channels img = self.ffn(channel_cat_imgs) - if any([ro != rh for ro, rh in zip(high_resolution, self.output_resolution)]): + if high_resolution[0] != self.output_resolution[0] or high_resolution[1] != self.output_resolution[1]: # Interpolate / pool to target output_resolution if highest feature resolution differs if ( high_resolution[0] % self.output_resolution[0] != 0 or From 739b46cc65cd77568d78e19d836c577291f81e35 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 25 Jun 2025 21:46:31 -0700 Subject: [PATCH 3/9] Fixed pool size (16,16) because of of MSFA. --- timm/models/mobilenetv5.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/timm/models/mobilenetv5.py b/timm/models/mobilenetv5.py index 853dd5b479..355176d07d 100644 --- a/timm/models/mobilenetv5.py +++ b/timm/models/mobilenetv5.py @@ -765,7 +765,7 @@ def _gen_mobilenet_v5( def _cfg(url: str = '', **kwargs): return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (16, 16), 'crop_pct': 1.0, 'interpolation': 'bicubic', 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'first_conv': 'conv_stem.conv', 'classifier': 'classifier', @@ -777,18 +777,17 @@ def _cfg(url: str = '', **kwargs): # encoder-only configs 'mobilenetv5_300m_enc': _cfg( #hf_hub_id='timm/', - input_size=(3, 768, 768), pool_size=(24, 24), + input_size=(3, 768, 768), num_classes=0), - # classification configs for testing / fine-tune (WIP) + # WIP classification configs for testing 'mobilenetv5_300m': _cfg( # hf_hub_id='timm/', - input_size=(3, 768, 768), pool_size=(24, 24), + input_size=(3, 768, 768), num_classes=0), 'mobilenetv5_base.untrained': _cfg( # hf_hub_id='timm/', - num_classes=1000, - input_size=(3, 224, 224), pool_size=(7, 7)), + num_classes=1000) }) From e0cb66913666f6c8d06fa61748a169fbd0a7a6fd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 25 Jun 2025 22:05:00 -0700 Subject: [PATCH 4/9] Make features_only=True work with mnv5 & enc, uses forward_intermediates() --- timm/models/mobilenetv5.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/timm/models/mobilenetv5.py b/timm/models/mobilenetv5.py index 355176d07d..40e81243dd 100644 --- a/timm/models/mobilenetv5.py +++ b/timm/models/mobilenetv5.py @@ -554,6 +554,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _create_mnv5_encoder(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV5Encoder: + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) + feature_cfg = dict(out_indices=out_indices, feature_cls='getter') kwargs_filter = ( 'num_classes', 'num_features', @@ -567,6 +569,7 @@ def _create_mnv5_encoder(variant: str, pretrained: bool = False, **kwargs) -> Mo variant, pretrained, pretrained_strict=False, + feature_cfg=feature_cfg, kwargs_filter=kwargs_filter, **kwargs, ) @@ -574,11 +577,14 @@ def _create_mnv5_encoder(variant: str, pretrained: bool = False, **kwargs) -> Mo def _create_mnv5(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV5Encoder: + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) + feature_cfg = dict(out_indices=out_indices, feature_cls='getter') model = build_model_with_cfg( MobileNetV5, variant, pretrained, pretrained_strict=False, + feature_cfg=feature_cfg, **kwargs, ) return model From 857727ded8ff27d8b27738a4d56431229c979316 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 25 Jun 2025 22:34:06 -0700 Subject: [PATCH 5/9] Simplify resolution check for improved script/trace compat --- timm/models/mobilenetv5.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/mobilenetv5.py b/timm/models/mobilenetv5.py index 40e81243dd..ccaa7f6820 100644 --- a/timm/models/mobilenetv5.py +++ b/timm/models/mobilenetv5.py @@ -74,7 +74,8 @@ def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: high_resolution = inputs[0].shape[-2:] # Assuming the first input is the highest resolution. resized_inputs = [] for _, img in enumerate(inputs): - if any([r < hr for r, hr in zip(img.shape[-2:], high_resolution)]): + feat_size = img.shape[-2:] + if feat_size[0] < high_resolution[0] or feat_size[1] < high_resolution[1]: img = F.interpolate(img, size=high_resolution, mode=self.interpolation_mode) resized_inputs.append(img) From 4cc7fdbd8843c7aac1b07f0b2ea264906ab8451d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 26 Jun 2025 07:36:56 -0700 Subject: [PATCH 6/9] Cleanup imports, mark MSFA as notrace --- timm/models/mobilenetv5.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/timm/models/mobilenetv5.py b/timm/models/mobilenetv5.py index ccaa7f6820..8e19648e98 100644 --- a/timm/models/mobilenetv5.py +++ b/timm/models/mobilenetv5.py @@ -5,22 +5,24 @@ import torch.nn as nn import torch.nn.functional as F -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import ( SelectAdaptivePool2d, Linear, LayerType, PadType, RmsNorm2d, ConvNormAct, create_conv2d, get_norm_act_layer, to_2tuple ) -from ._builder import build_model_with_cfg, pretrained_cfg_for_features +from ._builder import build_model_with_cfg from ._efficientnet_blocks import SqueezeExcite, UniversalInvertedResidual from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ - round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT -from ._features import FeatureInfo, FeatureHooks, feature_take_indices + round_channels, resolve_act_layer +from ._features import feature_take_indices +from ._features_fx import register_notrace_module from ._manipulate import checkpoint_seq, checkpoint -from ._registry import generate_default_cfgs, register_model, register_model_deprecations +from ._registry import generate_default_cfgs, register_model __all__ = ['MobileNetV5', 'MobileNetV5Encoder'] +@register_notrace_module class MobileNetV5MultiScaleFusionAdapter(nn.Module): """Multi-layer fusion token adapter. From ddd3f99a7855f3c583d3a303c6c46c5a6d34f8b2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 26 Jun 2025 09:03:44 -0700 Subject: [PATCH 7/9] Update test, encoder_only mode for backward test --- tests/test_models.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 58dad77e27..a7e9ccf00d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -210,6 +210,7 @@ def test_model_backward(model_name, batch_size): pytest.skip("Fixed input size model > limit.") model = create_model(model_name, pretrained=False, num_classes=42) + encoder_only = model.num_classes == 0 # FIXME better approach? num_params = sum([x.numel() for x in model.parameters()]) model.train() @@ -224,7 +225,12 @@ def test_model_backward(model_name, batch_size): assert x.grad is not None, f'No gradient for {n}' num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None]) - assert outputs.shape[-1] == 42 + if encoder_only: + output_fmt = getattr(model, 'output_fmt', 'NCHW') + feat_axis = get_channel_dim(output_fmt) + assert outputs.shape[feat_axis] == model.num_features, f'unpooled feature dim {outputs.shape[feat_axis]} != model.num_features {model.num_features}' + else: + assert outputs.shape[-1] == 42 assert num_params == num_grad, 'Some parameters are missing gradients' assert not torch.isnan(outputs).any(), 'Output included NaNs' From 136440d9d410c1088978fb89b1fd04262f561980 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 26 Jun 2025 09:32:22 -0700 Subject: [PATCH 8/9] Switch to 'same' padding emulation for the enc model as it should be closer for original weights. --- timm/models/mobilenetv5.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/timm/models/mobilenetv5.py b/timm/models/mobilenetv5.py index 8e19648e98..062ef8ffb7 100644 --- a/timm/models/mobilenetv5.py +++ b/timm/models/mobilenetv5.py @@ -803,7 +803,14 @@ def _cfg(url: str = '', **kwargs): @register_model def mobilenetv5_300m_enc(pretrained: bool = False, **kwargs) -> MobileNetV5Encoder: """MobileNet V5 Vision Encoder""" - model = _gen_mobilenet_v5('mobilenetv5_300m_enc', pretrained=pretrained, encoder=True, **kwargs) + pad_type = kwargs.pop('pad_type', 'same') + model = _gen_mobilenet_v5( + 'mobilenetv5_300m_enc', + pretrained=pretrained, + encoder=True, + pad_type=pad_type, + **kwargs, + ) return model From 38286760112fa8ee6b6074c3c41e1e43ac033089 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 26 Jun 2025 09:39:07 -0700 Subject: [PATCH 9/9] Make RmsNormAct sync with RmsNorm re default eps of 1e-6 --- timm/layers/norm_act.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/layers/norm_act.py b/timm/layers/norm_act.py index b1e1ba98b5..dd9413105e 100644 --- a/timm/layers/norm_act.py +++ b/timm/layers/norm_act.py @@ -478,7 +478,7 @@ class RmsNormAct(RmsNorm): def __init__( self, num_channels, - eps=1e-5, + eps=1e-6, affine=True, apply_act=True, act_layer=nn.ReLU, @@ -511,7 +511,7 @@ class RmsNormAct2d(RmsNorm2d): def __init__( self, num_channels, - eps=1e-5, + eps=1e-6, affine=True, apply_act=True, act_layer=nn.ReLU,