diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index f87ce9693a..28d5067b82 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -270,7 +270,6 @@ def _compute_resize_matrix( eye_matrix = torch.eye(old_total, device=device, dtype=dtype) basis_vectors_batch = eye_matrix.reshape(old_total, 1, old_h, old_w) - resized_basis_vectors_batch = F.interpolate( basis_vectors_batch, size=new_size, @@ -278,17 +277,10 @@ def _compute_resize_matrix( antialias=antialias, align_corners=False ) # Output shape: (old_total, 1, new_h, new_w) - - resize_matrix = resized_basis_vectors_batch.squeeze(1).reshape(old_total, new_total).T + resize_matrix = resized_basis_vectors_batch.squeeze(1).permute(1, 2, 0).reshape(new_total, old_total) return resize_matrix # Shape: (new_total, old_total) -def _compute_pinv_for_resampling(resize_matrix: torch.Tensor) -> torch.Tensor: - """Calculates the pseudoinverse matrix used for the resampling operation.""" - pinv_matrix = torch.linalg.pinv(resize_matrix.T) # Shape: (new_total, old_total) - return pinv_matrix - - def _apply_resampling( patch_embed: torch.Tensor, pinv_matrix: torch.Tensor, @@ -296,21 +288,15 @@ def _apply_resampling( orig_dtype: torch.dtype, intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE ) -> torch.Tensor: - """Applies the precomputed pinv_matrix to resample the patch_embed tensor.""" - try: - from torch import vmap - except ImportError: - from functorch import vmap - - def resample_kernel(kernel: torch.Tensor) -> torch.Tensor: - kernel_flat = kernel.reshape(-1).to(intermediate_dtype) - resampled_kernel_flat = pinv_matrix @ kernel_flat - return resampled_kernel_flat.reshape(new_size_tuple) - - resample_kernel_vmap = vmap(vmap(resample_kernel, in_dims=0, out_dims=0), in_dims=0, out_dims=0) - patch_embed_float = patch_embed.to(intermediate_dtype) - resampled_patch_embed = resample_kernel_vmap(patch_embed_float) - return resampled_patch_embed.to(orig_dtype) + """ Simplified resampling w/o vmap use. + As proposed by https://github.com/stas-sl + """ + c_out, c_in, *_ = patch_embed.shape + patch_embed = patch_embed.reshape(c_out, c_in, -1).to(dtype=intermediate_dtype) + pinv_matrix = pinv_matrix.to(dtype=intermediate_dtype) + resampled_patch_embed = patch_embed @ pinv_matrix # (C_out, C_in, P_old * P_old) @ (P_old * P_old, P_new * P_new) + resampled_patch_embed = resampled_patch_embed.reshape(c_out, c_in, *new_size_tuple).to(dtype=orig_dtype) + return resampled_patch_embed def resample_patch_embed( @@ -336,7 +322,7 @@ def resample_patch_embed( resize_mat = _compute_resize_matrix( old_size_tuple, new_size_tuple, interpolation, antialias, device, DTYPE_INTERMEDIATE ) - pinv_matrix = _compute_pinv_for_resampling(resize_mat) + pinv_matrix = torch.linalg.pinv(resize_mat) # Calculates the pseudoinverse matrix used for resampling resampled_patch_embed = _apply_resampling( patch_embed, pinv_matrix, new_size_tuple, orig_dtype, DTYPE_INTERMEDIATE ) @@ -388,7 +374,7 @@ def _get_or_create_pinv_matrix( resize_mat = _compute_resize_matrix( self.orig_size, new_size, self.interpolation, self.antialias, device, dtype ) - pinv_matrix = _compute_pinv_for_resampling(resize_mat) + pinv_matrix = torch.linalg.pinv(resize_mat) # Calculates the pseudoinverse matrix used for resampling # Cache using register_buffer buffer_name = f"pinv_{new_size[0]}x{new_size[1]}" diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py index 9684b397ec..539543f5e7 100644 --- a/timm/models/naflexvit.py +++ b/timm/models/naflexvit.py @@ -25,6 +25,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import ( @@ -89,6 +90,7 @@ class NaFlexVitCfg: pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16) # Grid size for position embedding initialization pos_embed_interp_mode: str = 'bicubic' # Interpolation mode for position embedding resizing pos_embed_ar_preserving: bool = False # Whether to preserve aspect ratio during position embedding interpolation + pos_embed_use_grid_sample: bool = False # Whether to use grid_sample for naflex position embedding interpolation # Image processing dynamic_img_pad: bool = False # Whether to enable dynamic padding for variable resolution @@ -163,6 +165,13 @@ def batch_patchify( return patches, (nh, nw) +def calculate_naflex_grid_sizes(_coord: torch.Tensor): + # Calculate the appropriate grid size from coords + max_y = _coord[:, :, 0].amax(dim=1) + 1 + max_x = _coord[:, :, 1].amax(dim=1) + 1 + return [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)] + + @register_notrace_module class NaFlexEmbeds(nn.Module): """NaFlex Embedding module for Vision Transformers. @@ -221,6 +230,7 @@ def __init__( pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14), pos_embed_interp_mode: str = 'bicubic', pos_embed_ar_preserving: bool = False, + pos_embed_use_grid_sample: bool = False, input_norm_layer: Optional[Type[nn.Module]] = None, proj_norm_layer: Union[bool, Optional[Type[nn.Module]]] = None, norm_layer: Optional[Type[nn.Module]] = None, @@ -256,6 +266,7 @@ def __init__( self.num_reg_tokens = reg_tokens self.pos_embed_interp_mode = pos_embed_interp_mode self.pos_embed_ar_preserving = pos_embed_ar_preserving + self.pos_embed_use_grid_sample = pos_embed_use_grid_sample self.patch_size = to_2tuple(patch_size) self.in_chans = in_chans self.embed_dim = embed_dim @@ -403,18 +414,19 @@ def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: def _apply_learned_naflex_pos_embed( self, x: torch.Tensor, - naflex_grid_sizes: List[Tuple[int, int]], + patch_coord: torch.Tensor, ) -> None: """Apply learned position embeddings to NaFlex batch in-place. - Interpolates learned position embeddings for each sample in the batch + Interpolates learned 2D position embeddings for each sample in the batch based on their individual grid sizes. Args: - x: Input tensor to add position embeddings to - naflex_grid_sizes: List of (height, width) grid sizes for each batch element + x: Input tensor to add position embeddings to [B, N, C] + patch_coord: Patch coordinates [B, N, 2] with (y, x) values """ - # Handle each batch element separately with its own grid size + # Calculate grid sizes from patch coordinates + naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord) orig_h, orig_w = self.pos_embed.shape[1:3] pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W @@ -438,18 +450,6 @@ def _interp2d(size): )[:, :, :size[0], :size[1]].flatten(2).transpose(1, 2) return pos_embed_flat.to(dtype=x.dtype) - # FIXME leaving alternative code commented here for now for comparisons - # pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {} - # for i, s in enumerate(naflex_grid_sizes): - # if s in pos_embed_cache: - # pos_embed_flat = pos_embed_cache[s] - # else: - # pos_embed_flat = _interp(s) - # pos_embed_cache[s] = pos_embed_flat - # - # seq_len = min(x.shape[1], pos_embed_flat.shape[1]) - # x[i, :seq_len] += pos_embed_flat[0, :seq_len] - # Determine unique grid sizes to avoid duplicate interpolation size_to_indices: Dict[Tuple[int, int], List[int]] = {} for bi, k in enumerate(naflex_grid_sizes): @@ -467,32 +467,79 @@ def _interp2d(size): pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1) ) + def _apply_learned_naflex_pos_embed_grid_sample( + self, + x: torch.Tensor, + patch_coord: torch.Tensor, + ) -> None: + """Apply learned position embeddings to NaFlex batch using grid_sample. + + Uses F.grid_sample for efficient interpolation of learned 2D position embeddings + based on patch coordinates. Based on proposal by https://github.com/stas-sl + + Args: + x: Input tensor to add position embeddings to [B, N, C] + patch_coord: Patch coordinates [B, N, 2] with (y, x) values + """ + device = x.device + B, N, C = x.shape + shapes = patch_coord.max(dim=1).values + 1 # (B, 2) containing [h_i, w_i] + + if self.pos_embed_ar_preserving: + L_i = shapes.amax(dim=1) # (B,) max(h_i, w_i) + L_global = L_i.amax() + grid_size_y = grid_size_x = L_global + scale_x = scale_y = L_global / L_i # uniform zoom (B,) + else: + grid_size_y, grid_size_x = shapes.amax(dim=0) # (2,) + scale_y = grid_size_y / shapes[:, 0] # vertical zoom (B,) + scale_x = grid_size_x / shapes[:, 1] # horizontal zoom (B,) + + theta = torch.zeros(B, 2, 3, device=device, dtype=torch.float32) + theta[:, 0, 0] = scale_x + theta[:, 1, 1] = scale_y + theta[:, 0, 2] = scale_x - 1 # translate x + theta[:, 1, 2] = scale_y - 1 # translate y + + grid = F.affine_grid(theta, (B, C, grid_size_y, grid_size_x), align_corners=False) + pos_embed = F.grid_sample( + self.pos_embed.permute(0, 3, 1, 2).expand(B, -1, -1, -1).float(), + grid, + mode=self.pos_embed_interp_mode, + align_corners=False, + padding_mode='border', + ).to(dtype=x.dtype) # (B, C, H_out, W_out) + + bi = torch.arange(B, device=device).unsqueeze(1) + x += pos_embed[bi, :, patch_coord[..., 0], patch_coord[..., 1]] # NOTE leave as '+=' + def _apply_learned_pos_embed( self, x: torch.Tensor, grid_size: List[int], ) -> None: - """Apply learned position embeddings to standard batch in-place. + """Apply learned position embeddings to standard 2D batch in-place. - Interpolates learned position embeddings to match the specified grid size. + Interpolates learned 2D position embeddings to match the specified grid size. Args: - x: Input tensor to add position embeddings to + x: Input tensor to add position embeddings to [B, H*W, C] grid_size: Target grid size as [height, width] """ orig_h, orig_w = self.pos_embed.shape[1:3] - if grid_size[0] == orig_h or grid_size[1] == orig_w: + if grid_size[0] == orig_h and grid_size[1] == orig_w: # No resize needed, just flatten pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) else: # Resize if needed - directly using F.interpolate + _interp_size = to_2tuple(max(grid_size)) if self.pos_embed_ar_preserving else grid_size pos_embed_flat = F.interpolate( self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W - size=grid_size, + size=_interp_size, mode=self.pos_embed_interp_mode, align_corners=False, antialias=True, - ).flatten(2).transpose(1, 2) + )[:, :, :grid_size[0], :grid_size[1]].flatten(2).transpose(1, 2) pos_embed_flat = pos_embed_flat.to(dtype=x.dtype) x.add_(pos_embed_flat) @@ -500,7 +547,7 @@ def _apply_learned_pos_embed( def _apply_factorized_naflex_pos_embed( self, x: torch.Tensor, - naflex_grid_sizes: List[Tuple[int, int]], + patch_coord: torch.Tensor, ) -> None: """Apply factorized position embeddings to NaFlex batch in-place. @@ -508,15 +555,17 @@ def _apply_factorized_naflex_pos_embed( and combined for each sample's grid size. Args: - x: Input tensor to add position embeddings to - naflex_grid_sizes: List of (height, width) grid sizes for each batch element + x: Input tensor to add position embeddings to [B, N, C] + patch_coord: Patch coordinates [B, N, 2] with (y, x) values """ + # Calculate grid sizes from patch coordinates + naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord) assert len(naflex_grid_sizes) == x.size(0) # one (H,W) per sample # Handle each batch element separately with its own grid size orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1] - # bucket samples that share the same (H,W) so we build each grid once + # bucket samples that share the same (H, W) so we build each grid once size_to_indices: Dict[Tuple[int, int], List[int]] = {} for bi, k in enumerate(naflex_grid_sizes): size_to_indices.setdefault(k, []).append(bi) @@ -556,6 +605,95 @@ def _interp1d(table: torch.Tensor, new_length: int, orig_length: int) -> torch.T pos[:, :seq_len].expand(len(batch_indices), -1, -1) ) + def _apply_factorized_naflex_pos_embed_grid_sample( + self, + x: torch.Tensor, + patch_coord: torch.Tensor, + ) -> None: + """Apply factorized position embeddings to NaFlex batch using grid_sample. + + Uses F.grid_sample for efficient interpolation of separate Y and X position + embedding tables based on patch coordinates. Based on proposal by https://github.com/stas-sl + + Args: + x: Input tensor to add position embeddings to [B, N, C] + patch_coord: Patch coordinates [B, N, 2] with (y, x) values + """ + device = x.device + B, _, C = x.shape + shapes = patch_coord.amax(dim=1) + 1 + + if self.pos_embed_ar_preserving: + # Aspect ratio preserving mode: use square grid with uniform scaling + L_i = shapes.amax(dim=1) # (B,) max(h_i, w_i) + L_global = L_i.amax() + grid_size_y = grid_size_x = L_global + scale_x = scale_y = L_global / L_i # uniform zoom (B,) + else: + # Standard mode: different scaling for x and y + grid_size_y, grid_size_x = shapes.amax(0) + scale_x = grid_size_x / shapes[:, 1] # horizontal zoom (B,) + scale_y = grid_size_y / shapes[:, 0] # vertical zoom (B,) + + def _interp1d(table: torch.Tensor, scale: torch.Tensor, out_length: torch.Tensor) -> torch.Tensor: + pe = table.permute(0, 2, 1).unsqueeze(2).expand(B, -1, -1, -1).float() # (1, L, C) -> (B, C, 1, L) + theta = torch.zeros(B, 2, 3, device=x.device) + theta[:, 0, 0] = scale + theta[:, 0, 2] = scale - 1 + theta[:, 1, 1] = 1 + grid = F.affine_grid(theta, (B, C, 1, out_length), align_corners=False) + pe = F.grid_sample(pe, grid, mode='bilinear', align_corners=False, padding_mode='border') + return pe.to(x.dtype) + + # Interpolate along each axis + pe_x = _interp1d(self.pos_embed_x, scale=scale_x, out_length=grid_size_x) + pe_y = _interp1d(self.pos_embed_y, scale=scale_y, out_length=grid_size_y) + + bi = torch.arange(B, device=device).unsqueeze(1) + x += pe_x[bi, :, 0, patch_coord[..., 1]] + pe_y[bi, :, 0, patch_coord[..., 0]] + + def _apply_factorized_pos_embed( + self, + x: torch.Tensor, + grid_size: List[int], + ) -> None: + """Apply factorized position embeddings to standard 2D batch in-place. + + Uses separate Y and X position embedding tables that are interpolated + and combined for the specified grid size. + + Args: + x: Input tensor to add position embeddings to [B, H*W, C] + grid_size: Target grid size as [height, width] + """ + orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1] + target_h, target_w = grid_size + + if self.pos_embed_ar_preserving: + len_y = len_x = max(target_h, target_w) + else: + len_y, len_x = target_h, target_w + + def _interp1d(table: torch.Tensor, new_length: int, orig_length: int) -> torch.Tensor: + if new_length == orig_length: + return table.to(dtype=x.dtype) + return F.interpolate( + table.permute(0, 2, 1).float(), # (1,L,C) -> (1,C,L) + size=new_length, + mode='linear', + align_corners=False, + ).permute(0, 2, 1).to(dtype=x.dtype) # (1,L,C) + + # Interpolate embeddings + pe_y = _interp1d(self.pos_embed_y, len_y, orig_h)[:, :target_h] # (1,H,C) + pe_x = _interp1d(self.pos_embed_x, len_x, orig_w)[:, :target_w] # (1,W,C) + + # Broadcast, add and flatten to sequence layout (row major) + pos_embed = pe_y.unsqueeze(2) + pe_x.unsqueeze(1) # (1, H, W, C) + pos_embed_flat = pos_embed.flatten(1, 2) # (1, H*W, C) + + x.add_(pos_embed_flat) + def forward( self, x: torch.Tensor, @@ -574,24 +712,18 @@ def forward( Embedded tensor with position encoding and class/register tokens. Shape: [B, num_prefix_tokens + N, embed_dim] """ - # Apply patch embedding - naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None grid_size: Optional[List[int]] = None - B = x.shape[0] if self.is_linear: # Linear embedding path, works with NaFlex mode or standard 2D mode - if patch_coord is not None: + if patch_coord is None: + # Standard 2D (B, C, H, W) mode + _assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4') + x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad) + else: # Pre-patchified NaFlex mode # Variable patch size mode: [B, N, Ph, Pw, C], normal mode: [B, N, P*P*C] _assert(x.ndim == 5 or x.ndim == 3, 'Expecting patchified input with ndim == 3 or 5.') - # Calculate the appropriate grid size from coords - max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1 - max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1 - naflex_grid_sizes = [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)] - else: - _assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4') - x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad) # Handle variable patch size projection if self.enable_patch_interpolator and x.ndim == 5: @@ -629,14 +761,25 @@ def forward( x = self.norm(x) if self.pos_embed_type == 'learned': - if naflex_grid_sizes is not None: - self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes) - else: - assert grid_size is not None + if grid_size is not None: + # Standard 2D mode self._apply_learned_pos_embed(x, grid_size=grid_size) + else: + # NaFlex mode + if self.pos_embed_use_grid_sample: + self._apply_learned_naflex_pos_embed_grid_sample(x, patch_coord=patch_coord) + else: + self._apply_learned_naflex_pos_embed(x, patch_coord=patch_coord) elif self.pos_embed_type == 'factorized': - if naflex_grid_sizes is not None: - self._apply_factorized_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes) + if grid_size is not None: + # Standard 2D mode + self._apply_factorized_pos_embed(x, grid_size=grid_size) + else: + # NaFlex mode + if self.pos_embed_use_grid_sample: + self._apply_factorized_naflex_pos_embed_grid_sample(x, patch_coord=patch_coord) + else: + self._apply_factorized_naflex_pos_embed(x, patch_coord=patch_coord) elif self.pos_embed_type == 'rope': assert False, "ROPE not yet implemented" @@ -874,6 +1017,7 @@ def __init__( pos_embed_grid_size=cfg.pos_embed_grid_size, pos_embed_interp_mode=cfg.pos_embed_interp_mode, pos_embed_ar_preserving=cfg.pos_embed_ar_preserving, + pos_embed_use_grid_sample=cfg.pos_embed_use_grid_sample, proj_norm_layer=embed_norm_layer, pos_drop_rate=cfg.pos_drop_rate, patch_drop_rate=cfg.patch_drop_rate, diff --git a/validate.py b/validate.py index 59e78a91fd..25781f18f5 100755 --- a/validate.py +++ b/validate.py @@ -345,11 +345,12 @@ def validate(args): model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non - input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device=device, dtype=model_dtype) - if args.channels_last: - input = input.contiguous(memory_format=torch.channels_last) - with amp_autocast(): - model(input) + if not args.naflex_loader: + input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device=device, dtype=model_dtype) + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + with amp_autocast(): + model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader):