From 4471cad24e5879537f41a94f2050b68608064724 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 13 Jun 2025 14:16:01 -0700 Subject: [PATCH 1/6] Refactor patch resampling based on feedback from https://github.com/stas-sl --- timm/layers/patch_embed.py | 38 ++++++++++++-------------------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index f87ce9693a..28d5067b82 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -270,7 +270,6 @@ def _compute_resize_matrix( eye_matrix = torch.eye(old_total, device=device, dtype=dtype) basis_vectors_batch = eye_matrix.reshape(old_total, 1, old_h, old_w) - resized_basis_vectors_batch = F.interpolate( basis_vectors_batch, size=new_size, @@ -278,17 +277,10 @@ def _compute_resize_matrix( antialias=antialias, align_corners=False ) # Output shape: (old_total, 1, new_h, new_w) - - resize_matrix = resized_basis_vectors_batch.squeeze(1).reshape(old_total, new_total).T + resize_matrix = resized_basis_vectors_batch.squeeze(1).permute(1, 2, 0).reshape(new_total, old_total) return resize_matrix # Shape: (new_total, old_total) -def _compute_pinv_for_resampling(resize_matrix: torch.Tensor) -> torch.Tensor: - """Calculates the pseudoinverse matrix used for the resampling operation.""" - pinv_matrix = torch.linalg.pinv(resize_matrix.T) # Shape: (new_total, old_total) - return pinv_matrix - - def _apply_resampling( patch_embed: torch.Tensor, pinv_matrix: torch.Tensor, @@ -296,21 +288,15 @@ def _apply_resampling( orig_dtype: torch.dtype, intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE ) -> torch.Tensor: - """Applies the precomputed pinv_matrix to resample the patch_embed tensor.""" - try: - from torch import vmap - except ImportError: - from functorch import vmap - - def resample_kernel(kernel: torch.Tensor) -> torch.Tensor: - kernel_flat = kernel.reshape(-1).to(intermediate_dtype) - resampled_kernel_flat = pinv_matrix @ kernel_flat - return resampled_kernel_flat.reshape(new_size_tuple) - - resample_kernel_vmap = vmap(vmap(resample_kernel, in_dims=0, out_dims=0), in_dims=0, out_dims=0) - patch_embed_float = patch_embed.to(intermediate_dtype) - resampled_patch_embed = resample_kernel_vmap(patch_embed_float) - return resampled_patch_embed.to(orig_dtype) + """ Simplified resampling w/o vmap use. + As proposed by https://github.com/stas-sl + """ + c_out, c_in, *_ = patch_embed.shape + patch_embed = patch_embed.reshape(c_out, c_in, -1).to(dtype=intermediate_dtype) + pinv_matrix = pinv_matrix.to(dtype=intermediate_dtype) + resampled_patch_embed = patch_embed @ pinv_matrix # (C_out, C_in, P_old * P_old) @ (P_old * P_old, P_new * P_new) + resampled_patch_embed = resampled_patch_embed.reshape(c_out, c_in, *new_size_tuple).to(dtype=orig_dtype) + return resampled_patch_embed def resample_patch_embed( @@ -336,7 +322,7 @@ def resample_patch_embed( resize_mat = _compute_resize_matrix( old_size_tuple, new_size_tuple, interpolation, antialias, device, DTYPE_INTERMEDIATE ) - pinv_matrix = _compute_pinv_for_resampling(resize_mat) + pinv_matrix = torch.linalg.pinv(resize_mat) # Calculates the pseudoinverse matrix used for resampling resampled_patch_embed = _apply_resampling( patch_embed, pinv_matrix, new_size_tuple, orig_dtype, DTYPE_INTERMEDIATE ) @@ -388,7 +374,7 @@ def _get_or_create_pinv_matrix( resize_mat = _compute_resize_matrix( self.orig_size, new_size, self.interpolation, self.antialias, device, dtype ) - pinv_matrix = _compute_pinv_for_resampling(resize_mat) + pinv_matrix = torch.linalg.pinv(resize_mat) # Calculates the pseudoinverse matrix used for resampling # Cache using register_buffer buffer_name = f"pinv_{new_size[0]}x{new_size[1]}" From 98c6a4af56070d24421ecd75d35522d74d2f70e5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 17 Jun 2025 14:31:32 -0700 Subject: [PATCH 2/6] Add grid_sample pos_embed interpolation option --- timm/models/naflexvit.py | 75 ++++++++++++++++++++++++++++++++-------- 1 file changed, 61 insertions(+), 14 deletions(-) diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py index 9684b397ec..1ed24c1751 100644 --- a/timm/models/naflexvit.py +++ b/timm/models/naflexvit.py @@ -25,6 +25,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import ( @@ -89,6 +90,7 @@ class NaFlexVitCfg: pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16) # Grid size for position embedding initialization pos_embed_interp_mode: str = 'bicubic' # Interpolation mode for position embedding resizing pos_embed_ar_preserving: bool = False # Whether to preserve aspect ratio during position embedding interpolation + pos_embed_use_grid_sample: bool = False # Whether to use grid_sample for naflex position embedding interpolation # Image processing dynamic_img_pad: bool = False # Whether to enable dynamic padding for variable resolution @@ -221,6 +223,7 @@ def __init__( pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14), pos_embed_interp_mode: str = 'bicubic', pos_embed_ar_preserving: bool = False, + pos_embed_use_grid_sample: bool = False, input_norm_layer: Optional[Type[nn.Module]] = None, proj_norm_layer: Union[bool, Optional[Type[nn.Module]]] = None, norm_layer: Optional[Type[nn.Module]] = None, @@ -256,6 +259,7 @@ def __init__( self.num_reg_tokens = reg_tokens self.pos_embed_interp_mode = pos_embed_interp_mode self.pos_embed_ar_preserving = pos_embed_ar_preserving + self.pos_embed_use_grid_sample = pos_embed_use_grid_sample self.patch_size = to_2tuple(patch_size) self.in_chans = in_chans self.embed_dim = embed_dim @@ -438,18 +442,6 @@ def _interp2d(size): )[:, :, :size[0], :size[1]].flatten(2).transpose(1, 2) return pos_embed_flat.to(dtype=x.dtype) - # FIXME leaving alternative code commented here for now for comparisons - # pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {} - # for i, s in enumerate(naflex_grid_sizes): - # if s in pos_embed_cache: - # pos_embed_flat = pos_embed_cache[s] - # else: - # pos_embed_flat = _interp(s) - # pos_embed_cache[s] = pos_embed_flat - # - # seq_len = min(x.shape[1], pos_embed_flat.shape[1]) - # x[i, :seq_len] += pos_embed_flat[0, :seq_len] - # Determine unique grid sizes to avoid duplicate interpolation size_to_indices: Dict[Tuple[int, int], List[int]] = {} for bi, k in enumerate(naflex_grid_sizes): @@ -467,6 +459,57 @@ def _interp2d(size): pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1) ) + def _apply_learned_naflex_pos_embed_grid_sample( + self, + x: torch.Tensor, + naflex_grid_sizes: List[Tuple[int, int]], + ): + """ NaFlex 2D position embedding interpolation using F.grid_sample. + + Based on proposal by https://github.com/stas-sl + """ + device = x.device + B, C = x.shape[0:2] + + def _make_coords(h, w): + _y, _x = torch.meshgrid( + torch.arange(h, device=device), + torch.arange(w, device=device), + indexing='ij', + ) + coord = torch.stack([_y.flatten(), _x.flatten()], dim=1) + return coord + + coords = pad_sequence( + [_make_coords(h, w) for h, w in naflex_grid_sizes], + batch_first=True, + ) + shapes = coords.amax(1) + 1 + theta = torch.zeros(B, 2, 3, dtype=torch.float32, device=device) + if self.pos_embed_ar_preserving: + shape_max = shapes.amax() + grid_size = (shape_max, shape_max) + L = shapes.amax(1) + theta[:, 0, 0] = grid_size[1] / L # scale x + theta[:, 1, 1] = grid_size[0] / L # scale y + else: + grid_size = shapes.amax(0) + theta[:, 0, 0] = grid_size[1] / shapes[:, 1] # scale x + theta[:, 1, 1] = grid_size[0] / shapes[:, 0] # scale y + theta[:, 0, 2] = theta[:, 0, 0] - 1 # translate x + theta[:, 1, 2] = theta[:, 1, 1] - 1 # translate y + grid = F.affine_grid(theta, (B, C, *grid_size), align_corners=False) + pos_embed = F.grid_sample( + self.pos_embed.permute(0, 3, 1, 2).expand(B, -1, -1, -1).float(), + grid, + mode=self.pos_embed_interp_mode, + align_corners=False, + padding_mode='border', + ).to(dtype=x.dtype) + bi = torch.arange(B, device=device).unsqueeze(1).expand(-1, coords.shape[1]) + # NOTE leave as '+=', do not change to .add_(...) + x += pos_embed[bi, :, coords[..., 0], coords[..., 1]] + def _apply_learned_pos_embed( self, x: torch.Tensor, @@ -516,7 +559,7 @@ def _apply_factorized_naflex_pos_embed( # Handle each batch element separately with its own grid size orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1] - # bucket samples that share the same (H,W) so we build each grid once + # bucket samples that share the same (H, W) so we build each grid once size_to_indices: Dict[Tuple[int, int], List[int]] = {} for bi, k in enumerate(naflex_grid_sizes): size_to_indices.setdefault(k, []).append(bi) @@ -630,7 +673,10 @@ def forward( if self.pos_embed_type == 'learned': if naflex_grid_sizes is not None: - self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes) + if self.pos_embed_use_grid_sample: + self._apply_learned_naflex_pos_embed_grid_sample(x, naflex_grid_sizes=naflex_grid_sizes) + else: + self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes) else: assert grid_size is not None self._apply_learned_pos_embed(x, grid_size=grid_size) @@ -874,6 +920,7 @@ def __init__( pos_embed_grid_size=cfg.pos_embed_grid_size, pos_embed_interp_mode=cfg.pos_embed_interp_mode, pos_embed_ar_preserving=cfg.pos_embed_ar_preserving, + pos_embed_use_grid_sample=cfg.pos_embed_use_grid_sample, proj_norm_layer=embed_norm_layer, pos_drop_rate=cfg.pos_drop_rate, patch_drop_rate=cfg.patch_drop_rate, From c2ba04ca345fbbee720f001ba47cea646cefdde7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 17 Jun 2025 20:21:31 -0700 Subject: [PATCH 3/6] Slight tweak --- timm/models/naflexvit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py index 1ed24c1751..fcc5ae2cb4 100644 --- a/timm/models/naflexvit.py +++ b/timm/models/naflexvit.py @@ -487,9 +487,9 @@ def _make_coords(h, w): shapes = coords.amax(1) + 1 theta = torch.zeros(B, 2, 3, dtype=torch.float32, device=device) if self.pos_embed_ar_preserving: - shape_max = shapes.amax() - grid_size = (shape_max, shape_max) L = shapes.amax(1) + grid_max = L.amax() + grid_size = (grid_max, grid_max) theta[:, 0, 0] = grid_size[1] / L # scale x theta[:, 1, 1] = grid_size[0] / L # scale y else: From 4e3cba847653069c1f8c50f6101a195bb4208a2c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 17 Jun 2025 21:36:47 -0700 Subject: [PATCH 4/6] Fix silly shape bug, and fix issue with pad_sequence when none of the shapes in batch use full seq len. --- timm/models/naflexvit.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py index fcc5ae2cb4..eaee2b9808 100644 --- a/timm/models/naflexvit.py +++ b/timm/models/naflexvit.py @@ -469,7 +469,7 @@ def _apply_learned_naflex_pos_embed_grid_sample( Based on proposal by https://github.com/stas-sl """ device = x.device - B, C = x.shape[0:2] + B, N, C = x.shape def _make_coords(h, w): _y, _x = torch.meshgrid( @@ -480,10 +480,12 @@ def _make_coords(h, w): coord = torch.stack([_y.flatten(), _x.flatten()], dim=1) return coord - coords = pad_sequence( - [_make_coords(h, w) for h, w in naflex_grid_sizes], - batch_first=True, - ) + coords = torch.zeros(B, N, 2, dtype=torch.long, device=device) + for i, (h, w) in enumerate(naflex_grid_sizes): + coords_i = _make_coords(h, w) # (h*w, 2) + coords[i, :coords_i.shape[0]] = coords_i # pad with zeros past h*w + # FIXME should we be masking? + shapes = coords.amax(1) + 1 theta = torch.zeros(B, 2, 3, dtype=torch.float32, device=device) if self.pos_embed_ar_preserving: @@ -506,7 +508,7 @@ def _make_coords(h, w): align_corners=False, padding_mode='border', ).to(dtype=x.dtype) - bi = torch.arange(B, device=device).unsqueeze(1).expand(-1, coords.shape[1]) + bi = torch.arange(B, device=device).unsqueeze(1) # NOTE leave as '+=', do not change to .add_(...) x += pos_embed[bi, :, coords[..., 0], coords[..., 1]] From ab0c06ce502ebe17bc37840428fa4db6e7542b16 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 18 Jun 2025 11:59:38 -0700 Subject: [PATCH 5/6] Fix up grid_sample, did not make sense to rebuild patch coords, duh --- timm/models/naflexvit.py | 68 +++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py index eaee2b9808..5ae80becca 100644 --- a/timm/models/naflexvit.py +++ b/timm/models/naflexvit.py @@ -462,7 +462,8 @@ def _interp2d(size): def _apply_learned_naflex_pos_embed_grid_sample( self, x: torch.Tensor, - naflex_grid_sizes: List[Tuple[int, int]], + patch_coord: torch.Tensor, + patch_valid: Optional[torch.Tensor] = None, ): """ NaFlex 2D position embedding interpolation using F.grid_sample. @@ -470,36 +471,24 @@ def _apply_learned_naflex_pos_embed_grid_sample( """ device = x.device B, N, C = x.shape + shapes = patch_coord.max(dim=1).values + 1 # (B, 2) containing [h_i, w_i] - def _make_coords(h, w): - _y, _x = torch.meshgrid( - torch.arange(h, device=device), - torch.arange(w, device=device), - indexing='ij', - ) - coord = torch.stack([_y.flatten(), _x.flatten()], dim=1) - return coord - - coords = torch.zeros(B, N, 2, dtype=torch.long, device=device) - for i, (h, w) in enumerate(naflex_grid_sizes): - coords_i = _make_coords(h, w) # (h*w, 2) - coords[i, :coords_i.shape[0]] = coords_i # pad with zeros past h*w - # FIXME should we be masking? - - shapes = coords.amax(1) + 1 - theta = torch.zeros(B, 2, 3, dtype=torch.float32, device=device) if self.pos_embed_ar_preserving: - L = shapes.amax(1) - grid_max = L.amax() - grid_size = (grid_max, grid_max) - theta[:, 0, 0] = grid_size[1] / L # scale x - theta[:, 1, 1] = grid_size[0] / L # scale y + L_i = shapes.amax(dim=1) # (B,) max(h_i, w_i) + L_global = L_i.amax() + grid_size = (L_global, L_global) + s_x = s_y = L_global / L_i # uniform zoom (B,) else: - grid_size = shapes.amax(0) - theta[:, 0, 0] = grid_size[1] / shapes[:, 1] # scale x - theta[:, 1, 1] = grid_size[0] / shapes[:, 0] # scale y + grid_size = shapes.amax(dim=0) + s_x = grid_size[1] / shapes[:, 1] # horizontal zoom (B,) + s_y = grid_size[0] / shapes[:, 0] # vertical zoom (B,) + + theta = torch.zeros(B, 2, 3, device=device, dtype=torch.float32) + theta[:, 0, 0] = s_x # scale x + theta[:, 1, 1] = s_y # scale y theta[:, 0, 2] = theta[:, 0, 0] - 1 # translate x theta[:, 1, 2] = theta[:, 1, 1] - 1 # translate y + grid = F.affine_grid(theta, (B, C, *grid_size), align_corners=False) pos_embed = F.grid_sample( self.pos_embed.permute(0, 3, 1, 2).expand(B, -1, -1, -1).float(), @@ -507,10 +496,20 @@ def _make_coords(h, w): mode=self.pos_embed_interp_mode, align_corners=False, padding_mode='border', - ).to(dtype=x.dtype) + ).to(dtype=x.dtype) # (B, C, H_out, W_out) + + # NOTE if we bring in patch_valid, can explicitly mask padding tokens + # more experimentation at train time needed + # lin_idx = patch_coord[..., 0] * grid_size[1] + patch_coord[..., 1] # (B, N) + # pos_flat = pos_embed.flatten(2).transpose(1, 2) + # pos_flat = pos_flat.gather(1, lin_idx.unsqueeze(2).expand(-1, -1, C)) # (B, N, C) + # if patch_valid is not None: + # pos_flat.mul_(patch_valid.unsqueeze(2)) + # idx_vec = torch.arange(N, device=device) # (N,) + # x.index_add_(1, idx_vec, pos_flat) + bi = torch.arange(B, device=device).unsqueeze(1) - # NOTE leave as '+=', do not change to .add_(...) - x += pos_embed[bi, :, coords[..., 0], coords[..., 1]] + x += pos_embed[bi, :, patch_coord[..., 0], patch_coord[..., 1]] # NOTE leave as '+=' def _apply_learned_pos_embed( self, @@ -605,6 +604,7 @@ def forward( self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None, + patch_valid: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass for patch embedding with position encoding. @@ -676,7 +676,11 @@ def forward( if self.pos_embed_type == 'learned': if naflex_grid_sizes is not None: if self.pos_embed_use_grid_sample: - self._apply_learned_naflex_pos_embed_grid_sample(x, naflex_grid_sizes=naflex_grid_sizes) + self._apply_learned_naflex_pos_embed_grid_sample( + x, + patch_coord=patch_coord, + patch_valid=patch_valid, + ) else: self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes) else: @@ -1146,7 +1150,7 @@ def forward_intermediates( mask = create_attention_mask(patch_valid, self.num_prefix_tokens, patches.dtype) # Forward pass through embedding - x = self.embeds(patches, patch_coord=patch_coord) + x = self.embeds(patches, patch_coord=patch_coord, patch_valid=patch_valid) x = self.norm_pre(x) # Forward pass through blocks @@ -1219,7 +1223,7 @@ def forward_features( ) # Pass through embedding module with patch coordinate/type support - x = self.embeds(x, patch_coord=patch_coord) + x = self.embeds(x, patch_coord=patch_coord, patch_valid=patch_valid) x = self.norm_pre(x) # Apply transformer blocks with masked attention if mask provided if attn_mask is not None: From 41058e962d26f7d5a91bd4c3526bec6799314fd9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 19 Jun 2025 10:47:37 -0700 Subject: [PATCH 6/6] Refactor and cleanup NaFlex pos embed methods. Make interface more consistent, add factorized grid_sample and fixed grid size methods. --- timm/models/naflexvit.py | 221 +++++++++++++++++++++++++++------------ validate.py | 11 +- 2 files changed, 162 insertions(+), 70 deletions(-) diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py index 5ae80becca..539543f5e7 100644 --- a/timm/models/naflexvit.py +++ b/timm/models/naflexvit.py @@ -165,6 +165,13 @@ def batch_patchify( return patches, (nh, nw) +def calculate_naflex_grid_sizes(_coord: torch.Tensor): + # Calculate the appropriate grid size from coords + max_y = _coord[:, :, 0].amax(dim=1) + 1 + max_x = _coord[:, :, 1].amax(dim=1) + 1 + return [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)] + + @register_notrace_module class NaFlexEmbeds(nn.Module): """NaFlex Embedding module for Vision Transformers. @@ -407,18 +414,19 @@ def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: def _apply_learned_naflex_pos_embed( self, x: torch.Tensor, - naflex_grid_sizes: List[Tuple[int, int]], + patch_coord: torch.Tensor, ) -> None: """Apply learned position embeddings to NaFlex batch in-place. - Interpolates learned position embeddings for each sample in the batch + Interpolates learned 2D position embeddings for each sample in the batch based on their individual grid sizes. Args: - x: Input tensor to add position embeddings to - naflex_grid_sizes: List of (height, width) grid sizes for each batch element + x: Input tensor to add position embeddings to [B, N, C] + patch_coord: Patch coordinates [B, N, 2] with (y, x) values """ - # Handle each batch element separately with its own grid size + # Calculate grid sizes from patch coordinates + naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord) orig_h, orig_w = self.pos_embed.shape[1:3] pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W @@ -463,33 +471,37 @@ def _apply_learned_naflex_pos_embed_grid_sample( self, x: torch.Tensor, patch_coord: torch.Tensor, - patch_valid: Optional[torch.Tensor] = None, - ): - """ NaFlex 2D position embedding interpolation using F.grid_sample. + ) -> None: + """Apply learned position embeddings to NaFlex batch using grid_sample. + + Uses F.grid_sample for efficient interpolation of learned 2D position embeddings + based on patch coordinates. Based on proposal by https://github.com/stas-sl - Based on proposal by https://github.com/stas-sl + Args: + x: Input tensor to add position embeddings to [B, N, C] + patch_coord: Patch coordinates [B, N, 2] with (y, x) values """ device = x.device B, N, C = x.shape shapes = patch_coord.max(dim=1).values + 1 # (B, 2) containing [h_i, w_i] if self.pos_embed_ar_preserving: - L_i = shapes.amax(dim=1) # (B,) max(h_i, w_i) + L_i = shapes.amax(dim=1) # (B,) max(h_i, w_i) L_global = L_i.amax() - grid_size = (L_global, L_global) - s_x = s_y = L_global / L_i # uniform zoom (B,) + grid_size_y = grid_size_x = L_global + scale_x = scale_y = L_global / L_i # uniform zoom (B,) else: - grid_size = shapes.amax(dim=0) - s_x = grid_size[1] / shapes[:, 1] # horizontal zoom (B,) - s_y = grid_size[0] / shapes[:, 0] # vertical zoom (B,) + grid_size_y, grid_size_x = shapes.amax(dim=0) # (2,) + scale_y = grid_size_y / shapes[:, 0] # vertical zoom (B,) + scale_x = grid_size_x / shapes[:, 1] # horizontal zoom (B,) theta = torch.zeros(B, 2, 3, device=device, dtype=torch.float32) - theta[:, 0, 0] = s_x # scale x - theta[:, 1, 1] = s_y # scale y - theta[:, 0, 2] = theta[:, 0, 0] - 1 # translate x - theta[:, 1, 2] = theta[:, 1, 1] - 1 # translate y + theta[:, 0, 0] = scale_x + theta[:, 1, 1] = scale_y + theta[:, 0, 2] = scale_x - 1 # translate x + theta[:, 1, 2] = scale_y - 1 # translate y - grid = F.affine_grid(theta, (B, C, *grid_size), align_corners=False) + grid = F.affine_grid(theta, (B, C, grid_size_y, grid_size_x), align_corners=False) pos_embed = F.grid_sample( self.pos_embed.permute(0, 3, 1, 2).expand(B, -1, -1, -1).float(), grid, @@ -498,16 +510,6 @@ def _apply_learned_naflex_pos_embed_grid_sample( padding_mode='border', ).to(dtype=x.dtype) # (B, C, H_out, W_out) - # NOTE if we bring in patch_valid, can explicitly mask padding tokens - # more experimentation at train time needed - # lin_idx = patch_coord[..., 0] * grid_size[1] + patch_coord[..., 1] # (B, N) - # pos_flat = pos_embed.flatten(2).transpose(1, 2) - # pos_flat = pos_flat.gather(1, lin_idx.unsqueeze(2).expand(-1, -1, C)) # (B, N, C) - # if patch_valid is not None: - # pos_flat.mul_(patch_valid.unsqueeze(2)) - # idx_vec = torch.arange(N, device=device) # (N,) - # x.index_add_(1, idx_vec, pos_flat) - bi = torch.arange(B, device=device).unsqueeze(1) x += pos_embed[bi, :, patch_coord[..., 0], patch_coord[..., 1]] # NOTE leave as '+=' @@ -516,27 +518,28 @@ def _apply_learned_pos_embed( x: torch.Tensor, grid_size: List[int], ) -> None: - """Apply learned position embeddings to standard batch in-place. + """Apply learned position embeddings to standard 2D batch in-place. - Interpolates learned position embeddings to match the specified grid size. + Interpolates learned 2D position embeddings to match the specified grid size. Args: - x: Input tensor to add position embeddings to + x: Input tensor to add position embeddings to [B, H*W, C] grid_size: Target grid size as [height, width] """ orig_h, orig_w = self.pos_embed.shape[1:3] - if grid_size[0] == orig_h or grid_size[1] == orig_w: + if grid_size[0] == orig_h and grid_size[1] == orig_w: # No resize needed, just flatten pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) else: # Resize if needed - directly using F.interpolate + _interp_size = to_2tuple(max(grid_size)) if self.pos_embed_ar_preserving else grid_size pos_embed_flat = F.interpolate( self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W - size=grid_size, + size=_interp_size, mode=self.pos_embed_interp_mode, align_corners=False, antialias=True, - ).flatten(2).transpose(1, 2) + )[:, :, :grid_size[0], :grid_size[1]].flatten(2).transpose(1, 2) pos_embed_flat = pos_embed_flat.to(dtype=x.dtype) x.add_(pos_embed_flat) @@ -544,7 +547,7 @@ def _apply_learned_pos_embed( def _apply_factorized_naflex_pos_embed( self, x: torch.Tensor, - naflex_grid_sizes: List[Tuple[int, int]], + patch_coord: torch.Tensor, ) -> None: """Apply factorized position embeddings to NaFlex batch in-place. @@ -552,9 +555,11 @@ def _apply_factorized_naflex_pos_embed( and combined for each sample's grid size. Args: - x: Input tensor to add position embeddings to - naflex_grid_sizes: List of (height, width) grid sizes for each batch element + x: Input tensor to add position embeddings to [B, N, C] + patch_coord: Patch coordinates [B, N, 2] with (y, x) values """ + # Calculate grid sizes from patch coordinates + naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord) assert len(naflex_grid_sizes) == x.size(0) # one (H,W) per sample # Handle each batch element separately with its own grid size @@ -600,11 +605,99 @@ def _interp1d(table: torch.Tensor, new_length: int, orig_length: int) -> torch.T pos[:, :seq_len].expand(len(batch_indices), -1, -1) ) + def _apply_factorized_naflex_pos_embed_grid_sample( + self, + x: torch.Tensor, + patch_coord: torch.Tensor, + ) -> None: + """Apply factorized position embeddings to NaFlex batch using grid_sample. + + Uses F.grid_sample for efficient interpolation of separate Y and X position + embedding tables based on patch coordinates. Based on proposal by https://github.com/stas-sl + + Args: + x: Input tensor to add position embeddings to [B, N, C] + patch_coord: Patch coordinates [B, N, 2] with (y, x) values + """ + device = x.device + B, _, C = x.shape + shapes = patch_coord.amax(dim=1) + 1 + + if self.pos_embed_ar_preserving: + # Aspect ratio preserving mode: use square grid with uniform scaling + L_i = shapes.amax(dim=1) # (B,) max(h_i, w_i) + L_global = L_i.amax() + grid_size_y = grid_size_x = L_global + scale_x = scale_y = L_global / L_i # uniform zoom (B,) + else: + # Standard mode: different scaling for x and y + grid_size_y, grid_size_x = shapes.amax(0) + scale_x = grid_size_x / shapes[:, 1] # horizontal zoom (B,) + scale_y = grid_size_y / shapes[:, 0] # vertical zoom (B,) + + def _interp1d(table: torch.Tensor, scale: torch.Tensor, out_length: torch.Tensor) -> torch.Tensor: + pe = table.permute(0, 2, 1).unsqueeze(2).expand(B, -1, -1, -1).float() # (1, L, C) -> (B, C, 1, L) + theta = torch.zeros(B, 2, 3, device=x.device) + theta[:, 0, 0] = scale + theta[:, 0, 2] = scale - 1 + theta[:, 1, 1] = 1 + grid = F.affine_grid(theta, (B, C, 1, out_length), align_corners=False) + pe = F.grid_sample(pe, grid, mode='bilinear', align_corners=False, padding_mode='border') + return pe.to(x.dtype) + + # Interpolate along each axis + pe_x = _interp1d(self.pos_embed_x, scale=scale_x, out_length=grid_size_x) + pe_y = _interp1d(self.pos_embed_y, scale=scale_y, out_length=grid_size_y) + + bi = torch.arange(B, device=device).unsqueeze(1) + x += pe_x[bi, :, 0, patch_coord[..., 1]] + pe_y[bi, :, 0, patch_coord[..., 0]] + + def _apply_factorized_pos_embed( + self, + x: torch.Tensor, + grid_size: List[int], + ) -> None: + """Apply factorized position embeddings to standard 2D batch in-place. + + Uses separate Y and X position embedding tables that are interpolated + and combined for the specified grid size. + + Args: + x: Input tensor to add position embeddings to [B, H*W, C] + grid_size: Target grid size as [height, width] + """ + orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1] + target_h, target_w = grid_size + + if self.pos_embed_ar_preserving: + len_y = len_x = max(target_h, target_w) + else: + len_y, len_x = target_h, target_w + + def _interp1d(table: torch.Tensor, new_length: int, orig_length: int) -> torch.Tensor: + if new_length == orig_length: + return table.to(dtype=x.dtype) + return F.interpolate( + table.permute(0, 2, 1).float(), # (1,L,C) -> (1,C,L) + size=new_length, + mode='linear', + align_corners=False, + ).permute(0, 2, 1).to(dtype=x.dtype) # (1,L,C) + + # Interpolate embeddings + pe_y = _interp1d(self.pos_embed_y, len_y, orig_h)[:, :target_h] # (1,H,C) + pe_x = _interp1d(self.pos_embed_x, len_x, orig_w)[:, :target_w] # (1,W,C) + + # Broadcast, add and flatten to sequence layout (row major) + pos_embed = pe_y.unsqueeze(2) + pe_x.unsqueeze(1) # (1, H, W, C) + pos_embed_flat = pos_embed.flatten(1, 2) # (1, H*W, C) + + x.add_(pos_embed_flat) + def forward( self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None, - patch_valid: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass for patch embedding with position encoding. @@ -619,24 +712,18 @@ def forward( Embedded tensor with position encoding and class/register tokens. Shape: [B, num_prefix_tokens + N, embed_dim] """ - # Apply patch embedding - naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None grid_size: Optional[List[int]] = None - B = x.shape[0] if self.is_linear: # Linear embedding path, works with NaFlex mode or standard 2D mode - if patch_coord is not None: + if patch_coord is None: + # Standard 2D (B, C, H, W) mode + _assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4') + x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad) + else: # Pre-patchified NaFlex mode # Variable patch size mode: [B, N, Ph, Pw, C], normal mode: [B, N, P*P*C] _assert(x.ndim == 5 or x.ndim == 3, 'Expecting patchified input with ndim == 3 or 5.') - # Calculate the appropriate grid size from coords - max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1 - max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1 - naflex_grid_sizes = [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)] - else: - _assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4') - x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad) # Handle variable patch size projection if self.enable_patch_interpolator and x.ndim == 5: @@ -674,21 +761,25 @@ def forward( x = self.norm(x) if self.pos_embed_type == 'learned': - if naflex_grid_sizes is not None: + if grid_size is not None: + # Standard 2D mode + self._apply_learned_pos_embed(x, grid_size=grid_size) + else: + # NaFlex mode if self.pos_embed_use_grid_sample: - self._apply_learned_naflex_pos_embed_grid_sample( - x, - patch_coord=patch_coord, - patch_valid=patch_valid, - ) + self._apply_learned_naflex_pos_embed_grid_sample(x, patch_coord=patch_coord) else: - self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes) - else: - assert grid_size is not None - self._apply_learned_pos_embed(x, grid_size=grid_size) + self._apply_learned_naflex_pos_embed(x, patch_coord=patch_coord) elif self.pos_embed_type == 'factorized': - if naflex_grid_sizes is not None: - self._apply_factorized_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes) + if grid_size is not None: + # Standard 2D mode + self._apply_factorized_pos_embed(x, grid_size=grid_size) + else: + # NaFlex mode + if self.pos_embed_use_grid_sample: + self._apply_factorized_naflex_pos_embed_grid_sample(x, patch_coord=patch_coord) + else: + self._apply_factorized_naflex_pos_embed(x, patch_coord=patch_coord) elif self.pos_embed_type == 'rope': assert False, "ROPE not yet implemented" @@ -1150,7 +1241,7 @@ def forward_intermediates( mask = create_attention_mask(patch_valid, self.num_prefix_tokens, patches.dtype) # Forward pass through embedding - x = self.embeds(patches, patch_coord=patch_coord, patch_valid=patch_valid) + x = self.embeds(patches, patch_coord=patch_coord) x = self.norm_pre(x) # Forward pass through blocks @@ -1223,7 +1314,7 @@ def forward_features( ) # Pass through embedding module with patch coordinate/type support - x = self.embeds(x, patch_coord=patch_coord, patch_valid=patch_valid) + x = self.embeds(x, patch_coord=patch_coord) x = self.norm_pre(x) # Apply transformer blocks with masked attention if mask provided if attn_mask is not None: diff --git a/validate.py b/validate.py index 59e78a91fd..25781f18f5 100755 --- a/validate.py +++ b/validate.py @@ -345,11 +345,12 @@ def validate(args): model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non - input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device=device, dtype=model_dtype) - if args.channels_last: - input = input.contiguous(memory_format=torch.channels_last) - with amp_autocast(): - model(input) + if not args.naflex_loader: + input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device=device, dtype=model_dtype) + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + with amp_autocast(): + model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader):