diff --git a/tests/test_models.py b/tests/test_models.py index 6d3615e4f2..b4686a3efe 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -56,12 +56,12 @@ 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt', 'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest', 'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext', - 'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet', + 'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet', 'naflexvit' ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. NON_STD_FILTERS = [ - 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', + 'vit_*', 'naflexvit*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', 'swiftformer_*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*', 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*', @@ -81,7 +81,7 @@ EXCLUDE_FILTERS = ['*enormous*'] NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*'] -EXCLUDE_JIT_FILTERS = ['hiera_*'] +EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*'] TARGET_FWD_SIZE = MAX_FWD_SIZE = 384 TARGET_BWD_SIZE = 128 diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 4b95fbd174..304d6129bf 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -8,6 +8,18 @@ from .imagenet_info import ImageNetInfo, infer_imagenet_subset from .loader import create_loader from .mixup import Mixup, FastCollateMixup +from .naflex_dataset import NaFlexMapDatasetWrapper, calculate_naflex_batch_size +from .naflex_loader import create_naflex_loader +from .naflex_mixup import NaFlexMixup, pairwise_mixup_target, mix_batch_variable_size +from .naflex_transforms import ( + ResizeToSequence, + CenterCropToSequence, + RandomCropToSequence, + RandomResizedCropToSequence, + ResizeKeepRatioToSequence, + Patchify, + patchify_image, +) from .readers import create_reader from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions from .real_labels import RealLabelsImagenet diff --git a/timm/data/loader.py b/timm/data/loader.py index 313440570c..313f33efe4 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -33,6 +33,7 @@ def fast_collate(batch): if isinstance(batch[0][0], tuple): # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position + is_np = isinstance(batch[0][0], np.ndarray) inner_tuple_size = len(batch[0][0]) flattened_batch_size = batch_size * inner_tuple_size targets = torch.zeros(flattened_batch_size, dtype=torch.int64) @@ -41,7 +42,10 @@ def fast_collate(batch): assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length for j in range(inner_tuple_size): targets[i + j * batch_size] = batch[i][1] - tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j]) + if is_np: + tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j]) + else: + tensor[i + j * batch_size] += batch[i][0][j] return tensor, targets elif isinstance(batch[0][0], np.ndarray): targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) diff --git a/timm/data/mixup.py b/timm/data/mixup.py index 26dc239152..e40622915b 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -229,6 +229,8 @@ def _mix_elem_collate(self, output, batch, half=False): num_elem = batch_size // 2 if half else batch_size assert len(output) == num_elem lam_batch, use_cutmix = self._params_per_elem(num_elem) + is_np = isinstance(batch[0][0], np.ndarray) + for i in range(num_elem): j = batch_size - i - 1 lam = lam_batch[i] @@ -236,15 +238,23 @@ def _mix_elem_collate(self, output, batch, half=False): if lam != 1.: if use_cutmix[i]: if not half: - mixed = mixed.copy() + mixed = mixed.copy() if is_np else mixed.clone() (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( - output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + output.shape, + lam, + ratio_minmax=self.cutmix_minmax, + correct_lam=self.correct_lam, + ) mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] lam_batch[i] = lam else: - mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) - np.rint(mixed, out=mixed) - output[i] += torch.from_numpy(mixed.astype(np.uint8)) + if is_np: + mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) + np.rint(mixed, out=mixed) + else: + mixed = mixed.float() * lam + batch[j][0].float() * (1 - lam) + torch.round(mixed, out=mixed) + output[i] += torch.from_numpy(mixed.astype(np.uint8)) if is_np else mixed.byte() if half: lam_batch = np.concatenate((lam_batch, np.ones(num_elem))) return torch.tensor(lam_batch).unsqueeze(1) @@ -252,6 +262,8 @@ def _mix_elem_collate(self, output, batch, half=False): def _mix_pair_collate(self, output, batch): batch_size = len(batch) lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) + is_np = isinstance(batch[0][0], np.ndarray) + for i in range(batch_size // 2): j = batch_size - i - 1 lam = lam_batch[i] @@ -261,39 +273,60 @@ def _mix_pair_collate(self, output, batch): if lam < 1.: if use_cutmix[i]: (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( - output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) - patch_i = mixed_i[:, yl:yh, xl:xh].copy() + output.shape, + lam, + ratio_minmax=self.cutmix_minmax, + correct_lam=self.correct_lam, + ) + patch_i = mixed_i[:, yl:yh, xl:xh].copy() if is_np else mixed_i[:, yl:yh, xl:xh].clone() mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh] mixed_j[:, yl:yh, xl:xh] = patch_i lam_batch[i] = lam else: - mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam) - mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam) - mixed_i = mixed_temp - np.rint(mixed_j, out=mixed_j) - np.rint(mixed_i, out=mixed_i) - output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) - output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) + if is_np: + mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam) + mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam) + mixed_i = mixed_temp + np.rint(mixed_j, out=mixed_j) + np.rint(mixed_i, out=mixed_i) + else: + mixed_temp = mixed_i.float() * lam + mixed_j.float() * (1 - lam) + mixed_j = mixed_j.float() * lam + mixed_i.float() * (1 - lam) + mixed_i = mixed_temp + torch.round(mixed_j, out=mixed_j) + torch.round(mixed_i, out=mixed_i) + output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) if is_np else mixed_i.byte() + output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) if is_np else mixed_j.byte() lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) return torch.tensor(lam_batch).unsqueeze(1) def _mix_batch_collate(self, output, batch): batch_size = len(batch) lam, use_cutmix = self._params_per_batch() + is_np = isinstance(batch[0][0], np.ndarray) + if use_cutmix: (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( - output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + output.shape, + lam, + ratio_minmax=self.cutmix_minmax, + correct_lam=self.correct_lam, + ) for i in range(batch_size): j = batch_size - i - 1 mixed = batch[i][0] if lam != 1.: if use_cutmix: - mixed = mixed.copy() # don't want to modify the original while iterating + mixed = mixed.copy() if is_np else mixed.clone() # don't want to modify the original while iterating mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] else: - mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) - np.rint(mixed, out=mixed) - output[i] += torch.from_numpy(mixed.astype(np.uint8)) + if is_np: + mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) + np.rint(mixed, out=mixed) + else: + mixed = mixed.float() * lam + batch[j][0].float() * (1 - lam) + torch.round(mixed, out=mixed) + output[i] += torch.from_numpy(mixed.astype(np.uint8)) if is_np else mixed.byte() return lam def __call__(self, batch, _=None): diff --git a/timm/data/naflex_dataset.py b/timm/data/naflex_dataset.py new file mode 100644 index 0000000000..beff4d1685 --- /dev/null +++ b/timm/data/naflex_dataset.py @@ -0,0 +1,558 @@ +""" Dynamic Sequence Length Datasets for Variable Resolution Image Processing + +Implements two dataset wrappers: +1. NaFlexMapDatasetWrapper - Map-style dataset that returns batches with variable sequence lengths +TODO: 2. NaFlexIterableDatasetWrapper - Iterable dataset that yields batches with variable sequence lengths + +Both support: +- Pre-initialized transforms for efficiency +- Distributed training +- Multiple workers +- Variable batch sizes based on sequence length + +Hacked together by / Copyright 2025, Ross Wightman, Hugging Face +""" + +import math +import random +import warnings +from functools import partial +from typing import Any, Iterator, List, Tuple, Dict, Optional, Union, Callable + +import torch +from torch.utils.data import Dataset, IterableDataset, DataLoader +from PIL import Image + +from .naflex_transforms import Patchify +from timm.layers import to_2tuple + + +def calculate_naflex_batch_size( + tokens_per_batch: int, + seq_len: int, + max_size: Optional[int] = None, + divisor: int = 1, + rounding: str = 'floor', +) -> int: + """Calculate batch size based on sequence length with divisibility constraints. + + Args: + tokens_per_batch: Target number of tokens per batch. + seq_len: Sequence length for this batch. + max_size: Optional maximum batch size. + divisor: Ensure batch size is divisible by this value. + rounding: Rounding method ('floor', 'ceil', 'round'). + + Returns: + Calculated batch size. + """ + # Calculate raw batch size based on sequence length + raw_batch_size = tokens_per_batch / seq_len + + # Apply divisibility with specified rounding method + if divisor > 1: + if rounding == 'floor': + batch_size = math.floor(raw_batch_size / divisor) * divisor + elif rounding == 'ceil': + batch_size = math.ceil(raw_batch_size / divisor) * divisor + else: # 'round' is the default + batch_size = round(raw_batch_size / divisor) * divisor + else: + # If no divisor specified, just use integer division + batch_size = int(raw_batch_size) + + # Ensure batch size is valid + batch_size = max(1, batch_size) # At least 1 + + if max_size is not None: + batch_size = min(batch_size, max_size) + + return batch_size + + +class NaFlexCollator: + """Custom collator for batching NaFlex-style variable-resolution images.""" + + def __init__( + self, + max_seq_len: Optional[int] = None, + ) -> None: + """Initialize NaFlexCollator. + + Args: + max_seq_len: Maximum sequence length for batching. + """ + self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24) + + def __call__(self, batch: List[Tuple[Dict[str, torch.Tensor], Union[int, torch.Tensor]]]) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: + """Collate batch of NaFlex samples. + + Args: + batch: List of tuples (patch_dict, target). + + Returns: + A tuple of (input_dict, targets) where input_dict contains: + - patches: Padded tensor of patches + - patch_coord: Coordinates for each patch (y, x) + - patch_valid: Valid indicators + """ + assert isinstance(batch[0], tuple) + batch_size = len(batch) + + # Extract targets + targets = [item[1] for item in batch] + if isinstance(targets[0], torch.Tensor): + targets = torch.stack(targets) + else: + targets = torch.tensor(targets, dtype=torch.int64) + + # Get patch dictionaries + patch_dicts = [item[0] for item in batch] + + # If we have a maximum sequence length constraint, ensure we don't exceed it + if self.max_seq_len is not None: + max_patches = self.max_seq_len + else: + # Find the maximum number of patches in this batch + max_patches = max(item['patches'].shape[0] for item in patch_dicts) + + # Check if patches are flattened or unflattened + patches_tensor = patch_dicts[0]['patches'] + is_unflattened = patches_tensor.ndim == 4 # [N, Ph, Pw, C] + + if is_unflattened: + # Patches are [N, Ph, Pw, C] - variable patch size mode + _, ph, pw, c = patches_tensor.shape + patches = torch.zeros((batch_size, max_patches, ph, pw, c), dtype=torch.float32) + else: + # Patches are [N, P*P*C] - normal mode + patch_dim = patches_tensor.shape[1] + patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32) + + # Prepare other tensors + patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x) + patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool) + + # Fill in the tensors + for i, patch_dict in enumerate(patch_dicts): + num_patches = min(patch_dict['patches'].shape[0], max_patches) + + patches[i, :num_patches] = patch_dict['patches'][:num_patches] + patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches] + patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches] + + result = { + 'patches': patches, + 'patch_coord': patch_coord, + 'patch_valid': patch_valid, + 'seq_len': max_patches, + } + + return result, targets + + +def _resolve_patch_cfg( + patch_size: Optional[Union[int, Tuple[int, int]]], + patch_size_choices: Optional[List[int]], + patch_size_choice_probs: Optional[List[float]], +) -> Tuple[List[Tuple[int, int]], List[float], bool]: + """Resolve patch size configuration. + + Args: + patch_size: Single patch size to use. + patch_size_choices: List of patch sizes to choose from. + patch_size_choice_probs: Probabilities for each patch size choice. + + Returns: + Tuple of (sizes, probs, variable) where sizes is list of patch size tuples, + probs is list of probabilities, and variable indicates if patch size varies. + """ + # If both are None, default to patch_size=16 + if patch_size is None and patch_size_choices is None: + patch_size = 16 + + if (patch_size is None) == (patch_size_choices is None): + raise ValueError( + "Specify exactly one of `patch_size` or `patch_size_choices`." + ) + + if patch_size is not None: + sizes = [to_2tuple(patch_size)] + probs = [1.0] + variable = False + else: + sizes = [to_2tuple(p) for p in patch_size_choices] + if patch_size_choice_probs is None: + probs = [1.0 / len(sizes)] * len(sizes) + else: + if len(patch_size_choice_probs) != len(sizes): + raise ValueError("`patch_size_choice_probs` length mismatch.") + s = float(sum(patch_size_choice_probs)) + if s <= 0: + raise ValueError("`patch_size_choice_probs` sum to zero.") + probs = [p / s for p in patch_size_choice_probs] + variable = True + return sizes, probs, variable + + +class NaFlexMapDatasetWrapper(IterableDataset): + """ + IterableDataset wrapper for a map-style base dataset. + + Yields batches with variable sequence lengths. It calculates a canonical + batch schedule (sequence length, batch size pairs) once based on the + total dataset size (padded for distribution). Each epoch, it shuffles + the order of this canonical schedule and the dataset indices. + This ensures a consistent number of batches and samples per epoch + across all ranks. Handles distributed training and multiple workers. + """ + + def __init__( + self, + base_dataset: Dataset, + patch_size: Optional[Union[int, Tuple[int, int]]] = None, + patch_size_choices: Optional[List[int]] = None, + patch_size_choice_probs: Optional[List[float]] = None, + seq_lens: Tuple[int, ...] = (128, 256, 576, 784, 1024), + max_tokens_per_batch: int = 4096 * 4, + transform_factory: Optional[Callable] = None, + mixup_fn: Optional[Callable] = None, + seed: int = 42, + shuffle: bool = True, + distributed: bool = False, + rank: int = 0, + world_size: int = 1, + epoch: int = 0, + batch_divisor: int = 8, + ) -> None: + """Initialize NaFlexMapDatasetWrapper. + + Args: + base_dataset: Map-style dataset to wrap. + patch_size: Single patch size to use. + patch_size_choices: List of patch sizes to randomly select from. + patch_size_choice_probs: Probabilities for each patch size. + seq_lens: Sequence lengths to use for batching. + max_tokens_per_batch: Target tokens per batch. + transform_factory: Factory function for creating transforms. + mixup_fn: Optional mixup function. + seed: Random seed. + shuffle: Whether to shuffle data. + distributed: Whether using distributed training. + rank: Process rank for distributed training. + world_size: Total number of processes. + epoch: Starting epoch. + batch_divisor: Ensure batch size is divisible by this. + """ + super().__init__() + if not hasattr(base_dataset, '__len__') or not hasattr(base_dataset, '__getitem__'): + raise TypeError("base_dataset must be a map-style dataset (implement __len__ and __getitem__)") + + self.base_dataset = base_dataset + self.seq_lens = sorted(list(set(seq_lens))) # Ensure unique and sorted + self.max_tokens_per_batch = max_tokens_per_batch + self.seed = seed + self.shuffle = shuffle + self.distributed = distributed + self.rank = rank if distributed else 0 + self.world_size = world_size if distributed else 1 + self.epoch = epoch + self.batch_divisor = batch_divisor + + # Resolve patch size configuration + self.patch_sizes, self.patch_size_probs, self.variable_patch_size = _resolve_patch_cfg( + patch_size, + patch_size_choices, + patch_size_choice_probs + ) + + # Pre-initialize transforms and collate fns for each (seq_len, patch_idx) combination + self.transforms: Dict[Tuple[int, int], Optional[Callable]] = {} + self.collate_fns: Dict[int, Callable] = {} + self.patchifiers: List[Callable] = [] + + for seq_len in self.seq_lens: + self.collate_fns[seq_len] = NaFlexCollator(seq_len) + + for patch_idx, patch_size_tuple in enumerate(self.patch_sizes): + # Pre-initialize patchifiers for each patch size (indexed by patch_idx) + self.patchifiers.append(Patchify( + patch_size=patch_size_tuple, + flatten_patches=not self.variable_patch_size + )) + + # Create transforms for each (seq_len, patch_idx) combination + for seq_len in self.seq_lens: + key = (seq_len, patch_idx) + if transform_factory: + self.transforms[key] = transform_factory(max_seq_len=seq_len, patch_size=patch_size_tuple) + else: + self.transforms[key] = None # No transform + + self.mixup_fn = mixup_fn + + # Canonical Schedule Calculation (Done Once) + self._canonical_batch_schedule: List[Tuple[int, int]] = [] + self._num_batches_per_rank: int = 0 + self._padded_samples_per_rank: int = 0 + self._create_canonical_schedule() # Calculate schedule based on padded size + + # Per-Epoch State + # Stores (seq_len, list_of_indices) for the current epoch, specific to this rank + self._epoch_batches: List[Tuple[int, List[int]]] = [] + self._prepare_epoch_batches(self.epoch) # setup for initial epoch + + def _create_canonical_schedule(self): + """ + Calculates the canonical batch schedule (seq_len, batch_size pairs) + based on the dataset size, padded for distributed training. + This schedule is the *same* for all ranks and ensures consistent + epoch length. It is calculated once during initialization. + """ + total_len = len(self.base_dataset) + padded_total_len = total_len + num_samples_per_rank = total_len + + if self.distributed and self.world_size > 1: + # Calculate padding needed for even distribution + if total_len % self.world_size != 0: + pad_size = self.world_size - (total_len % self.world_size) + padded_total_len += pad_size + print(f"Rank {self.rank}: Padding dataset with {pad_size} samples for distributed training (total size {padded_total_len}).") + else: + pad_size = 0 + + if padded_total_len % self.world_size != 0: + # This should not happen with the padding logic, but safeguard + raise RuntimeError(f"Internal Error: Padded total length {padded_total_len} not divisible by world size {self.world_size}") + + num_samples_per_rank = padded_total_len // self.world_size + elif self.distributed and self.world_size <= 1: + # Distributed flag set but world_size is 1, treat as non-distributed + pass # num_samples_per_rank remains total_len + + self._padded_samples_per_rank = num_samples_per_rank + + if num_samples_per_rank == 0: + self._canonical_batch_schedule = [] + self._num_batches_per_rank = 0 + return + + # Use a fixed seed for generating the canonical schedule structure + g = torch.Generator() + g.manual_seed(self.seed) # Use base seed, NOT epoch seed + + current_schedule: List[Tuple[int, int]] = [] + remaining_samples = num_samples_per_rank + total_scheduled_samples = 0 + + while remaining_samples > 0: + # Sample sequence length deterministically based on base seed + seq_idx = torch.randint(0, len(self.seq_lens), (1,), generator=g).item() + seq_len = self.seq_lens[seq_idx] + + # Calculate batch size + batch_size = calculate_naflex_batch_size( + tokens_per_batch=self.max_tokens_per_batch, + seq_len=seq_len, + # max_size should be remaining_samples to avoid overshooting + max_size=remaining_samples, + divisor=self.batch_divisor, + rounding='floor', + ) + # Ensure batch size is positive and doesn't exceed remaining samples + batch_size = max(1, batch_size) + batch_size = min(batch_size, remaining_samples) + + if batch_size <= 0: + warnings.warn(f"Calculated batch size <= 0 (seq_len={seq_len}, remaining={remaining_samples}). Stopping schedule generation early.") + break # Avoid infinite loop if something goes wrong + + current_schedule.append((seq_len, batch_size)) + remaining_samples -= batch_size + total_scheduled_samples += batch_size + + # Sanity check: Ensure the schedule covers all samples for the rank + if total_scheduled_samples != num_samples_per_rank: + warnings.warn( + f"Rank {self.rank}: Canonical schedule accounts for {total_scheduled_samples} samples, " + f"but expected {num_samples_per_rank} samples per rank. " + f"This might happen if min_batch_size or batch_divisor constraints prevent utilizing all samples. " + f"Check parameters. Remaining samples: {remaining_samples}" + ) + # Adjust if needed? Could add a final small batch, but might violate constraints. + # Current behavior: some samples might be dropped if schedule logic fails. + + self._canonical_batch_schedule = current_schedule + self._num_batches_per_rank = len(current_schedule) + print(f"Rank {self.rank}: Created canonical schedule with {self._num_batches_per_rank} batches for {self._padded_samples_per_rank} samples/rank.") + + + def _prepare_epoch_batches(self, epoch: int): + """ + Prepares the batches for the current epoch by: + 1. Shuffling the full dataset indices (using epoch seed). + 2. Applying padding if in distributed mode. + 3. Selecting indices for the current rank. + 4. Shuffling the *order* of the canonical batch schedule (using epoch seed). + 5. Assigning the rank's indices to the shuffled batches. + """ + g = torch.Generator() + g.manual_seed(self.seed + epoch) # Epoch-specific seed for shuffling + + # 1. Get shuffled global indices + total_len = len(self.base_dataset) + if self.shuffle: + all_indices_shuffled = torch.randperm(total_len, generator=g).tolist() + else: + all_indices_shuffled = list(range(total_len)) + + # 2. Apply padding for distributed mode + indices_for_ranks = all_indices_shuffled + if self.distributed and self.world_size > 1: + padded_total_len = self._padded_samples_per_rank * self.world_size + if padded_total_len > total_len: + pad_size = padded_total_len - total_len + # Repeat initial elements from the *shuffled* list for padding + indices_for_ranks = all_indices_shuffled + all_indices_shuffled[:pad_size] + # Ensure length matches expectation + if len(indices_for_ranks) != padded_total_len: + raise RuntimeError(f"Internal Error: Padded index list length {len(indices_for_ranks)} does not match expected {padded_total_len}") + + # 3. Select indices for the current rank + if self.distributed and self.world_size > 1: + indices_this_rank = indices_for_ranks[self.rank::self.world_size] + else: # Non-distributed or world_size=1 + indices_this_rank = indices_for_ranks + + # Sanity check length + if len(indices_this_rank) != self._padded_samples_per_rank: + # This might happen if canonical schedule generation had warnings/issues + warnings.warn( + f"Rank {self.rank}: Number of indices for this rank ({len(indices_this_rank)}) " + f"does not match expected padded samples per rank ({self._padded_samples_per_rank}). " + f"Epoch generation might be inconsistent." + ) + # Adjust expected samples? Or truncate/pad indices? Let's proceed but warn. + # Using min() prevents IndexError later if indices are fewer than expected. + effective_samples_this_rank = min(len(indices_this_rank), self._padded_samples_per_rank) + indices_this_rank = indices_this_rank[:effective_samples_this_rank] + + else: + effective_samples_this_rank = self._padded_samples_per_rank + + # 4. Shuffle the order of the canonical batch schedule for this epoch + if self.shuffle: + schedule_perm = torch.randperm(self._num_batches_per_rank, generator=g).tolist() + shuffled_schedule = [self._canonical_batch_schedule[i] for i in schedule_perm] + else: + shuffled_schedule = list(self._canonical_batch_schedule) # Keep original order + + # 5. Assign indices to the shuffled batches + self._epoch_batches = [] + idx_pos = 0 + scheduled_samples_count = 0 + for seq_len, bs in shuffled_schedule: + # Ensure we don't try to grab more indices than available for the rank + actual_bs = min(bs, effective_samples_this_rank - idx_pos) + if actual_bs <= 0: + if scheduled_samples_count < effective_samples_this_rank: + # This indicates mismatch between schedule total and actual samples + warnings.warn(f"Rank {self.rank}: Ran out of samples ({idx_pos}/{effective_samples_this_rank}) before processing entire schedule. Check schedule generation.") + break # Stop if no more indices or batch size is zero + + batch_indices = indices_this_rank[idx_pos : idx_pos + actual_bs] + self._epoch_batches.append((seq_len, batch_indices)) + idx_pos += actual_bs + scheduled_samples_count += actual_bs + + # Final check + if scheduled_samples_count != effective_samples_this_rank: + warnings.warn( + f"Rank {self.rank}: Assigned {scheduled_samples_count} samples to batches, " + f"but expected {effective_samples_this_rank} effective samples this epoch. " + f"Indices remaining: {effective_samples_this_rank - scheduled_samples_count}." + ) + + def set_epoch(self, epoch: int) -> None: + """Updates the epoch, regenerating the epoch-specific batches. + + Args: + epoch: New epoch number. + """ + # Only regenerate if the epoch actually changes + if epoch != self.epoch: + self.epoch = epoch + self._prepare_epoch_batches(epoch) + + def __len__(self) -> int: + """Returns the number of batches per worker for the current epoch. + + Returns: + Number of batches this worker will process. + """ + return self._num_batches_per_rank + + def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]: + """Iterates through pre-calculated batches for the current epoch. + + Yields: + Tuple of (input_dict, targets) for each batch. + """ + worker_info = torch.utils.data.get_worker_info() + num_workers = worker_info.num_workers if worker_info else 1 + worker_id = worker_info.id if worker_info else 0 + + # Distribute pre-calculated batches among workers for this rank + # Each worker processes a slice of the batches prepared in _prepare_epoch_batches + batches_for_worker = self._epoch_batches[worker_id::num_workers] + for seq_len, indices in batches_for_worker: + if not indices: # Skip if a batch ended up with no indices (shouldn't happen often) + continue + + # Select patch size for this batch + patch_idx = 0 + if self.variable_patch_size: + # Use torch multinomial for weighted random choice + patch_idx = torch.multinomial(torch.tensor(self.patch_size_probs), 1).item() + + # Get the pre-initialized transform and patchifier using patch_idx + transform_key = (seq_len, patch_idx) + transform = self.transforms.get(transform_key) + batch_patchifier = self.patchifiers[patch_idx] + + batch_imgs = [] + batch_targets = [] + for idx in indices: + try: + # Get original image and label from map-style dataset + img, label = self.base_dataset[idx] + + # Apply transform if available + # Handle cases where transform might return None or fail + processed_img = transform(img) if transform else img + if processed_img is None: + warnings.warn(f"Transform returned None for index {idx}. Skipping sample.") + continue + + batch_imgs.append(processed_img) + batch_targets.append(label) + + except IndexError: + warnings.warn(f"IndexError encountered for index {idx} (possibly due to padding/repeated indices). Skipping sample.") + continue + except Exception as e: + # Log other potential errors during data loading/processing + warnings.warn(f"Error processing sample index {idx}. Error: {e}. Skipping sample.") + continue # Skip problematic sample + + if self.mixup_fn is not None: + batch_imgs, batch_targets = self.mixup_fn(batch_imgs, batch_targets) + + batch_imgs = [batch_patchifier(img) for img in batch_imgs] + batch_samples = list(zip(batch_imgs, batch_targets)) + if batch_samples: # Only yield if we successfully processed samples + # Collate the processed samples into a batch + yield self.collate_fns[seq_len](batch_samples) + + # If batch_samples is empty after processing 'indices', an empty batch is skipped. diff --git a/timm/data/naflex_loader.py b/timm/data/naflex_loader.py new file mode 100644 index 0000000000..d615bd63f9 --- /dev/null +++ b/timm/data/naflex_loader.py @@ -0,0 +1,414 @@ +"""NaFlex data loader for dynamic sequence length training. + +This module provides a specialized data loader for Vision Transformer models that supports: +- Dynamic sequence length sampling during training for improved efficiency +- Variable patch size training with probabilistic selection +- Patch-level random erasing augmentation +- Efficient GPU prefetching with normalization + +Hacked together by / Copyright 2025, Ross Wightman, Hugging Face +""" + +import math +from contextlib import suppress +from functools import partial +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union + + +import torch + +from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .loader import _worker_init, adapt_to_chs +from .naflex_dataset import NaFlexMapDatasetWrapper, NaFlexCollator +from .naflex_random_erasing import PatchRandomErasing +from .transforms_factory import create_transform + + +class NaFlexPrefetchLoader: + """Data prefetcher for NaFlex format which normalizes patches.""" + + def __init__( + self, + loader: torch.utils.data.DataLoader, + mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, + std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, + channels: int = 3, + device: torch.device = torch.device('cuda'), + img_dtype: Optional[torch.dtype] = None, + re_prob: float = 0., + re_mode: str = 'const', + re_count: int = 1, + re_num_splits: int = 0, + ) -> None: + """Initialize NaFlexPrefetchLoader. + + Args: + loader: DataLoader to prefetch from. + mean: Mean values for normalization. + std: Standard deviation values for normalization. + channels: Number of image channels. + device: Device to move tensors to. + img_dtype: Data type for image tensors. + re_prob: Random erasing probability. + re_mode: Random erasing mode. + re_count: Maximum number of erasing rectangles. + re_num_splits: Number of augmentation splits. + """ + self.loader = loader + self.device = device + self.img_dtype = img_dtype or torch.float32 + + # Create mean/std tensors for normalization (will be applied to patches) + mean = adapt_to_chs(mean, channels) + std = adapt_to_chs(std, channels) + normalization_shape = (1, 1, channels) + self.channels = channels + self.mean = torch.tensor( + [x * 255 for x in mean], device=device, dtype=self.img_dtype).view(normalization_shape) + self.std = torch.tensor( + [x * 255 for x in std], device=device, dtype=self.img_dtype).view(normalization_shape) + + if re_prob > 0.: + self.random_erasing = PatchRandomErasing( + erase_prob=re_prob, + mode=re_mode, + max_count=re_count, + num_splits=re_num_splits, + device=device, + ) + else: + self.random_erasing = None + + # Check for CUDA/NPU availability + self.is_cuda = device.type == 'cuda' and torch.cuda.is_available() + self.is_npu = device.type == 'npu' and torch.npu.is_available() + + def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]: + """Iterate through the loader with prefetching and normalization. + + Yields: + Tuple of (input_dict, targets) with normalized patches. + """ + first = True + if self.is_cuda: + stream = torch.cuda.Stream() + stream_context = partial(torch.cuda.stream, stream=stream) + elif self.is_npu: + stream = torch.npu.Stream() + stream_context = partial(torch.npu.stream, stream=stream) + else: + stream = None + stream_context = suppress + + for next_input_dict, next_target in self.loader: + with stream_context(): + # Move all tensors in input_dict to device + for k, v in next_input_dict.items(): + if isinstance(v, torch.Tensor): + dtype = self.img_dtype if k == 'patches' else None + next_input_dict[k] = next_input_dict[k].to( + device=self.device, + non_blocking=True, + dtype=dtype, + ) + + next_target = next_target.to(device=self.device, non_blocking=True) + + # Normalize patch values - handle both [B, N, P*P*C] and [B, N, Ph, Pw, C] formats + patches_tensor = next_input_dict['patches'] + original_shape = patches_tensor.shape + + if patches_tensor.ndim == 3: + # Format: [B, N, P*P*C] - flattened patches + batch_size, num_patches, patch_pixels = original_shape + # To [B*N, P*P, C] for normalization and erasing + patches = patches_tensor.view(batch_size, num_patches, -1, self.channels) + elif patches_tensor.ndim == 5: + # Format: [B, N, Ph, Pw, C] - unflattened patches (variable patch size mode) + batch_size, num_patches, patch_h, patch_w, channels = original_shape + assert channels == self.channels, f"Expected {self.channels} channels, got {channels}" + # To [B*N, Ph*Pw, C] for normalization and erasing + patches = patches_tensor.view(batch_size, num_patches, -1, self.channels) + else: + raise ValueError(f"Unexpected patches tensor dimensions: {patches_tensor.ndim}. Expected 3 or 5.") + + # Apply normalization + patches = patches.sub(self.mean).div(self.std) + + if self.random_erasing is not None: + patches = self.random_erasing( + patches, + patch_coord=next_input_dict['patch_coord'], + patch_valid=next_input_dict.get('patch_valid', None), + ) + + # Reshape back to original format + next_input_dict['patches'] = patches.view(original_shape) + + if not first: + yield input_dict, target + else: + first = False + + if stream is not None: + if self.is_cuda: + torch.cuda.current_stream().wait_stream(stream) + elif self.is_npu: + torch.npu.current_stream().wait_stream(stream) + + input_dict = next_input_dict + target = next_target + + yield input_dict, target + + def __len__(self) -> int: + """Get length of underlying loader. + + Returns: + Number of batches in the loader. + """ + return len(self.loader) + + @property + def sampler(self): + """Get sampler from underlying loader. + + Returns: + Sampler from the underlying DataLoader. + """ + return self.loader.sampler + + @property + def dataset(self): + """Get dataset from underlying loader. + + Returns: + Dataset from the underlying DataLoader. + """ + return self.loader.dataset + + +def create_naflex_loader( + dataset, + patch_size: Optional[Union[Tuple[int, int], int]] = None, + patch_size_choices: Optional[List[int]] = None, + patch_size_choice_probs: Optional[List[float]] = None, + train_seq_lens: Tuple[int, ...] = (128, 256, 576, 784, 1024), + max_seq_len: int = 576, + batch_size: int = 32, + is_training: bool = False, + mixup_fn: Optional[Callable] = None, + + no_aug: bool = False, + re_prob: float = 0., + re_mode: str = 'const', + re_count: int = 1, + re_split: bool = False, + train_crop_mode: Optional[str] = None, + scale: Optional[Tuple[float, float]] = None, + ratio: Optional[Tuple[float, float]] = None, + hflip: float = 0.5, + vflip: float = 0., + color_jitter: float = 0.4, + color_jitter_prob: Optional[float] = None, + grayscale_prob: float = 0., + gaussian_blur_prob: float = 0., + auto_augment: Optional[str] = None, + num_aug_repeats: int = 0, + num_aug_splits: int = 0, + interpolation: str = 'bilinear', + mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, + std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, + crop_pct: Optional[float] = None, + crop_mode: Optional[str] = None, + crop_border_pixels: Optional[int] = None, + + num_workers: int = 4, + distributed: bool = False, + rank: int = 0, + world_size: int = 1, + seed: int = 42, + epoch: int = 0, + use_prefetcher: bool = True, + pin_memory: bool = True, + img_dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = torch.device('cuda'), + persistent_workers: bool = True, + worker_seeding: str = 'all', + ) -> Union[torch.utils.data.DataLoader, NaFlexPrefetchLoader]: + """Create a data loader with dynamic sequence length sampling for training. + + Args: + dataset: Dataset to load from. + patch_size: Single patch size to use. + patch_size_choices: List of patch sizes for variable patch size training. + patch_size_choice_probs: Probabilities for each patch size choice. + train_seq_lens: Training sequence lengths for dynamic batching. + max_seq_len: Fixed sequence length for validation. + batch_size: Batch size for validation and max training sequence length. + is_training: Whether this is for training (enables dynamic batching). + mixup_fn: Optional mixup function. + no_aug: Disable augmentation. + re_prob: Random erasing probability. + re_mode: Random erasing mode. + re_count: Maximum number of erasing rectangles. + re_split: Random erasing split flag. + train_crop_mode: Training crop mode. + scale: Scale range for random resize crop. + ratio: Aspect ratio range for random resize crop. + hflip: Horizontal flip probability. + vflip: Vertical flip probability. + color_jitter: Color jitter factor. + color_jitter_prob: Color jitter probability. + grayscale_prob: Grayscale conversion probability. + gaussian_blur_prob: Gaussian blur probability. + auto_augment: AutoAugment policy. + num_aug_repeats: Number of augmentation repeats. + num_aug_splits: Number of augmentation splits. + interpolation: Interpolation method. + mean: Normalization mean values. + std: Normalization standard deviation values. + crop_pct: Crop percentage for validation. + crop_mode: Crop mode. + crop_border_pixels: Crop border pixels. + num_workers: Number of data loading workers. + distributed: Whether using distributed training. + rank: Process rank for distributed training. + world_size: Total number of processes. + seed: Random seed. + epoch: Starting epoch. + use_prefetcher: Whether to use prefetching. + pin_memory: Whether to pin memory. + img_dtype: Image data type. + device: Device to move tensors to. + persistent_workers: Whether to use persistent workers. + worker_seeding: Worker seeding mode. + + Returns: + DataLoader or NaFlexPrefetchLoader instance. + """ + + if is_training: + # For training, use the dynamic sequence length mechanism + assert num_aug_repeats == 0, 'Augmentation repeats not currently supported in NaFlex loader' + + transform_factory = partial( + create_transform, + is_training=True, + no_aug=no_aug, + train_crop_mode=train_crop_mode, + scale=scale, + ratio=ratio, + hflip=hflip, + vflip=vflip, + color_jitter=color_jitter, + color_jitter_prob=color_jitter_prob, + grayscale_prob=grayscale_prob, + gaussian_blur_prob=gaussian_blur_prob, + auto_augment=auto_augment, + interpolation=interpolation, + mean=mean, + std=std, + crop_pct=crop_pct, + crop_mode=crop_mode, + crop_border_pixels=crop_border_pixels, + re_prob=re_prob, + re_mode=re_mode, + re_count=re_count, + use_prefetcher=use_prefetcher, + naflex=True, + ) + + max_train_seq_len = max(train_seq_lens) + max_tokens_per_batch = batch_size * max_train_seq_len + + if isinstance(dataset, torch.utils.data.IterableDataset): + assert False, "IterableDataset Wrapper is a WIP" + + naflex_dataset = NaFlexMapDatasetWrapper( + dataset, + transform_factory=transform_factory, + patch_size=patch_size, + patch_size_choices=patch_size_choices, + patch_size_choice_probs=patch_size_choice_probs, + seq_lens=train_seq_lens, + max_tokens_per_batch=max_tokens_per_batch, + mixup_fn=mixup_fn, + seed=seed, + distributed=distributed, + rank=rank, + world_size=world_size, + shuffle=True, + epoch=epoch, + ) + + # NOTE: Collation is handled by the dataset wrapper for training + loader = torch.utils.data.DataLoader( + naflex_dataset, + batch_size=None, + shuffle=False, + num_workers=num_workers, + sampler=None, + pin_memory=pin_memory, + worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding), + persistent_workers=persistent_workers + ) + + if use_prefetcher: + loader = NaFlexPrefetchLoader( + loader, + mean=mean, + std=std, + img_dtype=img_dtype, + device=device, + re_prob=re_prob, + re_mode=re_mode, + re_count=re_count, + ) + + else: + # For validation, use fixed sequence length (unchanged) + dataset.transform = create_transform( + is_training=False, + interpolation=interpolation, + mean=mean, + std=std, + # FIXME add crop args when sequence transforms support crop modes + use_prefetcher=use_prefetcher, + naflex=True, + patch_size=patch_size, + max_seq_len=max_seq_len, + patchify=True, + ) + + # Create the collator + collate_fn = NaFlexCollator(max_seq_len=max_seq_len) + + # Handle distributed training + sampler = None + if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): + # For validation, use OrderedDistributedSampler + from timm.data.distributed_sampler import OrderedDistributedSampler + sampler = OrderedDistributedSampler(dataset) + + loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=sampler, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=False, + ) + + if use_prefetcher: + loader = NaFlexPrefetchLoader( + loader, + mean=mean, + std=std, + img_dtype=img_dtype, + device=device, + ) + + return loader diff --git a/timm/data/naflex_mixup.py b/timm/data/naflex_mixup.py new file mode 100644 index 0000000000..40185cb88f --- /dev/null +++ b/timm/data/naflex_mixup.py @@ -0,0 +1,250 @@ +"""Variable‑size Mixup / CutMix utilities for NaFlex data loaders. + +This module provides: + +* `mix_batch_variable_size` – pixel‑level Mixup/CutMix that operates on a + list of images whose spatial sizes differ, mixing only their central overlap + so no resizing is required. +* `pairwise_mixup_target` – builds soft‑label targets that exactly match the + per‑sample pixel provenance produced by the mixer. +* `NaFlexMixup` – a callable functor that wraps the two helpers and stores + all augmentation hyper‑parameters in one place, making it easy to plug into + different dataset wrappers. + +Hacked together by / Copyright 2025, Ross Wightman, Hugging Face +""" +import math +import random +from typing import Dict, List, Tuple, Union + +import torch + + +def mix_batch_variable_size( + imgs: List[torch.Tensor], + *, + mixup_alpha: float = 0.8, + cutmix_alpha: float = 1.0, + switch_prob: float = 0.5, + local_shuffle: int = 4, +) -> Tuple[List[torch.Tensor], List[float], Dict[int, int]]: + """Apply Mixup or CutMix on a batch of variable-sized images. + + Sorts images by aspect ratio and pairs neighboring samples. Only the mutual + central overlap region of each pair is mixed. + + Args: + imgs: List of transformed images shaped (C, H, W). + mixup_alpha: Beta distribution alpha for Mixup. Set to 0 to disable. + cutmix_alpha: Beta distribution alpha for CutMix. Set to 0 to disable. + switch_prob: Probability of using CutMix when both modes are enabled. + local_shuffle: Size of local windows for shuffling after aspect sorting. + + Returns: + Tuple of (mixed_imgs, lam_list, pair_to) where: + - mixed_imgs: List of mixed images + - lam_list: Per-sample lambda values representing mixing degree + - pair_to: Mapping i -> j of which sample was mixed with which + """ + if len(imgs) < 2: + raise ValueError("Need at least two images to perform Mixup/CutMix.") + + # Decide augmentation mode and raw λ + if mixup_alpha > 0.0 and cutmix_alpha > 0.0: + use_cutmix = torch.rand(()).item() < switch_prob + alpha = cutmix_alpha if use_cutmix else mixup_alpha + elif mixup_alpha > 0.0: + use_cutmix = False + alpha = mixup_alpha + elif cutmix_alpha > 0.0: + use_cutmix = True + alpha = cutmix_alpha + else: + raise ValueError("Both mixup_alpha and cutmix_alpha are zero – nothing to do.") + + lam_raw = torch.distributions.Beta(alpha, alpha).sample().item() + lam_raw = max(0.0, min(1.0, lam_raw)) # numerical safety + + # Pair images by nearest aspect ratio + order = sorted(range(len(imgs)), key=lambda i: imgs[i].shape[2] / imgs[i].shape[1]) + if local_shuffle > 1: + for start in range(0, len(order), local_shuffle): + random.shuffle(order[start:start + local_shuffle]) + + pair_to: Dict[int, int] = {} + for a, b in zip(order[::2], order[1::2]): + pair_to[a] = b + pair_to[b] = a + + odd_one = order[-1] if len(imgs) % 2 else None + + mixed_imgs: List[torch.Tensor] = [None] * len(imgs) + lam_list: List[float] = [1.0] * len(imgs) + + for i in range(len(imgs)): + if i == odd_one: + mixed_imgs[i] = imgs[i] + continue + + j = pair_to[i] + xi, xj = imgs[i], imgs[j] + _, hi, wi = xi.shape + _, hj, wj = xj.shape + dest_area = hi * wi + + # Central overlap common to both images + oh, ow = min(hi, hj), min(wi, wj) + overlap_area = oh * ow + top_i, left_i = (hi - oh) // 2, (wi - ow) // 2 + top_j, left_j = (hj - oh) // 2, (wj - ow) // 2 + + xi = xi.clone() + if use_cutmix: + # CutMix: random rectangle inside the overlap + cut_ratio = math.sqrt(1.0 - lam_raw) + ch, cw = int(oh * cut_ratio), int(ow * cut_ratio) + cut_area = ch * cw + y_off = random.randint(0, oh - ch) + x_off = random.randint(0, ow - cw) + + yl_i, xl_i = top_i + y_off, left_i + x_off + yl_j, xl_j = top_j + y_off, left_j + x_off + xi[:, yl_i: yl_i + ch, xl_i: xl_i + cw] = xj[:, yl_j: yl_j + ch, xl_j: xl_j + cw] + mixed_imgs[i] = xi + + corrected_lam = 1.0 - cut_area / float(dest_area) + lam_list[i] = corrected_lam + else: + # Mixup: blend the entire overlap region + patch_i = xi[:, top_i:top_i + oh, left_i:left_i + ow] + patch_j = xj[:, top_j:top_j + oh, left_j:left_j + ow] + + blended = patch_i.mul(lam_raw).add_(patch_j, alpha=1.0 - lam_raw) + xi[:, top_i:top_i + oh, left_i:left_i + ow] = blended + mixed_imgs[i] = xi + + corrected_lam = (dest_area - overlap_area) / dest_area + lam_raw * overlap_area / dest_area + lam_list[i] = corrected_lam + + return mixed_imgs, lam_list, pair_to + + +def smoothed_sparse_target( + targets: torch.Tensor, + *, + num_classes: int, + smoothing: float = 0.0, +) -> torch.Tensor: + off_val = smoothing / num_classes + on_val = 1.0 - smoothing + off_val + + y_onehot = torch.full( + (targets.size(0), num_classes), + off_val, + dtype=torch.float32, + device=targets.device + ) + y_onehot.scatter_(1, targets.unsqueeze(1), on_val) + return y_onehot + + +def pairwise_mixup_target( + targets: torch.Tensor, + pair_to: Dict[int, int], + lam_list: List[float], + *, + num_classes: int, + smoothing: float = 0.0, +) -> torch.Tensor: + """Create soft targets that match the pixel‑level mixing performed. + + Args: + targets: (B,) tensor of integer class indices. + pair_to: Mapping of sample index to its mixed partner as returned by mix_batch_variable_size(). + lam_list: Per‑sample fractions of own pixels, also from the mixer. + num_classes: Total number of classes in the dataset. + smoothing: Label‑smoothing value in the range [0, 1). + + Returns: + Tensor of shape (B, num_classes) whose rows sum to 1. + """ + y_onehot = smoothed_sparse_target(targets, num_classes=num_classes, smoothing=smoothing) + targets = y_onehot.clone() + for i, j in pair_to.items(): + lam = lam_list[i] + targets[i].mul_(lam).add_(y_onehot[j], alpha=1.0 - lam) + + return targets + + +class NaFlexMixup: + """Callable wrapper that combines mixing and target generation.""" + + def __init__( + self, + *, + num_classes: int, + mixup_alpha: float = 0.8, + cutmix_alpha: float = 1.0, + switch_prob: float = 0.5, + prob: float = 1.0, + local_shuffle: int = 4, + label_smoothing: float = 0.0, + ) -> None: + """Configure the augmentation. + + Args: + num_classes: Total number of classes. + mixup_alpha: Beta α for Mixup. 0 disables Mixup. + cutmix_alpha: Beta α for CutMix. 0 disables CutMix. + switch_prob: Probability of selecting CutMix when both modes are enabled. + prob: Probability of applying any mixing per batch. + local_shuffle: Window size used to shuffle images after aspect sorting so pairings vary between epochs. + smoothing: Label‑smoothing value. 0 disables smoothing. + """ + self.num_classes = num_classes + self.mixup_alpha = mixup_alpha + self.cutmix_alpha = cutmix_alpha + self.switch_prob = switch_prob + self.prob = prob + self.local_shuffle = local_shuffle + self.smoothing = label_smoothing + + def __call__( + self, + imgs: List[torch.Tensor], + targets: torch.Tensor, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Apply the augmentation and generate matching targets. + + Args: + imgs: List of already transformed images shaped (C, H, W). + targets: Hard labels with shape (B,). + + Returns: + mixed_imgs: List of mixed images in the same order and shapes as the input. + targets: Soft‑label tensor shaped (B, num_classes) suitable for cross‑entropy with soft targets. + """ + if not isinstance(targets, torch.Tensor): + targets = torch.tensor(targets) + + if random.random() > self.prob: + targets = smoothed_sparse_target(targets, num_classes=self.num_classes, smoothing=self.smoothing) + return imgs, targets.unbind(0) + + mixed_imgs, lam_list, pair_to = mix_batch_variable_size( + imgs, + mixup_alpha=self.mixup_alpha, + cutmix_alpha=self.cutmix_alpha, + switch_prob=self.switch_prob, + local_shuffle=self.local_shuffle, + ) + + targets = pairwise_mixup_target( + targets, + pair_to, + lam_list, + num_classes=self.num_classes, + smoothing=self.smoothing, + ) + return mixed_imgs, targets.unbind(0) diff --git a/timm/data/naflex_random_erasing.py b/timm/data/naflex_random_erasing.py new file mode 100644 index 0000000000..f1cbeb8ac7 --- /dev/null +++ b/timm/data/naflex_random_erasing.py @@ -0,0 +1,354 @@ +"""Patch-level random erasing augmentation for NaFlex Vision Transformers. + +This module implements random erasing specifically designed for patchified images, +operating at the patch granularity rather than pixel level. It supports two modes: +- 'patch': Randomly erases individual patches (speckle-like noise) +- 'region': Erases contiguous rectangular regions of patches (similar to original RandomErasing) + +The implementation is coordinate-aware, respecting valid patch boundaries and supporting +variable patch sizes in NaFlex training. + +Hacked together by / Copyright 2025, Ross Wightman, Hugging Face +""" + +import random +import math +from typing import Optional, Union, Tuple + +import torch + + +class PatchRandomErasing: + """Random erasing for patchified images in NaFlex format. + + Supports two modes: + 1. 'patch': Simple mode that erases randomly selected valid patches + 2. 'region': Erases rectangular regions at patch granularity + """ + + def __init__( + self, + erase_prob: float = 0.5, + patch_drop_prob: float = 0.0, + min_count: int = 1, + max_count: Optional[int] = None, + min_area: float = 0.02, + max_area: float = 1 / 3, + min_aspect: float = 0.3, + max_aspect: Optional[float] = None, + mode: str = 'const', + value: float = 0., + spatial_mode: str = 'region', + num_splits: int = 0, + device: Union[str, torch.device] = 'cuda', + ) -> None: + """Initialize PatchRandomErasing. + + Args: + erase_prob: Probability that the Random Erasing operation will be performed. + patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing. + min_count: Minimum number of erasing operations. + max_count: Maximum number of erasing operations. + min_area: Minimum percentage of valid patches/area to erase. + max_area: Maximum percentage of valid patches/area to erase. + min_aspect: Minimum aspect ratio of erased area (only used in 'region' mode). + max_aspect: Maximum aspect ratio of erased area (only used in 'region' mode). + mode: Patch content mode, one of 'const', 'rand', or 'pixel'. + value: Constant value for 'const' mode. + spatial_mode: Erasing strategy, one of 'patch' or 'region'. + num_splits: Number of splits to apply erasing to (0 for all). + device: Computation device. + """ + self.erase_prob = erase_prob + self.patch_drop_prob = patch_drop_prob + self.min_count = min_count + self.max_count = max_count or min_count + self.min_area = min_area + self.max_area = max_area + + # Aspect ratio params (for region mode) + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + # Number of splits + self.num_splits = num_splits + self.device = device + + # Strategy mode + self.spatial_mode = spatial_mode + assert self.spatial_mode in ('patch', 'region') + + # Value generation mode flags + self.erase_mode = mode.lower() + assert self.erase_mode in ('rand', 'pixel', 'const') + self.const_value = value + self.unique_noise_per_patch = True + + def _get_values( + self, + shape: Union[Tuple[int, ...], torch.Size], + value: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float32, + device: Optional[Union[str, torch.device]] = None + ) -> torch.Tensor: + """Generate values for erased patches based on the specified mode. + + Args: + shape: Shape of patches to erase. + value: Value to use in const (or rand) mode. + dtype: Data type to use. + device: Device to use. + + Returns: + Tensor with values for erasing patches. + """ + device = device or self.device + if self.erase_mode == 'pixel': + # only mode with erase shape that includes pixels + return torch.empty(shape, dtype=dtype, device=device).normal_() + else: + shape = (1, 1, shape[-1]) if len(shape) == 3 else (1, shape[-1]) + if self.erase_mode == 'const' or value is not None: + erase_value = value or self.const_value + if isinstance(erase_value, (int, float)): + values = torch.full(shape, erase_value, dtype=dtype, device=device) + else: + erase_value = torch.tensor(erase_value, dtype=dtype, device=device) + values = torch.expand_copy(erase_value, shape) + else: + values = torch.empty(shape, dtype=dtype, device=device).normal_() + return values + + def _drop_patches( + self, + patches: torch.Tensor, + patch_coord: torch.Tensor, + patch_valid: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Patch Dropout. + + Fully drops patches from datastream. Only mode that saves compute BUT requires support + for non-contiguous patches and associated patch coordinate and valid handling. + + Args: + patches: Tensor of patches. + patch_coord: Tensor of patch coordinates. + patch_valid: Tensor indicating which patches are valid. + + Returns: + Tuple of (patches, patch_coord, patch_valid) with some patches dropped. + """ + # FIXME WIP, not completed. Downstream support in model needed for non-contiguous valid patches + if random.random() > self.erase_prob: + return + + # Get indices of valid patches + valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist() + + # Skip if no valid patches + if not valid_indices: + return patches, patch_coord, patch_valid + + num_valid = len(valid_indices) + if self.patch_drop_prob: + # patch dropout mode, completely remove dropped patches (FIXME needs downstream support in model) + num_keep = max(1, int(num_valid * (1. - self.patch_drop_prob))) + keep_indices = torch.argsort(torch.randn(1, num_valid, device=self.device), dim=-1)[:, :num_keep] + # maintain patch order, possibly useful for debug / visualization + keep_indices = keep_indices.sort(dim=-1)[0] + patches = patches.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + patches.shape[2:])) + + return patches, patch_coord, patch_valid + + def _erase_patches( + self, + patches: torch.Tensor, + patch_coord: torch.Tensor, + patch_valid: torch.Tensor, + patch_shape: torch.Size, + dtype: torch.dtype = torch.float32, + ) -> None: + """Apply erasing by selecting individual patches randomly. + + The simplest mode, aligned on patch boundaries. Behaves similarly to speckle or 'sprinkles' + noise augmentation at patch size. + + Args: + patches: Tensor of patches to modify in-place. + patch_coord: Tensor of patch coordinates. + patch_valid: Tensor indicating which patches are valid. + patch_shape: Shape of individual patches. + dtype: Data type for generated values. + """ + if random.random() > self.erase_prob: + return + + # Get indices of valid patches + valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0] + num_valid = len(valid_indices) + if num_valid == 0: + return + + count = random.randint(self.min_count, self.max_count) + # Determine how many valid patches to erase from RE min/max count and area args + max_erase = min(num_valid, max(1, int(num_valid * count * self.max_area))) + min_erase = max(1, int(num_valid * count * self.min_area)) + num_erase = random.randint(min_erase, max_erase) + + # Randomly select valid patches to erase + erase_idx = valid_indices[torch.randperm(num_valid, device=patches.device)[:num_erase]] + + if self.unique_noise_per_patch and self.erase_mode == 'pixel': + # generate unique noise for the whole selection of patches + fill_shape = (num_erase,) + patch_shape + else: + fill_shape = patch_shape + + patches[erase_idx] = self._get_values(fill_shape, dtype=dtype) + + def _erase_region( + self, + patches: torch.Tensor, + patch_coord: torch.Tensor, + patch_valid: torch.Tensor, + patch_shape: torch.Size, + dtype: torch.dtype = torch.float32, + ) -> None: + """Apply erasing by selecting rectangular regions of patches randomly. + + Closer to the original RandomErasing implementation. Erases + spatially contiguous rectangular regions of patches (aligned with patches). + + Args: + patches: Tensor of patches to modify in-place. + patch_coord: Tensor of patch coordinates. + patch_valid: Tensor indicating which patches are valid. + patch_shape: Shape of individual patches. + dtype: Data type for generated values. + """ + if random.random() > self.erase_prob: + return + + # Determine grid dimensions from coordinates + valid_coord = patch_coord[patch_valid] + if len(valid_coord) == 0: + return # No valid patches + max_y = valid_coord[:, 0].max().item() + 1 + max_x = valid_coord[:, 1].max().item() + 1 + grid_h, grid_w = max_y, max_x + total_area = grid_h * grid_w + ys, xs = patch_coord[:, 0], patch_coord[:, 1] + + count = random.randint(self.min_count, self.max_count) + for _ in range(count): + # Try to select a valid region to erase (multiple attempts) + for attempt in range(10): + # Sample random area and aspect ratio + target_area = random.uniform(self.min_area, self.max_area) * total_area + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + + # Calculate region height and width + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + + if h > grid_h or w > grid_w: + continue # try again + + # Calculate region patch bounds + top = random.randint(0, grid_h - h) + left = random.randint(0, grid_w - w) + bottom, right = top + h, left + w + + # Region test + region_mask = ( + (ys >= top) & (ys < bottom) & + (xs >= left) & (xs < right) & + patch_valid + ) + num_selected = int(region_mask.sum().item()) + if not num_selected: + continue # no patch actually falls inside – try again + + if self.unique_noise_per_patch and self.erase_mode == 'pixel': + # generate unique noise for the whole region + fill_shape = (num_selected,) + patch_shape + else: + fill_shape = patch_shape + + patches[region_mask] = self._get_values(fill_shape, dtype=dtype) + # Successfully applied erasing, exit the loop + break + + def __call__( + self, + patches: torch.Tensor, + patch_coord: torch.Tensor, + patch_valid: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Apply random patch erasing. + + Args: + patches: Tensor of shape [B, N, P*P, C] or [B, N, Ph, Pw, C]. + patch_coord: Tensor of shape [B, N, 2] with (y, x) coordinates. + patch_valid: Boolean tensor of shape [B, N] indicating which patches are valid. + + Returns: + Erased patches tensor of same shape as input. + """ + if patches.ndim == 4: + batch_size, num_patches, patch_dim, channels = patches.shape + elif patches.ndim == 5: + batch_size, num_patches, patch_h, patch_w, channels = patches.shape + else: + assert False + patch_shape = patches.shape[2:] + # patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c) + + # Create default valid mask if not provided + if patch_valid is None: + patch_valid = torch.ones((batch_size, num_patches), dtype=torch.bool, device=patches.device) + + # Skip the first part of the batch if num_splits is set + batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 + + # Apply erasing to each batch element + for i in range(batch_start, batch_size): + if self.patch_drop_prob: + assert False, "WIP, not completed" + self._drop_patches( + patches[i], + patch_coord[i], + patch_valid[i], + ) + elif self.spatial_mode == 'patch': + # FIXME we could vectorize patch mode across batch, worth the effort? + self._erase_patches( + patches[i], + patch_coord[i], + patch_valid[i], + patch_shape, + patches.dtype + ) + elif self.spatial_mode == 'region': + self._erase_region( + patches[i], + patch_coord[i], + patch_valid[i], + patch_shape, + patches.dtype + ) + else: + assert False + + return patches + + def __repr__(self) -> str: + """Return string representation of PatchRandomErasing. + + Returns: + String representation of the object. + """ + fs = self.__class__.__name__ + f'(p={self.erase_prob}, mode={self.erase_mode}' + fs += f', spatial={self.spatial_mode}, area=({self.min_area}, {self.max_area}))' + fs += f', count=({self.min_count}, {self.max_count}))' + return fs \ No newline at end of file diff --git a/timm/data/naflex_transforms.py b/timm/data/naflex_transforms.py new file mode 100644 index 0000000000..23308f4420 --- /dev/null +++ b/timm/data/naflex_transforms.py @@ -0,0 +1,821 @@ +""" NaFlex (NaViT + FlexiViT) Transforms and Collation + +Implements PyTorch versions of the transforms described in the NaViT and FlexiViT papers: +- NaViT: https://arxiv.org/abs/2307.14995 +- FlexiViT: https://arxiv.org/abs/2212.08013 + +Enables variable resolution/aspect ratio image handling with efficient patching. + +Hacked together by / Copyright 2025, Ross Wightman, Hugging Face +""" + +import math +import random +import warnings +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch +from PIL import Image +from torchvision import transforms +from torchvision.transforms import functional as F +from torchvision.transforms.functional import InterpolationMode + +from .transforms import str_to_interp_mode, crop_or_pad, center_crop_or_pad + + +def get_image_size_for_seq( + image_hw: Tuple[int, int], + patch_size: Union[int, Tuple[int, int]] = 16, + max_seq_len: int = 1024, + divisible_by_patch: bool = True, + max_ratio: Optional[float] = None, + eps: float = 1e-5, +) -> Tuple[float, Tuple[int, int]]: + """Determine scaling ratio and image size for sequence length constraint. + + Calculates the scaling ratio needed so that when image_hw is scaled, + the total number of resulting patches does not exceed max_seq_len. + + Args: + image_hw: Original image dimensions (height, width). + patch_size: Patch dimensions. If int, patches are square. + max_seq_len: Maximum allowed sequence length. + divisible_by_patch: Whether resulting dimensions must be divisible by patch_size. + max_ratio: Optional cap on scaling ratio to prevent excessive upsampling. + eps: Convergence threshold for binary search. + + Returns: + Tuple of (ratio, target_hw) where ratio is the scaling factor and + target_hw is the resulting (height, width) after scaling. + """ + + # Handle patch size input, extract patch_h, patch_w + if isinstance(patch_size, int): + patch_h, patch_w = patch_size, patch_size + else: + # Assume it's a tuple/list: (patch_h, patch_w) + if len(patch_size) != 2: + raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).") + patch_h, patch_w = patch_size + + # Safety checks + if patch_h <= 0 or patch_w <= 0: + raise ValueError("patch_size dimensions must be positive.") + + def prepare_target_hw(ratio): + """Scale image_hw by ratio and optionally round dimensions to multiples of patch_h, patch_w.""" + scaled_h = image_hw[0] * ratio + scaled_w = image_hw[1] * ratio + + # If we need the result to be divisible by patch_size + if divisible_by_patch: + scaled_h = patch_h * math.ceil(scaled_h / patch_h) + scaled_w = patch_w * math.ceil(scaled_w / patch_w) + + # Ensure at least one patch in each dimension + scaled_h = int(max(scaled_h, patch_h)) + scaled_w = int(max(scaled_w, patch_w)) + + return scaled_h, scaled_w + + def is_feasible(ratio): + """Check if scaling by 'ratio' keeps patch count within max_seq_len.""" + t_h, t_w = prepare_target_hw(ratio) + + # Each dimension is already a multiple of patch_h, patch_w if divisible_by_patch=True. + # Use integer division to count patches. + num_patches_h = t_h // patch_h + num_patches_w = t_w // patch_w + seq_len = num_patches_h * num_patches_w + + return seq_len <= max_seq_len + + # Binary search boundaries + lb = eps / 10.0 + rb = 100.0 + + # Standard binary search loop + while (rb - lb) >= eps: + mid = (lb + rb) / 2.0 + if is_feasible(mid): + lb = mid + else: + rb = mid + + # The final ratio from the binary search + ratio = lb + + # If max_ratio is provided, clamp it to prevent upsampling beyond that threshold + if max_ratio is not None: + ratio = min(ratio, max_ratio) + + # Final checks + if ratio <= eps: + raise ValueError("Binary search failed - image might be too large?") + if ratio >= 100.0: + raise ValueError("Binary search failed - image might be too small?") + + # Prepare the final target dimensions with the possibly clamped ratio + target_hw = prepare_target_hw(ratio) + return ratio, target_hw + + +_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic')) + + +class ResizeToSequence(torch.nn.Module): + """Resize image to fit within a maximum sequence length constraint when patchified. + + This maintains aspect ratio while ensuring the resulting image, when divided into patches, + will not exceed the specified maximum sequence length. + """ + def __init__( + self, + patch_size: int, + max_seq_len: int = 1024, + divisible_by_patch: bool = True, + max_ratio: Optional[float] = None, + interpolation: Union[str, InterpolationMode, Tuple[InterpolationMode, ...]] = 'bicubic', + ) -> None: + """Initialize ResizeToSequence transform. + + Args: + patch_size: Size of patches. + max_seq_len: Maximum sequence length constraint. + divisible_by_patch: Whether dimensions must be divisible by patch_size. + max_ratio: Optional cap on scaling ratio. + interpolation: Interpolation method or methods. + """ + super().__init__() + self.patch_size = patch_size + self.max_seq_len = max_seq_len + self.divisible_by_patch = divisible_by_patch + self.max_ratio = max_ratio + if isinstance(interpolation, str): + if interpolation == 'random': + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = str_to_interp_mode(interpolation) + else: + self.interpolation = interpolation + + + def forward(self, img: torch.Tensor) -> torch.Tensor: + """Resize image to maintain aspect ratio and fit sequence constraint. + + Args: + img: Input image tensor. + + Returns: + Resized image tensor. + """ + _, h, w = transforms.functional.get_dimensions(img) + + _, target_hw = get_image_size_for_seq( + (h, w), + self.patch_size, + self.max_seq_len, + divisible_by_patch=self.divisible_by_patch, + max_ratio=self.max_ratio, + ) + + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + + resized_img = transforms.functional.resize(img, target_hw, interpolation=interpolation, antialias=True) + + return resized_img + + +class ResizeKeepRatioToSequence(torch.nn.Module): + """ + Resize and Keep Aspect Ratio, adapted to fit sequence length constraints. + """ + + def __init__( + self, + patch_size=16, + max_sequence_len=1024, + divisible_by_patch=True, + longest=0., + interpolation='bilinear', + random_scale_prob=0., + random_scale_range=(0.85, 1.05), + random_scale_area=False, + random_aspect_prob=0., + random_aspect_range=(0.9, 1.11), + max_ratio=None, + ): + """ + Args: + patch_size: Size of patches (int or tuple of (patch_h, patch_w)) + max_sequence_len: Maximum allowed sequence length for the resulting image + divisible_by_patch: If True, ensure dimensions are divisible by patch_size + longest: Float between 0-1 where 0=shortest side, 1=longest side determines scale + interpolation: Interpolation method for resizing + random_scale_prob: Probability of applying random scaling + random_scale_range: Range for random scaling factor (min, max) + random_scale_area: If True, scale factors affect area (√ factor) + random_aspect_prob: Probability of applying random aspect ratio jittering + random_aspect_range: Range for random aspect ratio (min, max) + max_ratio: Maximum allowed scaling ratio + """ + super().__init__() + self.patch_size = patch_size + self.max_sequence_len = max_sequence_len + self.divisible_by_patch = divisible_by_patch + self.longest = float(longest) + + if interpolation == 'random': + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = str_to_interp_mode(interpolation) + + self.random_scale_prob = random_scale_prob + self.random_scale_range = random_scale_range + self.random_scale_area = random_scale_area + self.random_aspect_prob = random_aspect_prob + self.random_aspect_range = random_aspect_range + self.max_ratio = max_ratio + + @staticmethod + def get_params( + img, + patch_size, + max_sequence_len, + divisible_by_patch, + longest, + random_scale_prob=0., + random_scale_range=(1.0, 1.33), + random_scale_area=False, + random_aspect_prob=0., + random_aspect_range=(0.9, 1.11), + max_ratio=None, + ): + """Get parameters for resizing.""" + # Get image dimensions + img_h, img_w = F.get_dimensions(img)[1:] + + # Step 1: Get the maximum allowed dimensions from sequence length constraint + _, target_hw = get_image_size_for_seq( + (img_h, img_w), + patch_size, + max_sequence_len, + divisible_by_patch, + max_ratio, + ) + target_h, target_w = target_hw + + # Calculate ratio based on sequence constraint + ratio_h = target_h / img_h + ratio_w = target_w / img_w + # Apply longest blending + ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest) + + # Apply random scaling + if random_scale_prob > 0 and random.random() < random_scale_prob: + ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1]) + if random_scale_area: + # Make ratio factor equivalent to area change + ratio_factor = 1. / math.sqrt(ratio_factor) + ratio_factor = (ratio_factor, ratio_factor) + else: + ratio_factor = (1., 1.) + + # Apply random aspect + if random_aspect_prob > 0 and random.random() < random_aspect_prob: + log_aspect = (math.log(random_aspect_range[0]), math.log(random_aspect_range[1])) + aspect_factor = math.exp(random.uniform(*log_aspect)) + aspect_factor = math.sqrt(aspect_factor) + # Apply aspect ratio jittering + ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor) + + # Calculate final dimensions + size = [round(dim * ratio * f) for dim, f in zip((img_h, img_w), ratio_factor)] + + # Ensure dimensions satisfy sequence constraint and are divisible by patch size + if isinstance(patch_size, int): + ph, pw = patch_size, patch_size + else: + ph, pw = patch_size + + # Ensure dimensions are at least one patch + size[0] = max(size[0], ph) + size[1] = max(size[1], pw) + + # Make divisible by patch size if needed + if divisible_by_patch: + size[0] = ph * math.ceil(size[0] / ph) + size[1] = pw * math.ceil(size[1] / pw) + + # Verify we haven't exceeded sequence length + num_patches_h = size[0] // ph + num_patches_w = size[1] // pw + seq_len = num_patches_h * num_patches_w + + if seq_len > max_sequence_len: + # Scale back down to fit sequence constraint + scale_back = math.sqrt(max_sequence_len / seq_len) + size[0] = int(size[0] * scale_back) + size[1] = int(size[1] * scale_back) + + # Ensure divisible by patch size after scaling back + if divisible_by_patch: + size[0] = ph * math.ceil(size[0] / ph) + size[1] = pw * math.ceil(size[1] / pw) + + return size + + def forward(self, img): + """ + Resize the image with aspect ratio preservation and sequence length constraints. + """ + size = self.get_params( + img, + self.patch_size, + self.max_sequence_len, + self.divisible_by_patch, + self.longest, + self.random_scale_prob, + self.random_scale_range, + self.random_scale_area, + self.random_aspect_prob, + self.random_aspect_range, + self.max_ratio, + ) + + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + + return F.resize(img, size, interpolation) + + def __repr__(self): + interpolate_str = "random" if isinstance(self.interpolation, (tuple, list)) else str(self.interpolation) + return (f"{self.__class__.__name__}(patch_size={self.patch_size}, " + f"max_sequence_len={self.max_sequence_len}, " + f"longest={self.longest:.3f}, " + f"random_scale_prob={self.random_scale_prob:.3f}, " + f"random_aspect_prob={self.random_aspect_prob:.3f})") + + +class CenterCropToSequence(torch.nn.Module): + """Center crop the image such that the resulting patch sequence length meets constraints.""" + def __init__( + self, + patch_size: int, + max_seq_len: int, + divisible_by_patch: bool = True, + fill: Union[int, Tuple[int, int, int]] = 0, + padding_mode: str = 'constant' + ): + super().__init__() + self.patch_size = patch_size + self.max_seq_len = max_seq_len + self.divisible_by_patch = divisible_by_patch + self.fill = fill + self.padding_mode = padding_mode + + + def forward(self, img): + """Center crop the image to maintain aspect ratio and fit sequence constraint.""" + _, h, w = transforms.functional.get_dimensions(img) + _, target_hw = get_image_size_for_seq( + (h, w), + self.patch_size, + self.max_seq_len, + self.divisible_by_patch + ) + + # Use center crop + return center_crop_or_pad(img, target_hw, fill=self.fill, padding_mode=self.padding_mode) + + +class RandomCropToSequence(torch.nn.Module): + """Randomly crop and/or pad the image to fit sequence length constraints. + + This maintains aspect ratio while ensuring the resulting image, when divided into patches, + will not exceed the specified maximum sequence length. Similar to CentralCropToSequence + but with randomized positioning. + """ + + def __init__( + self, + patch_size: int, + max_sequence_len: int, + divisible_by_patch: bool = True, + fill: Union[int, Tuple[int, int, int]] = 0, + padding_mode: str = 'constant' + ): + """ + Args: + patch_size: Size of patches (int or tuple of (patch_h, patch_w)) + max_sequence_len: Maximum allowed sequence length for the resulting image + divisible_by_patch: If True, resulting image dimensions will be multiples of patch_size + fill: Fill value for padding + padding_mode: Padding mode ('constant', 'edge', 'reflect', 'symmetric') + """ + super().__init__() + self.patch_size = patch_size + self.max_sequence_len = max_sequence_len + self.divisible_by_patch = divisible_by_patch + self.fill = fill + self.padding_mode = padding_mode + + @staticmethod + def get_params(img, target_size): + """Get random position for crop/pad.""" + _, image_height, image_width = transforms.functional.get_dimensions(img) + delta_height = image_height - target_size[0] + delta_width = image_width - target_size[1] + + # Handle both positive (crop) and negative (pad) deltas + if delta_height == 0: + top = 0 + else: + top = int(math.copysign(random.randint(0, abs(delta_height)), delta_height)) + + if delta_width == 0: + left = 0 + else: + left = int(math.copysign(random.randint(0, abs(delta_width)), delta_width)) + + return top, left + + def forward(self, img): + """Randomly crop or pad the image to maintain aspect ratio and fit sequence constraint.""" + # Get current dimensions + _, img_h, img_w = transforms.functional.get_dimensions(img) + + # Calculate target dimensions that satisfy sequence length + # We use max_ratio=1.0 to prevent upscaling - we only want to crop or maintain current size + _, target_hw = get_image_size_for_seq( + (img_h, img_w), + self.patch_size, + self.max_sequence_len, + self.divisible_by_patch, + max_ratio=1.0 # Prevent upscaling + ) + + # Get random position for crop/pad + top, left = self.get_params(img, target_hw) + + # Apply crop or pad + return crop_or_pad( + img, + top=top, + left=left, + height=target_hw[0], + width=target_hw[1], + fill=self.fill, + padding_mode=self.padding_mode, + ) + + def __repr__(self) -> str: + return (f"{self.__class__.__name__}(patch_size={self.patch_size}, " + f"max_sequence_len={self.max_sequence_len}, " + f"divisible_by_patch={self.divisible_by_patch})") + + +def _validate_range(value, name, length=2): + # Validate type and length + if not isinstance(value, Sequence) or len(value) != length: + raise ValueError(f"{name} should be a sequence of length {length}.") + + # Validate order + if value[0] > value[1]: + warnings.warn(f"{name.capitalize()} range reversed. Swapping.") + return value[1], value[0] + + return value + + +class RandomResizedCropToSequence(torch.nn.Module): + """ + Randomly crop the input image to a subregion with varying area and aspect ratio + (relative to the original), then resize that crop to a target size. The target size + is determined such that patchifying the resized image (with `patch_size`) + does not exceed `max_seq_len` patches, while maintaining the aspect ratio of the crop. + + This combines aspects of torchvision's RandomResizedCrop with sequence length constraints. + + Args: + patch_size (int or tuple[int, int]): + Patch dimensions (patch_h, patch_w) for sequence length calculation. + max_seq_len (int): + Maximum number of patches allowed in the final image. + scale (tuple[float, float]): + Range (min, max) of area fraction of the original image to crop. + ratio (tuple[float, float]): + Range (min, max) of aspect ratio *multipliers* for the crop, relative + to the original image's aspect ratio. E.g., (0.75, 1.333) means the + crop's aspect ratio will be sampled between 0.75*orig_ar and 1.333*orig_ar. + Uses log-uniform sampling. + interpolation (str or InterpolationMode): + Interpolation mode for resizing. Can be 'bilinear', 'bicubic', 'nearest', + or 'random' (chooses between bilinear and bicubic). + Defaults to 'bicubic'. + divisible_by_patch (bool): + If True, the final image height and width will be multiples of the + respective patch dimensions. Defaults to True. + max_ratio (float, optional): + An optional upper limit on the scaling ratio applied during resizing. + Prevents excessive upsampling of the initial crop. `max_ratio=1.0` + prevents any upsampling beyond the cropped size. Defaults to None (no limit). + final_scale_range (tuple[float, float], optional): + If provided, applies an *additional* random scaling factor to the + final target size. The factor is sampled uniformly from this range, + and multiplied by the size determined by `get_image_size_for_seq`. + E.g., (0.8, 1.0) means the final size will be between 80% and 100% + of the maximum feasible size. Defaults to None (use maximum feasible size). + attempts (int): + Number of attempts to sample a valid crop geometry before falling back + to a center crop strategy. Defaults to 10. + """ + + def __init__( + self, + patch_size: Union[int, Tuple[int, int]] = 16, + max_seq_len: int = 1024, + scale: Tuple[float, float] = (0.08, 1.0), + ratio: Tuple[float, float] = (.8, 1.25), + interpolation: Union[str, InterpolationMode] = 'bicubic', + divisible_by_patch: bool = True, + max_ratio: Optional[float] = None, + final_scale_range: Optional[Tuple[float, float]] = None, + attempts: int = 10, + ): + super().__init__() + if isinstance(patch_size, int): + self.patch_h, self.patch_w = patch_size, patch_size + else: + # Assume it's a tuple/list: (patch_h, patch_w) + if len(patch_size) != 2: + raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).") + self.patch_h, self.patch_w = patch_size + self.max_seq_len = max_seq_len + self.scale = scale + self.ratio = ratio + self.divisible_by_patch = divisible_by_patch + self.max_ratio = max_ratio + self.final_scale_range = final_scale_range + self.attempts = attempts + if isinstance(interpolation, str): + if interpolation == 'random': + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = str_to_interp_mode(interpolation) + else: + self.interpolation = interpolation + + # Validate scale and ratio + self.scale = _validate_range(self.scale, "scale") + self.ratio = _validate_range(self.ratio, "ratio") + + # Validate final_scale_range if provided + if self.final_scale_range is not None: + self.final_scale_range = _validate_range(self.final_scale_range, "final_scale_range") + + # Additional validation for final_scale_range values + if not (0.0 <= self.final_scale_range[0] <= self.final_scale_range[1] <= 1.0): + warnings.warn("final_scale_range values should ideally be between 0.0 and 1.0.") + + @staticmethod + def get_params( + img: torch.Tensor, + scale: Tuple[float, float], + ratio: Tuple[float, float], + crop_attempts: int = 10, + patch_h: int = 16, + patch_w: int = 16, + max_seq_len: int = 1024, + divisible_by_patch: bool = True, + max_ratio: Optional[float] = None, + final_scale_range: Optional[Tuple[float, float]] = None, + interpolation: Union[List[InterpolationMode], InterpolationMode] = _RANDOM_INTERPOLATION, + ) -> Tuple[Tuple[int, int, int, int], Tuple[int, int], InterpolationMode]: + """ Get parameters for a random sized crop relative to image aspect ratio. + """ + _, height, width = F.get_dimensions(img) + if height <= 0 or width <= 0: + raise ValueError(f"Input image must have positive dimensions, got H={height}, W={width}") + + area = height * width + orig_aspect = width / height + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + + for _ in range(crop_attempts): + target_area = area * random.uniform(scale[0], scale[1]) + aspect_ratio_factor = math.exp(random.uniform(log_ratio[0], log_ratio[1])) + aspect_ratio = orig_aspect * aspect_ratio_factor + + # Calculate target dimensions for the crop + # target_area = crop_w * crop_h, aspect_ratio = crop_w / crop_h + # => crop_h = sqrt(target_area / aspect_ratio) + # => crop_w = sqrt(target_area * aspect_ratio) + crop_h = int(round(math.sqrt(target_area / aspect_ratio))) + crop_w = int(round(math.sqrt(target_area * aspect_ratio))) + + if 0 < crop_w <= width and 0 < crop_h <= height: + top = random.randint(0, height - crop_h) + left = random.randint(0, width - crop_w) + break + else: + # Fallback strategy, use center crop trying to respect ratio range + min_aspect_ratio = orig_aspect * ratio[0] + max_aspect_ratio = orig_aspect * ratio[1] + + if orig_aspect < min_aspect_ratio: + # Original is narrower than target min, clamp width + crop_w = width + crop_h = min(int(round(crop_w / min_aspect_ratio)), height) + elif orig_aspect > max_aspect_ratio: + # Original is wider than target max, clamp height + crop_h = height + crop_w = min(int(round(crop_h * max_aspect_ratio)), width) + else: + # Aspect ratio is within range, take the largest possible crop (full image) + crop_w = width + crop_h = height + + # Ensure valid dimensions after fallback calculation + crop_h = max(1, crop_h) + crop_w = max(1, crop_w) + + top = (height - crop_h) // 2 + left = (width - crop_w) // 2 + + # Determine max feasible size for scaling of the *cropped* region + feasible_ratio, feasible_size = get_image_size_for_seq( + (crop_h, crop_w), + patch_size=(patch_h, patch_w), # Pass as tuple + max_seq_len=max_seq_len, + divisible_by_patch=divisible_by_patch, + max_ratio=max_ratio, + ) + + # Optionally apply final scale randomization + final_size = feasible_size + if final_scale_range is not None: + min_sc, max_sc = final_scale_range + scale_factor = random.uniform(min_sc, max_sc) + scale_factor = min(max(scale_factor, 0.0), 1.0) # Clamp factor just in case + + # Calculate raw scaled size + # Note: feasible_ratio already accounts for max_ratio clamp if any + raw_h = crop_h * feasible_ratio * scale_factor + raw_w = crop_w * feasible_ratio * scale_factor + + # Re-apply divisibility constraint if needed + if divisible_by_patch: + # Use ceil to avoid going under minimum patch size + target_h = patch_h * math.ceil(raw_h / patch_h) + target_w = patch_w * math.ceil(raw_w / patch_w) + else: + target_h = int(round(raw_h)) + target_w = int(round(raw_w)) + + # Ensure final size is at least one patch dimension + target_h = max(target_h, patch_h) + target_w = max(target_w, patch_w) + final_size = (target_h, target_w) + + # Final check: Ensure this randomized size still fits max_seq_len + # (It should, as we scaled down, but rounding might theoretically push it over) + num_patches_h = final_size[0] // patch_h + num_patches_w = final_size[1] // patch_w + if (num_patches_h * num_patches_w) > max_seq_len: + # If it exceeds, revert to the original feasible_size (safest) + final_size = feasible_size + warnings.warn(f"Final scale randomization ({scale_factor:.2f}) resulted in size {final_size} exceeding max_seq_len={max_seq_len} after rounding. Reverting to feasible size {feasible_size}.") + + # Select interpolation mode + if isinstance(interpolation, (tuple, list)): + interpolation = random.choice(interpolation) + else: + interpolation = interpolation + + return (top, left, crop_h, crop_w), final_size, interpolation + + def forward(self, img: torch.Tensor) -> torch.Tensor: + # Sample crop, resize, and interpolation parameters + crop_params, final_size, interpolation = self.get_params( + img, + scale=self.scale, + ratio=self.ratio, + crop_attempts=self.attempts, + patch_h=self.patch_h, + patch_w=self.patch_w, + divisible_by_patch=self.divisible_by_patch, + max_seq_len=self.max_seq_len, + final_scale_range=self.final_scale_range, + interpolation=self.interpolation, + ) + top, left, crop_h, crop_w = crop_params + + output = F.resized_crop( + img, + top=top, + left=left, + height=crop_h, + width=crop_w, + size=final_size, + interpolation=interpolation, + antialias=True, + ) + + return output + + def __repr__(self) -> str: + if isinstance(self.interpolation, (tuple, list)): + interpolate_str = ', '.join(str(m).split('.')[-1] for m in self.interpolation) + else: + interpolate_str = str(self.interpolation) + format_string = self.__class__.__name__ + '(' + format_string += f"patch_size=({self.patch_h}, {self.patch_w})" + format_string += f", max_seq_len={self.max_seq_len}" + format_string += f", scale={self.scale}" + format_string += f", ratio={self.ratio}" + format_string += f", interpolation=[{interpolate_str}]" + format_string += f", divisible_by_patch={self.divisible_by_patch}" + format_string += f", max_ratio={self.max_ratio}" + format_string += f", final_scale_range={self.final_scale_range}" + format_string += f", attempts={self.attempts}" + format_string += ')' + return format_string + + +def patchify_image( + img: torch.Tensor, + patch_size: Tuple[int, int], + pad: bool = True, + include_info: bool = True, + flatten_patches: bool = True, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + c, h, w = img.shape + ph, pw = patch_size + + # Ensure the image is divisible by patch size + if pad and (h % ph != 0 or w % pw != 0): + pad_h = (ph - h % ph) % ph # amount to add on bottom + pad_w = (pw - w % pw) % pw # amount to add on right + img = torch.nn.functional.pad(img, (0, pad_w, 0, pad_h)) + c, h, w = img.shape + + # Calculate number of patches in each dimension + nh, nw = h // ph, w // pw + # Reshape image to patches + patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0) + # [nh, nw, ph, pw, c] -> [nh * nw, ph * pw * c] or [nh * nw, ph, pw, c] + patches = patches.reshape(-1, ph * pw * c) if flatten_patches else patches.reshape(-1, ph, pw, c) + + if include_info: + # Create coordinate indices + y_idx, x_idx = torch.meshgrid(torch.arange(nh), torch.arange(nw), indexing='ij') + # Stack into a single coords tensor [N, 2] with (y, x) order + coord = torch.stack([y_idx.reshape(-1), x_idx.reshape(-1)], dim=1) + # Create type indicators (all 1s for regular patches) + valid = torch.ones(nh * nw, dtype=torch.bool) + return patches, coord, valid + + return patches + + +class Patchify(torch.nn.Module): + """Transform an image into patches with corresponding coordinates and type indicators.""" + + def __init__( + self, + patch_size: Union[int, Tuple[int, int]], + flatten_patches: bool = True + ): + super().__init__() + self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size) + self.flatten_patches = flatten_patches + + def forward(self, img): + """ + Args: + img: A PIL Image or tensor of shape [C, H, W] + + Returns: + A dictionary containing: + - patches: Tensor of shape [N, P*P*C] if flatten_patches=True, + or [N, Ph, Pw, C] if flatten_patches=False + - patch_coord: Tensor of shape [N, 2] with (y, x) coordinates + - patch_valid: Valid indicator (all 1s for non-padding patches) + """ + if isinstance(img, Image.Image): + # Convert PIL Image to tensor [C, H, W] + img = transforms.functional.to_tensor(img) + + patches, coord, valid = patchify_image(img, self.patch_size, flatten_patches=self.flatten_patches) + + return { + 'patches': patches, + 'patch_coord': coord, + 'patch_valid': valid, + } diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 9be0e3bf3c..904017ee8c 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -12,7 +12,8 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \ - ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor + ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, MaybeToTensor, MaybePILToTensor +from timm.data.naflex_transforms import RandomResizedCropToSequence, ResizeToSequence, Patchify from timm.data.random_erasing import RandomErasing @@ -46,7 +47,7 @@ def transforms_noaug_train( ] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm - tfl += [ToNumpy()] + tfl += [MaybePILToTensor()] elif not normalize: # when normalize disabled, converted to tensor without scaling, keep original dtype tfl += [MaybePILToTensor()] @@ -84,6 +85,10 @@ def transforms_imagenet_train( use_prefetcher: bool = False, normalize: bool = True, separate: bool = False, + naflex: bool = False, + patch_size: Union[int, Tuple[int, int]] = 16, + max_seq_len: int = 576, # 24x24 for 16x16 patch + patchify: bool = False, ): """ ImageNet-oriented image transforms for training. @@ -111,6 +116,9 @@ def transforms_imagenet_train( use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used). separate: Output transforms in 3-stage tuple. + naflex: Enable NaFlex mode, sequence constrained patch output + patch_size: Patch size for NaFlex mode. + max_seq_len: Max sequence length for NaFlex mode. Returns: If separate==True, the transforms are returned as a tuple of 3 separate transforms @@ -121,35 +129,49 @@ def transforms_imagenet_train( """ train_crop_mode = train_crop_mode or 'rrc' assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'} - if train_crop_mode in ('rkrc', 'rkrr'): - # FIXME integration of RKR is a WIP - scale = tuple(scale or (0.8, 1.00)) - ratio = tuple(ratio or (0.9, 1/.9)) - primary_tfl = [ - ResizeKeepRatio( - img_size, - interpolation=interpolation, - random_scale_prob=0.5, - random_scale_range=scale, - random_scale_area=True, # scale compatible with RRC - random_aspect_prob=0.5, - random_aspect_range=ratio, - ), - CenterCropOrPad(img_size, padding_mode='reflect') - if train_crop_mode == 'rkrc' else - RandomCropOrPad(img_size, padding_mode='reflect') - ] - else: + + primary_tfl = [] + if naflex: scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range - primary_tfl = [ - RandomResizedCropAndInterpolation( - img_size, - scale=scale, - ratio=ratio, - interpolation=interpolation, - ) - ] + primary_tfl += [RandomResizedCropToSequence( + patch_size=patch_size, + max_seq_len=max_seq_len, + scale=scale, + ratio=ratio, + interpolation=interpolation + )] + else: + if train_crop_mode in ('rkrc', 'rkrr'): + # FIXME integration of RKR is a WIP + scale = tuple(scale or (0.8, 1.00)) + ratio = tuple(ratio or (0.9, 1/.9)) + primary_tfl += [ + ResizeKeepRatio( + img_size, + interpolation=interpolation, + random_scale_prob=0.5, + random_scale_range=scale, + random_scale_area=True, # scale compatible with RRC + random_aspect_prob=0.5, + random_aspect_range=ratio, + ), + CenterCropOrPad(img_size, padding_mode='reflect') + if train_crop_mode == 'rkrc' else + RandomCropOrPad(img_size, padding_mode='reflect') + ] + else: + scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range + ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range + primary_tfl += [ + RandomResizedCropAndInterpolation( + img_size, + scale=scale, + ratio=ratio, + interpolation=interpolation, + ) + ] + if hflip > 0.: primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] if vflip > 0.: @@ -215,7 +237,7 @@ def transforms_imagenet_train( final_tfl = [] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm - final_tfl += [ToNumpy()] + final_tfl += [MaybePILToTensor()] elif not normalize: # when normalize disable, converted to tensor without scaling, keeps original dtype final_tfl += [MaybePILToTensor()] @@ -238,6 +260,9 @@ def transforms_imagenet_train( ) ] + if patchify: + final_tfl += [Patchify(patch_size=patch_size)] + if separate: return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) else: @@ -254,6 +279,10 @@ def transforms_imagenet_eval( std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, use_prefetcher: bool = False, normalize: bool = True, + naflex: bool = False, + patch_size: Union[int, Tuple[int, int]] = 16, + max_seq_len: int = 576, # 24x24 for 16x16 patch + patchify: bool = False, ): """ ImageNet-oriented image transform for evaluation and inference. @@ -267,6 +296,10 @@ def transforms_imagenet_eval( std: Image normalization standard deviation. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used). + naflex: Enable NaFlex mode, sequence constrained patch output + patch_size: Patch size for NaFlex mode. + max_seq_len: Max sequence length for NaFlex mode. + patchify: Patchify the output instead of relying on prefetcher Returns: Composed transform pipeline @@ -285,37 +318,44 @@ def transforms_imagenet_eval( if crop_border_pixels: tfl += [TrimBorder(crop_border_pixels)] - if crop_mode == 'squash': - # squash mode scales each edge to 1/pct of target, then crops - # aspect ratio is not preserved, no img lost if crop_pct == 1.0 - tfl += [ - transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)), - transforms.CenterCrop(img_size), - ] - elif crop_mode == 'border': - # scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop - # no image lost if crop_pct == 1.0 - fill = [round(255 * v) for v in mean] - tfl += [ - ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0), - CenterCropOrPad(img_size, fill=fill), - ] + if naflex: + tfl += [ResizeToSequence( + patch_size=patch_size, + max_seq_len=max_seq_len, + interpolation=interpolation, + )] else: - # default crop model is center - # aspect ratio is preserved, crops center within image, no borders are added, image is lost - if scale_size[0] == scale_size[1]: - # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) + if crop_mode == 'squash': + # squash mode scales each edge to 1/pct of target, then crops + # aspect ratio is not preserved, no img lost if crop_pct == 1.0 tfl += [ - transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation)) + transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)), + transforms.CenterCrop(img_size), + ] + elif crop_mode == 'border': + # scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop + # no image lost if crop_pct == 1.0 + fill = [round(255 * v) for v in mean] + tfl += [ + ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0), + CenterCropOrPad(img_size, fill=fill), ] else: - # resize the shortest edge to matching target dim for non-square target - tfl += [ResizeKeepRatio(scale_size)] - tfl += [transforms.CenterCrop(img_size)] + # default crop model is center + # aspect ratio is preserved, crops center within image, no borders are added, image is lost + if scale_size[0] == scale_size[1]: + # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) + tfl += [ + transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation)) + ] + else: + # resize the shortest edge to matching target dim for non-square target + tfl += [ResizeKeepRatio(scale_size)] + tfl += [transforms.CenterCrop(img_size)] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm - tfl += [ToNumpy()] + tfl += [MaybePILToTensor()] elif not normalize: # when normalize disabled, converted to tensor without scaling, keeps original dtype tfl += [MaybePILToTensor()] @@ -328,6 +368,9 @@ def transforms_imagenet_eval( ), ] + if patchify: + tfl += [Patchify(patch_size=patch_size)] + return transforms.Compose(tfl) @@ -359,6 +402,10 @@ def create_transform( use_prefetcher: bool = False, normalize: bool = True, separate: bool = False, + naflex: bool = False, + patch_size: Union[int, Tuple[int, int]] = 16, + max_seq_len: int = 576, # 24x24 for 16x16 patch + patchify: bool = False ): """ @@ -442,6 +489,10 @@ def create_transform( use_prefetcher=use_prefetcher, normalize=normalize, separate=separate, + naflex=naflex, + patch_size=patch_size, + max_seq_len=max_seq_len, + patchify=patchify, ) else: assert not separate, "Separate transforms not supported for validation preprocessing" @@ -455,6 +506,10 @@ def create_transform( crop_border_pixels=crop_border_pixels, use_prefetcher=use_prefetcher, normalize=normalize, + naflex=naflex, + patch_size=patch_size, + max_seq_len=max_seq_len, + patchify=patchify, ) return transform diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 23c6f908c9..ac24140ea5 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -1,7 +1,7 @@ from .activations import * from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d -from .attention import Attention, AttentionRope +from .attention import Attention, AttentionRope, maybe_add_mask from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2 from .attention_pool import AttentionPoolLatent from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding @@ -41,7 +41,7 @@ SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d from .padding import get_padding, get_same_padding, pad_same from .patch_dropout import PatchDropout -from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed +from .patch_embed import PatchEmbed, PatchEmbedWithSize, PatchEmbedInterpolator, resample_patch_embed from .pool1d import global_pool_nlc from .pool2d_same import AvgPool2dSame, create_pool2d from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc diff --git a/timm/layers/attention.py b/timm/layers/attention.py index 8e95a00209..01b2ecb263 100644 --- a/timm/layers/attention.py +++ b/timm/layers/attention.py @@ -8,6 +8,10 @@ from .pos_embed_sincos import apply_rot_embed_cat +def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + return scores if attn_mask is None else scores + attn_mask + + class Attention(nn.Module): """Standard Multi-head Self Attention module with QKV projection. @@ -24,10 +28,11 @@ def __init__( num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, + scale_norm: bool = False, proj_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0., - norm_layer: Type[nn.Module] = nn.LayerNorm, + norm_layer: Optional[Type[nn.Module]] = None, ) -> None: """Initialize the Attention module. @@ -43,6 +48,8 @@ def __init__( """ super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' + if qk_norm or scale_norm: + assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 @@ -52,6 +59,7 @@ def __init__( self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) + self.norm = norm_layer(dim) if scale_norm else nn.Identity() self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) @@ -74,13 +82,13 @@ def forward( else: q = q * self.scale attn = q @ k.transpose(-2, -1) - if attn_mask is not None: - attn = attn + attn_mask + attn = maybe_add_mask(attn, attn_mask) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) + x = self.norm(x) x = self.proj(x) x = self.proj_drop(x) return x @@ -196,10 +204,7 @@ def forward( else: q = q * self.scale attn = (q @ k.transpose(-2, -1)) - - if attn_mask is not None: - attn_mask = attn_mask.to(torch.bool) - attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) + attn = maybe_add_mask(attn, attn_mask) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) diff --git a/timm/layers/attention_pool.py b/timm/layers/attention_pool.py index c2591a3b4e..f464c8a3d1 100644 --- a/timm/layers/attention_pool.py +++ b/timm/layers/attention_pool.py @@ -4,6 +4,7 @@ import torch.nn as nn import torch.nn.functional as F +from .attention import maybe_add_mask from .config import use_fused_attn from .mlp import Mlp from .weight_init import trunc_normal_tf_ @@ -75,7 +76,7 @@ def init_weights(self): trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5) - def forward(self, x): + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): B, N, C = x.shape if self.pos_embed is not None: @@ -91,10 +92,11 @@ def forward(self, x): q, k = self.q_norm(q), self.k_norm(k) if self.fused_attn: - x = F.scaled_dot_product_attention(q, k, v) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) else: q = q * self.scale attn = q @ k.transpose(-2, -1) + attn = maybe_add_mask(attn, attn_mask) attn = attn.softmax(dim=-1) x = attn @ v x = x.transpose(1, 2).reshape(B, self.latent_len, C) diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index 336d16615f..f87ce9693a 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -10,7 +10,7 @@ """ import logging import math -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn as nn @@ -180,7 +180,8 @@ def forward(self, x) -> Tuple[torch.Tensor, List[int]]: return x, feat_size -def resample_patch_embed( +# FIXME to remove, keeping for comparison for now +def resample_patch_embed_old( patch_embed, new_size: List[int], interpolation: str = 'bicubic', @@ -250,6 +251,351 @@ def resample_kernel(kernel): return patch_embed +DTYPE_INTERMEDIATE = torch.float32 + + +def _compute_resize_matrix( + old_size: Tuple[int, int], + new_size: Tuple[int, int], + interpolation: str, + antialias: bool, + device: torch.device, + dtype: torch.dtype = DTYPE_INTERMEDIATE +) -> torch.Tensor: + """Computes the resize matrix basis vectors and interpolates them to new_size.""" + old_h, old_w = old_size + new_h, new_w = new_size + old_total = old_h * old_w + new_total = new_h * new_w + + eye_matrix = torch.eye(old_total, device=device, dtype=dtype) + basis_vectors_batch = eye_matrix.reshape(old_total, 1, old_h, old_w) + + resized_basis_vectors_batch = F.interpolate( + basis_vectors_batch, + size=new_size, + mode=interpolation, + antialias=antialias, + align_corners=False + ) # Output shape: (old_total, 1, new_h, new_w) + + resize_matrix = resized_basis_vectors_batch.squeeze(1).reshape(old_total, new_total).T + return resize_matrix # Shape: (new_total, old_total) + + +def _compute_pinv_for_resampling(resize_matrix: torch.Tensor) -> torch.Tensor: + """Calculates the pseudoinverse matrix used for the resampling operation.""" + pinv_matrix = torch.linalg.pinv(resize_matrix.T) # Shape: (new_total, old_total) + return pinv_matrix + + +def _apply_resampling( + patch_embed: torch.Tensor, + pinv_matrix: torch.Tensor, + new_size_tuple: Tuple[int, int], + orig_dtype: torch.dtype, + intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE +) -> torch.Tensor: + """Applies the precomputed pinv_matrix to resample the patch_embed tensor.""" + try: + from torch import vmap + except ImportError: + from functorch import vmap + + def resample_kernel(kernel: torch.Tensor) -> torch.Tensor: + kernel_flat = kernel.reshape(-1).to(intermediate_dtype) + resampled_kernel_flat = pinv_matrix @ kernel_flat + return resampled_kernel_flat.reshape(new_size_tuple) + + resample_kernel_vmap = vmap(vmap(resample_kernel, in_dims=0, out_dims=0), in_dims=0, out_dims=0) + patch_embed_float = patch_embed.to(intermediate_dtype) + resampled_patch_embed = resample_kernel_vmap(patch_embed_float) + return resampled_patch_embed.to(orig_dtype) + + +def resample_patch_embed( + patch_embed: torch.Tensor, + new_size: List[int], + interpolation: str = 'bicubic', + antialias: bool = True, + verbose: bool = False, +): + """ Standalone function (computes matrix on each call). """ + assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_ch, in_ch, h, w)" + assert len(new_size) == 2, "New shape should only be hw (height, width)" + + old_size_tuple: Tuple[int, int] = tuple(patch_embed.shape[-2:]) + new_size_tuple: Tuple[int, int] = tuple(new_size) + + if old_size_tuple == new_size_tuple: + return patch_embed + + device = patch_embed.device + orig_dtype = patch_embed.dtype + + resize_mat = _compute_resize_matrix( + old_size_tuple, new_size_tuple, interpolation, antialias, device, DTYPE_INTERMEDIATE + ) + pinv_matrix = _compute_pinv_for_resampling(resize_mat) + resampled_patch_embed = _apply_resampling( + patch_embed, pinv_matrix, new_size_tuple, orig_dtype, DTYPE_INTERMEDIATE + ) + return resampled_patch_embed + + +class PatchEmbedResamplerFixedOrigSize(nn.Module): + """ + Resample patch embedding weights from a fixed original size, + caching the pseudoinverse matrix based on the target size. + """ + def __init__( + self, + orig_size: Tuple[int, int], + interpolation: str = 'bicubic', + antialias: bool = True + ): + """ + Args: + orig_size (Tuple[int, int]): The expected original (height, width) of input patch_embed tensors. + interpolation (str): Interpolation mode. + antialias (bool): Use anti-aliasing filter in resize. + """ + super().__init__() + assert isinstance(orig_size, tuple) and len(orig_size) == 2, \ + "`orig_size` must be a tuple of (height, width)" + self.orig_size = orig_size # expected original size + self.interpolation = interpolation + self.antialias = antialias + # Cache map key is the target new_size tuple + self._pinv_cache_map: Dict[Tuple[int, int], str] = {} + + def _get_or_create_pinv_matrix( + self, + new_size: Tuple[int, int], + device: torch.device, + dtype: torch.dtype = DTYPE_INTERMEDIATE + ) -> torch.Tensor: + """Retrieves the cached pinv matrix or computes and caches it for the given new_size.""" + cache_key = new_size + buffer_name = self._pinv_cache_map.get(cache_key) + + if buffer_name and hasattr(self, buffer_name): + pinv_matrix = getattr(self, buffer_name) + if pinv_matrix.device == device and pinv_matrix.dtype == dtype: + return pinv_matrix + + # Calculate the matrix if not cached or needs update + resize_mat = _compute_resize_matrix( + self.orig_size, new_size, self.interpolation, self.antialias, device, dtype + ) + pinv_matrix = _compute_pinv_for_resampling(resize_mat) + + # Cache using register_buffer + buffer_name = f"pinv_{new_size[0]}x{new_size[1]}" + if hasattr(self, buffer_name): + delattr(self, buffer_name) + self.register_buffer(buffer_name, pinv_matrix) + self._pinv_cache_map[cache_key] = buffer_name # Map new_size key to buffer name + + return pinv_matrix + + def forward(self, patch_embed: torch.Tensor, new_size: List[int]) -> torch.Tensor: + """ Resamples the patch embedding weights to new_size. + + Args: + patch_embed (torch.Tensor): Original weights (out_ch, in_ch, H_orig, W_orig). + new_size (List[int]): Target [height, width]. + + Returns: + torch.Tensor: Resampled weights. + """ + assert len(patch_embed.shape) == 4 + assert len(new_size) == 2 + + # Input Validation + input_size = tuple(patch_embed.shape[-2:]) + assert input_size == self.orig_size, \ + f"Input patch_embed spatial size {input_size} does not match " \ + f"module's expected original size {self.orig_size}" + + new_size_tuple: Tuple[int, int] = tuple(new_size) + + # Check no-op case against self.orig_size + if self.orig_size == new_size_tuple: + return patch_embed + + device = patch_embed.device + orig_dtype = patch_embed.dtype + + # Get or compute the required pseudoinverse matrix + pinv_matrix = self._get_or_create_pinv_matrix(new_size_tuple, device) + + # Apply the resampling + resampled_patch_embed = _apply_resampling(patch_embed, pinv_matrix, new_size_tuple, orig_dtype) + + return resampled_patch_embed + + +class PatchEmbedInterpolator(nn.Module): + """Dynamically interpolates patch embedding weights for variable patch sizes. + + This module wraps patch embedding weight resampling functionality to support + on-the-fly patch size variation during training. It handles both Conv2d and + Linear patch embeddings. + + Args: + base_patch_size: The original patch size the model was initialized with + in_chans: Number of input channels + embed_dim: Embedding dimension + interpolation: Interpolation mode for resampling + antialias: Whether to use antialiasing during interpolation + """ + + def __init__( + self, + base_patch_size: Tuple[int, int], + in_chans: int = 3, + embed_dim: int = 768, + interpolation: str = 'bicubic', + antialias: bool = True, + ): + super().__init__() + self.base_patch_size = base_patch_size + self.in_chans = in_chans + self.embed_dim = embed_dim + self.interpolation = interpolation + self.antialias = antialias + + def resample_linear_weight( + self, + weight: torch.Tensor, + target_patch_size: Tuple[int, int], + ) -> torch.Tensor: + """Resample linear patch embedding weights for a new patch size. + + Args: + weight: Linear weight tensor of shape [embed_dim, patch_h * patch_w * in_chans] + target_patch_size: Target (patch_h, patch_w) to resample to + + Returns: + Resampled weight tensor + """ + if target_patch_size == self.base_patch_size: + return weight + + embed_dim = weight.shape[0] + base_ph, base_pw = self.base_patch_size + target_ph, target_pw = target_patch_size + + # Reshape linear weight to conv2d format + # [embed_dim, ph*pw*C] -> [embed_dim, C, ph, pw] + weight_conv = weight.reshape(embed_dim, base_ph, base_pw, self.in_chans) + weight_conv = weight_conv.permute(0, 3, 1, 2) + + # Resample using existing function + weight_conv_resampled = resample_patch_embed( + weight_conv, + new_size=[target_ph, target_pw], + interpolation=self.interpolation, + antialias=self.antialias, + verbose=False, + ) + + # Reshape back to linear format + # [embed_dim, C, ph, pw] -> [embed_dim, ph*pw*C] + weight_resampled = weight_conv_resampled.permute(0, 2, 3, 1) + weight_resampled = weight_resampled.reshape(embed_dim, -1) + + return weight_resampled + + def resample_conv_weight( + self, + weight: torch.Tensor, + target_patch_size: Tuple[int, int], + ) -> torch.Tensor: + """Resample conv2d patch embedding weights for a new patch size. + + Args: + weight: Conv2d weight tensor of shape [embed_dim, in_chans, patch_h, patch_w] + target_patch_size: Target (patch_h, patch_w) to resample to + + Returns: + Resampled weight tensor + """ + if target_patch_size == self.base_patch_size: + return weight + + # Resample using existing function + weight_resampled = resample_patch_embed( + weight, + new_size=list(target_patch_size), + interpolation=self.interpolation, + antialias=self.antialias, + verbose=False, + ) + + return weight_resampled + + def forward( + self, + patches: torch.Tensor, + proj_weight: torch.Tensor, + proj_bias: Optional[torch.Tensor] = None, + patch_size: Optional[Tuple[int, int]] = None, + is_linear: bool = True, + ) -> torch.Tensor: + """Apply patch embedding with dynamic weight resampling. + + Args: + patches: Input patches + - For linear mode with resampling: [B, N, Ph, Pw, C] + - For linear mode without resampling: [B, N, Ph*Pw*C] + - For conv mode: [B, C, H, W] + proj_weight: Original projection weight + proj_bias: Optional projection bias + patch_size: Current patch size (if None, uses base_patch_size) + is_linear: Whether using linear (True) or conv2d (False) projection + + Returns: + Embedded patches + """ + if patch_size is None: + patch_size = self.base_patch_size + + if is_linear: + if patch_size != self.base_patch_size: + # Need to resample - expects unflattened patches + assert patches.ndim == 5, "Patches must be [B, N, Ph, Pw, C] for resampling" + B, N, Ph, Pw, C = patches.shape + + # Resample the weight + weight_resampled = self.resample_linear_weight(proj_weight, patch_size) + + # Flatten patches and apply linear projection + patches_flat = patches.reshape(B, N, -1) + output = torch.nn.functional.linear(patches_flat, weight_resampled, proj_bias) + else: + # No resampling needed, patches can be pre-flattened + if patches.ndim == 5: + B, N, Ph, Pw, C = patches.shape + patches = patches.reshape(B, N, -1) + output = torch.nn.functional.linear(patches, proj_weight, proj_bias) + else: + # Conv mode + if patch_size != self.base_patch_size: + weight_resampled = self.resample_conv_weight(proj_weight, patch_size) + output = torch.nn.functional.conv2d( + patches, weight_resampled, proj_bias, + stride=patch_size, padding=0 + ) + else: + output = torch.nn.functional.conv2d( + patches, proj_weight, proj_bias, + stride=patch_size, padding=0 + ) + + return output + # def divs(n, m=None): # m = m or n // 2 # if m == 1: diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 81db3ff01c..884df366a2 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -42,6 +42,7 @@ from .mobilenetv3 import * from .mobilevit import * from .mvitv2 import * +from .naflexvit import * from .nasnet import * from .nest import * from .nextvit import * diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index 6679b38b46..0352d79a5a 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -18,7 +18,7 @@ # Layers we went to treat as leaf modules from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format -from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc +from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc, maybe_add_mask from timm.layers.non_local_attn import BilinearAttnTransform from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame from timm.layers.norm_act import ( @@ -79,6 +79,7 @@ def get_notrace_modules(): _autowrap_functions = { resample_abs_pos_embed, resample_abs_pos_embed_nhwc, + maybe_add_mask, } diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py new file mode 100644 index 0000000000..e32e94f396 --- /dev/null +++ b/timm/models/naflexvit.py @@ -0,0 +1,1589 @@ +""" NaFlex Vision Transformer + +An improved version of the Vision Transformer with: +1. Encapsulated embedding and position encoding in a single module +2. Support for linear patch embedding on pre-patchified inputs +3. Support for NaFlex variable aspect, variable resolution +4. Support for FlexiViT variable patch size +5. Support for NaViT fractional/factorized position embedding + +Based on ideas from: +- Original Vision Transformer: https://arxiv.org/abs/2010.11929 +- FlexiViT: https://arxiv.org/abs/2212.08013 +- NaViT: https://arxiv.org/abs/2307.06304 +- NaFlex (SigLip-2): https://arxiv.org/abs/2502.14786 + +Hacked together by / Copyright 2025, Ross Wightman, Hugging Face +""" + +import logging +import math +from dataclasses import dataclass, fields, replace +from functools import partial +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Final, Any, Literal + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.layers import ( + AttentionPoolLatent, + Mlp, + to_2tuple, + get_act_layer, + get_norm_layer, + LayerNorm, + LayerType, + _assert, +) +from timm.models._builder import build_model_with_cfg +from timm.models._features import feature_take_indices +from timm.models._features_fx import register_notrace_function, register_notrace_module +from timm.models._registry import register_model, generate_default_cfgs +from timm.models._manipulate import checkpoint_seq, named_apply + +from .vision_transformer import Block, global_pool_nlc + +__all__ = ['NaFlexVitCfg', 'NaFlexVit'] + + +_logger = logging.getLogger(__name__) + + +@dataclass +class NaFlexVitCfg: + """Configuration for FlexVit model. + + This dataclass contains the bulk of model configuration parameters, + with core parameters (img_size, in_chans, num_classes, etc.) remaining + as direct constructor arguments for API compatibility. + """ + # Architecture parameters + patch_size: Union[int, Tuple[int, int]] = 16 + embed_dim: int = 768 + depth: int = 12 + num_heads: int = 12 + mlp_ratio: float = 4.0 + + # Attention parameters + qkv_bias: bool = True + qk_norm: bool = False + proj_bias: bool = True + attn_drop_rate: float = 0.0 + + # Regularization + init_values: Optional[float] = None # Layer-scale init values (layer-scale enabled if not None) + drop_rate: float = 0.0 # Dropout rate for classifier + pos_drop_rate: float = 0.0 # Dropout rate for position embeddings + patch_drop_rate: float = 0.0 # Dropout rate for patch tokens + proj_drop_rate: float = 0.0 # Dropout rate for linear projections + drop_path_rate: float = 0.0 # Stochastic depth drop rate + + # Prefix token configuration + class_token: bool = False # Use class token + reg_tokens: int = 0 # Number of register tokens + + # Position embedding configuration + pos_embed: str = 'learned' # Type of position embedding ('learned', 'factorized', 'rope', 'none') + pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16) # Grid size for position embedding initialization + pos_embed_interp_mode: str = 'bicubic' # Interpolation mode for position embedding resizing + pos_embed_ar_preserving: bool = False # Whether to preserve aspect ratio during position embedding interpolation + + # Image processing + dynamic_img_pad: bool = False # Whether to enable dynamic padding for variable resolution + + # Architecture choices + pre_norm: bool = False # Whether to apply normalization before attention/MLP layers (start of blocks) + final_norm: bool = True # Whether to apply final normalization before pooling and classifier (end of blocks) + fc_norm: Optional[bool] = None # Whether to normalize features before final classifier (after pooling) + global_pool: str = 'map' # Type of global pooling for final sequence + pool_include_prefix: bool = False # Whether to include class/register prefix tokens in global pooling + + # Weight initialization + weight_init: str = '' # Weight initialization scheme + fix_init: bool = True # Apply weight initialization fix (scaling w/ layer index) + + # Embedding configuration + embed_proj_type: str = 'linear' # Type of embedding layer ('conv' or 'linear') + input_norm_layer: Optional[str] = None # Normalization layer for embeddings input (before input projection) + embed_norm_layer: Optional[str] = None # Normalization layer for embeddings (after input projection) + + # Layer implementations + norm_layer: Optional[str] = None # Normalization layer for transformer blocks + act_layer: Optional[str] = None # Activation layer for MLP blocks + block_fn: Optional[str] = None # Transformer block implementation class name + mlp_layer: Optional[str] = None # MLP implementation class name + + # Variable patch size support + enable_patch_interpolator: bool = False # Enable dynamic patch size support + + +def _overlay_kwargs(cfg: NaFlexVitCfg, **kwargs) -> NaFlexVitCfg: + """Overlay kwargs onto config, replacing config values with provided kwargs.""" + # Only update fields that exist in the config + config_fields = set(cfg.__dataclass_fields__.keys()) + config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields} + + if config_kwargs: + cfg = replace(cfg, **config_kwargs) + + return cfg + + +def batch_patchify( + x: torch.Tensor, + patch_size: Tuple[int, int], + pad: bool = True, +) -> Tuple[torch.Tensor, Tuple[int, int]]: + """Patchify a batch of images. + + Args: + x: Input tensor of shape [B, C, H, W]. + patch_size: Patch dimensions (patch_h, patch_w). + pad: Whether to pad images to be divisible by patch size. + + Returns: + Tuple of (patches, grid_size) where patches has shape [B, N, P*P*C] + and grid_size is (num_patches_h, num_patches_w). + """ + B, C, H, W = x.shape + ph, pw = patch_size + + # Ensure the image is divisible by patch size + if pad and (H % ph != 0 or W % pw != 0): + pad_h = (ph - H % ph) % ph + pad_w = (pw - W % pw) % pw + x = F.pad(x, (0, pad_w, 0, pad_h)) + + nh, nw = H // ph, W // pw + patches = x.view(B, C, nh, ph, nw, pw).permute(0, 2, 4, 3, 5, 1).reshape(B, nh * nw, ph * pw * C) + # FIXME confirm we want 'channels last' in the patch channel layout, egg ph, ph, C instead of C, ph, hw + + return patches, (nh, nw) + + +@register_notrace_module +class NaFlexEmbeds(nn.Module): + """NaFlex Embedding module for Vision Transformers. + + This module encapsulates the complete embedding process for Vision Transformers, + supporting both standard and NaFlex (NaViT + FlexiViT) functionality: + + 1. Patch embedding (via Conv2d or Linear) + 2. Class and register token preparation + 3. Position embedding addition with interpolation support + 4. Pre-normalization (if requested) + 5. Dropout application + + NaFlex capabilities include: + - Variable aspect ratio and resolution via patch coordinates + - Patch type indicators for handling padding tokens in attention + - Flexible position embedding interpolation for arbitrary grid sizes + - Support for factorized position embeddings + + The patch embedding can be one of two types: + - Conv2d-based (default): For standard image inputs [B, C, H, W] + - Linear-based: For pre-patchified inputs [B, N, P*P*C] + + Args: + patch_size: Size of patches for patch embedding + in_chans: Number of input image channels + embed_dim: Dimensionality of patch embedding + proj_type: Type of embedding projection layer ('conv' or 'linear') + input_norm_layer: Normalization layer applied to input (linear mode only) + proj_norm_layer: Normalization layer applied after projection + pos_embed: Type of position embedding ('learned', 'factorized', 'rope', 'none') + pos_drop_rate: Dropout rate for position embeddings + patch_drop_rate: Dropout rate for patch tokens + class_token: Whether to include a class token + reg_tokens: Number of register tokens to include + bias: Whether to use bias in projection layers + dynamic_img_pad: Whether to enable dynamic padding for variable resolution + pos_embed_grid_size: Grid size for position embedding initialization + pos_embed_interp_mode: Interpolation mode for position embedding resizing + pos_embed_ar_preserving: Whether to preserve aspect ratio during position embedding interpolation + default_img_size: Default image size for position embedding grid calculation + """ + + def __init__( + self, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + proj_type: Optional[str] = None, + proj_bias: bool = True, + class_token: bool = True, + reg_tokens: int = 0, + dynamic_img_pad: bool = False, + default_img_size: Optional[Union[int, Tuple[int, int]]] = None, + pos_embed: str = 'learned', + pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14), + pos_embed_interp_mode: str = 'bicubic', + pos_embed_ar_preserving: bool = False, + input_norm_layer: Optional[Type[nn.Module]] = None, + proj_norm_layer: Union[bool, Optional[Type[nn.Module]]] = None, + norm_layer: Optional[Type[nn.Module]] = None, + pos_drop_rate: float = 0., + patch_drop_rate: float = 0., + enable_patch_interpolator: bool = False, + ) -> None: + """Initialize NaFlexEmbeds module. + + Args: + patch_size: Size of patches for patch embedding. + in_chans: Number of input image channels. + embed_dim: Dimensionality of patch embedding. + proj_type: Type of embedding projection layer ('conv' or 'linear'). + proj_bias: Whether to use bias in projection layers. + class_token: Whether to include a class token. + reg_tokens: Number of register tokens to include. + dynamic_img_pad: Whether to enable dynamic padding for variable resolution. + default_img_size: Default image size for position embedding grid calculation. + pos_embed: Type of position embedding ('learned', 'factorized', 'rope', 'none'). + pos_embed_grid_size: Grid size for position embedding initialization. + pos_embed_interp_mode: Interpolation mode for position embedding resizing. + pos_embed_ar_preserving: Whether to preserve aspect ratio during interpolation. + input_norm_layer: Normalization layer applied to input (linear mode only). + proj_norm_layer: Normalization layer applied after projection. + norm_layer: Default normalization layer. + pos_drop_rate: Dropout rate for position embeddings. + patch_drop_rate: Dropout rate for patch tokens. + enable_patch_interpolator: Enable dynamic patch size support. + """ + super().__init__() + self.has_class_token = class_token + self.num_reg_tokens = reg_tokens + self.pos_embed_interp_mode = pos_embed_interp_mode + self.pos_embed_ar_preserving = pos_embed_ar_preserving + self.patch_size = to_2tuple(patch_size) + self.in_chans = in_chans + self.embed_dim = embed_dim + self.dynamic_img_pad = dynamic_img_pad + self.enable_patch_interpolator = enable_patch_interpolator + + # Calculate number of prefix tokens + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + + # Create class and register tokens + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + + # Calculate grid size and number of patches + self.default_img_size: Optional[Tuple[int, int]] = None + self.pos_embed_grid_size: Optional[Tuple[int, int]] = None # Grid size used for learned pos embed init + if pos_embed_grid_size is not None: + # Highest priority, use provided pos_embed_grid_size + self.pos_embed_grid_size = pos_embed_grid_size + elif default_img_size is not None: + # Fallback to calculating grid size from img_size + patch_size if img size provided. + self.default_img_size = to_2tuple(default_img_size) + self.pos_embed_grid_size = tuple([s // p for s, p in zip(self.default_img_size, self.patch_size)]) + + # Determine patch embedding type (linear or conv2d) + if proj_type == 'linear': + # Create linear projection for pre-patchified inputs + # Input dimension is patch_size^2 * in_chans + patch_dim = self.patch_size[0] * self.patch_size[1] * in_chans + assert not (input_norm_layer is True and norm_layer is None), \ + "`norm_layer` must be given when input_norm_layer=True" + input_norm_layer = norm_layer if input_norm_layer is True else (input_norm_layer or None) + self.norm_input = input_norm_layer(patch_dim) if input_norm_layer else None + self.proj = nn.Linear(patch_dim, embed_dim, bias=proj_bias) + self.flatten = False + self.is_linear = True + else: + # Default to convolutional patch embedding for image inputs + assert not input_norm_layer + self.norm_input = None + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=proj_bias + ) + self.flatten = True + self.is_linear = False + + # Create patch embedding interpolator if enabled + if self.enable_patch_interpolator: + from timm.layers import PatchEmbedInterpolator + self.patch_interpolator = PatchEmbedInterpolator( + base_patch_size=self.patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + interpolation=pos_embed_interp_mode, + antialias=True, + ) + else: + self.patch_interpolator = None + + # Create normalization layer after the projection + assert not (proj_norm_layer is True and norm_layer is None), \ + "`norm_layer` must be given when proj_norm_layer=True" + proj_norm_layer = norm_layer if proj_norm_layer is True else (proj_norm_layer or None) + self.norm = proj_norm_layer(embed_dim) if proj_norm_layer else nn.Identity() + + # Create position embedding if needed - only for patches, never for prefix tokens + if pos_embed in ('factorized', 'learned') and self.pos_embed_grid_size is None: + raise ValueError( + "Cannot initialize position embeddings without grid_size." + "Please provide img_size or pos_embed_grid_size.") + self.pos_embed: Optional[torch.Tensor] = None + self.pos_embed_y: Optional[torch.Tensor] = None + self.pos_embed_x: Optional[torch.Tensor] = None + if not pos_embed or pos_embed == 'none': + self.pos_embed_type = 'none' + elif pos_embed == 'rope': + self.pos_embed_type = 'rope' + # Rotary embeddings will be computed on-the-fly in the forward pass + elif pos_embed == 'factorized': + assert self.pos_embed_grid_size is not None + h, w = self.pos_embed_grid_size + self.pos_embed_type = 'factorized' + self.pos_embed_y = nn.Parameter(torch.randn(1, h, embed_dim) * .02) + self.pos_embed_x = nn.Parameter(torch.randn(1, w, embed_dim) * .02) + else: + assert self.pos_embed_grid_size is not None + h, w = self.pos_embed_grid_size + self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim) * .02) + self.pos_embed_type = 'learned' + + # Dropout layers + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + from timm.layers.patch_dropout import PatchDropout + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() + + def feature_info(self, location) -> Dict[str, Any]: + """Get feature information for feature extraction. + + Args: + location: Feature extraction location identifier + + Returns: + Dictionary containing feature channel count and reduction factor + """ + return dict(num_chs=self.embed_dim, reduction=self.patch_size) + + def feat_ratio(self, as_scalar: bool = True) -> Union[int, Tuple[int, int]]: + """Get the feature reduction ratio (stride) of the patch embedding. + + Args: + as_scalar: Whether to return the maximum dimension as a scalar + + Returns: + Feature reduction ratio as scalar or tuple + """ + if as_scalar: + return max(self.patch_size) + else: + return self.patch_size + + def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: + """Calculate grid (feature) size for given image size. + + Takes into account dynamic padding when enabled. + + Args: + img_size: Input image size as (height, width) + + Returns: + Grid size as (grid_height, grid_width) + """ + if self.dynamic_img_pad: + return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1]) + else: + return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1] + + #@torch.compiler.disable() + def _apply_learned_naflex_pos_embed( + self, + x: torch.Tensor, + naflex_grid_sizes: List[Tuple[int, int]], + ) -> None: + """Apply learned position embeddings to NaFlex batch in-place. + + Interpolates learned position embeddings for each sample in the batch + based on their individual grid sizes. + + Args: + x: Input tensor to add position embeddings to + naflex_grid_sizes: List of (height, width) grid sizes for each batch element + """ + # Handle each batch element separately with its own grid size + orig_h, orig_w = self.pos_embed.shape[1:3] + pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W + + def _interp2d(size): + """ + Return a flattened positional-embedding grid at an arbitrary spatial resolution. + + Converts the learned 2-D table stored in NCHW format (pos_embed_nchw) into + a (1, H*W, C) sequence that matches the requested size. + """ + if (size[0] == orig_h) and (size[1] == orig_w): + pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) + else: + _interp_size = to_2tuple(max(size)) if self.pos_embed_ar_preserving else size + pos_embed_flat = F.interpolate( + pos_embed_nchw, + size=_interp_size, + mode=self.pos_embed_interp_mode, + align_corners=False, + antialias=True, + )[:, :, :size[0], :size[1]].flatten(2).transpose(1, 2) + return pos_embed_flat.to(dtype=x.dtype) + + # FIXME leaving alternative code commented here for now for comparisons + # pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {} + # for i, s in enumerate(naflex_grid_sizes): + # if s in pos_embed_cache: + # pos_embed_flat = pos_embed_cache[s] + # else: + # pos_embed_flat = _interp(s) + # pos_embed_cache[s] = pos_embed_flat + # + # seq_len = min(x.shape[1], pos_embed_flat.shape[1]) + # x[i, :seq_len] += pos_embed_flat[0, :seq_len] + + # Determine unique grid sizes to avoid duplicate interpolation + size_to_indices: Dict[Tuple[int, int], List[int]] = {} + for bi, k in enumerate(naflex_grid_sizes): + # k = h << 16 | w # FIXME can get jit compat with this + size_to_indices.setdefault(k, []).append(bi) + + for k, batch_indices in size_to_indices.items(): + # h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this + # Interpolate only once for this (h, w) + pos_embed_flat = _interp2d(k) + seq_len = min(x.shape[1], pos_embed_flat.shape[1]) + x[:, :seq_len].index_add_( + 0, + torch.as_tensor(batch_indices, device=x.device), + pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1) + ) + + def _apply_learned_pos_embed( + self, + x: torch.Tensor, + grid_size: List[int], + ) -> None: + """Apply learned position embeddings to standard batch in-place. + + Interpolates learned position embeddings to match the specified grid size. + + Args: + x: Input tensor to add position embeddings to + grid_size: Target grid size as [height, width] + """ + orig_h, orig_w = self.pos_embed.shape[1:3] + if grid_size[0] == orig_h or grid_size[1] == orig_w: + # No resize needed, just flatten + pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) + else: + # Resize if needed - directly using F.interpolate + pos_embed_flat = F.interpolate( + self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W + size=grid_size, + mode=self.pos_embed_interp_mode, + align_corners=False, + antialias=True, + ).flatten(2).transpose(1, 2) + pos_embed_flat = pos_embed_flat.to(dtype=x.dtype) + + x.add_(pos_embed_flat) + + def _apply_factorized_naflex_pos_embed( + self, + x: torch.Tensor, + naflex_grid_sizes: List[Tuple[int, int]], + ) -> None: + """Apply factorized position embeddings to NaFlex batch in-place. + + Uses separate Y and X position embedding tables that are interpolated + and combined for each sample's grid size. + + Args: + x: Input tensor to add position embeddings to + naflex_grid_sizes: List of (height, width) grid sizes for each batch element + """ + assert len(naflex_grid_sizes) == x.size(0) # one (H,W) per sample + + # Handle each batch element separately with its own grid size + orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1] + + # bucket samples that share the same (H,W) so we build each grid once + size_to_indices: Dict[Tuple[int, int], List[int]] = {} + for bi, k in enumerate(naflex_grid_sizes): + size_to_indices.setdefault(k, []).append(bi) + + def _interp1d(table: torch.Tensor, new_length: int, orig_length: int) -> torch.Tensor: + """ + Resample a 1-D positional-embedding table to specified length + and return it in (1, L, C) layout, dtype matching x. + """ + if new_length == orig_length: + return table.to(dtype=x.dtype) + return F.interpolate( + table.permute(0, 2, 1).float(), # (1,C,L) → (1,C,L_out) + size=new_length, + mode='linear', + align_corners=False, + ).permute(0, 2, 1).to(dtype=x.dtype) # → (1,L_out,C) + + for k, batch_indices in size_to_indices.items(): + target_h, target_w = k + if self.pos_embed_ar_preserving: + len_y = len_x = max(target_h, target_w) + else: + len_y, len_x = target_h, target_w + + pe_y = _interp1d(self.pos_embed_y, len_y, orig_h)[:, :target_h] # (1,H,C) + pe_x = _interp1d(self.pos_embed_x, len_x, orig_w)[:, :target_w] # (1,W,C) + + # Broadcast, add and flatten to sequence layout (row major) + pos = pe_y.unsqueeze(2) + pe_x.unsqueeze(1) # (1,H,W,C) + pos = pos.flatten(1, 2) + + seq_len = min(x.shape[1], pos.shape[1]) + x[:, :seq_len].index_add_( + 0, + torch.as_tensor(batch_indices, device=x.device), + pos[:, :seq_len].expand(len(batch_indices), -1, -1) + ) + + def forward( + self, + x: torch.Tensor, + patch_coord: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass for patch embedding with position encoding. + + Args: + x: Input tensor. Supported formats: + - [B, C, H, W] for conv mode + - [B, N, P*P*C] for pre-patchified linear mode (normal) + - [B, N, Ph, Pw, C] for pre-patchified linear mode (variable patch size) + patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode. + + Returns: + Embedded tensor with position encoding and class/register tokens. + Shape: [B, num_prefix_tokens + N, embed_dim] + """ + # Apply patch embedding + naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None + grid_size: Optional[List[int]] = None + + B = x.shape[0] + if self.is_linear: + # Linear embedding path, works with NaFlex mode or standard 2D mode + if patch_coord is not None: + # Pre-patchified NaFlex mode + # Variable patch size mode: [B, N, Ph, Pw, C], normal mode: [B, N, P*P*C] + _assert(x.ndim == 5 or x.ndim == 3, 'Expecting patchified input with ndim == 3 or 5.') + # Calculate the appropriate grid size from coords + max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1 + max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1 + naflex_grid_sizes = [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)] + else: + _assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4') + x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad) + + # Handle variable patch size projection + if self.enable_patch_interpolator and x.ndim == 5: + _assert(self.norm_input is None, 'input norm not supported with patch resizing') + + # Apply projection with interpolation + x = self.patch_interpolator( + x, + self.proj.weight, + self.proj.bias, + patch_size=tuple(x.shape[2:4]), # patch size from [B, N, Ph, Pw, C] shape + is_linear=True, + ) + else: + # Standard projection + x = x.flatten(2) # ensure [B, N, P*P*C], flatten Ph*Pw*C if separate + if self.norm_input is not None: + x = self.norm_input(x) + x = self.proj(x) + else: + _assert(x.ndim == 4, 'Convolutional input must be 4D') + if self.dynamic_img_pad: + H, W = x.shape[-2:] + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) + + x = self.proj(x) + + grid_size = x.shape[-2:] + if self.flatten: + x = x.flatten(2).transpose(1, 2) # NCHW -> NLC + + # Apply normalization after flattening + x = self.norm(x) + + if self.pos_embed_type == 'learned': + if naflex_grid_sizes is not None: + self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes) + else: + assert grid_size is not None + self._apply_learned_pos_embed(x, grid_size=grid_size) + elif self.pos_embed_type == 'factorized': + if naflex_grid_sizes is not None: + self._apply_factorized_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes) + elif self.pos_embed_type == 'rope': + assert False, "ROPE not yet implemented" + + # Prepare and add class and register tokens + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(B, -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(B, -1, -1)) + # Add tokens to the beginning + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + + # Apply dropouts + x = self.pos_drop(x) + x = self.patch_drop(x) + return x + + +@register_notrace_function +def create_attention_mask( + patch_valid: torch.Tensor, + num_prefix_tokens: int = 0, + symmetric: bool = True, + q_len: Optional[int] = None, + dtype: torch.dtype = torch.float32, +) -> Optional[torch.Tensor]: + """Creates an attention mask from patch validity information. + + Supports two modes controlled by `symmetric`: + 1. `symmetric=True` (default): Creates a symmetric mask of shape + [B, 1, seq_len, seq_len]. An attention pair (i, j) is allowed only if + both token i and token j are valid. Suitable for standard self-attention. + 2. `symmetric=False`: Creates a potentially non-square mask of shape + [B, 1, q_len, kv_len]. An attention pair (q, k) is allowed only if + the key/value token k is valid. Query token validity is not checked + in the mask itself. Useful for cross-attention or specific self-attention + implementations `q_len` can be specified. + + Used for NaFlex mode to handle variable token counts and padding tokens. + + Args: + patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding. + num_prefix_tokens: Number of prefix tokens (class token, register tokens) + to prepend, which are always considered valid. + symmetric: If True, create a symmetric mask. + If False, create an expanded mask based only on key/value validity. + q_len: Query sequence length override. Only used when `symmetric` is False. + Defaults to the key/value sequence length (`kv_len`) if None. + dtype: Dtype of the output attention mask (e.g., torch.float32). + + Returns: + Attention mask tensor. Additive mask (-inf for masked, 0 for unmasked). + Shape is [B, 1, seq_len, seq_len] if symmetric=True, + or [B, 1, q_len, kv_len] if symmetric=False. + """ + if patch_valid is None: + return None + + patch_valid = patch_valid.bool() # Ensure boolean type + B, N = patch_valid.shape + kv_len = N # Initial key/value length is the number of patches + + # Prepend prefix tokens if any + if num_prefix_tokens > 0: + # Create prefix validity tensor on the same device/dtype base as patch_valid + prefix_valid = patch_valid.new_ones((B, num_prefix_tokens), dtype=torch.bool) + # Concatenate prefix and patch validity. Shape becomes [B, num_prefix_tokens + N] + patch_valid = torch.cat([prefix_valid, patch_valid], dim=1) + kv_len += num_prefix_tokens # Update total key/value sequence length + + if symmetric: + # Symmetric mask is True where BOTH query and key are valid + mask_bool = patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1) + mask_bool = mask_bool.unsqueeze(1) # Add head dimension: [B, 1, seq_len, seq_len] + else: + # Expanded mask + q_len = q_len or kv_len + mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len) + + # Create the float mask and apply masking using additive mask convention + mask_float = torch.zeros_like(mask_bool, dtype=dtype) + # Fill with negative infinity where mask_bool is False (masked positions) + mask_float.masked_fill_(~mask_bool, torch.finfo(dtype).min) + + return mask_float + + +@register_notrace_function +def global_pool_naflex( + x: torch.Tensor, + patch_valid: Optional[torch.Tensor] = None, + pool_type: str = 'token', + num_prefix_tokens: int = 1, + reduce_include_prefix: bool = False, +) -> torch.Tensor: + """Global pooling with NaFlex support for masked tokens. + + Applies global pooling while respecting patch validity masks to exclude + padding tokens from pooling operations. + + Args: + x: Input tensor with shape [B, N, C] + patch_valid: Optional validity mask for patches [B, N-num_prefix_tokens] + pool_type: Type of pooling ('token', 'avg', 'avgmax', 'max') + num_prefix_tokens: Number of prefix tokens (class/register) + reduce_include_prefix: Whether to include prefix tokens in pooling reduction + + Returns: + Pooled tensor with shape [B, C] + """ + if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'): + # Fall back to standard pooling + x = global_pool_nlc( + x, + pool_type=pool_type, + num_prefix_tokens=num_prefix_tokens, + reduce_include_prefix=reduce_include_prefix, + ) + return x + + # For NaFlex mode, we need to apply masked pooling to exclude padding tokens + if num_prefix_tokens > 0: + if reduce_include_prefix: + # Include prefix tokens in pooling - they are always considered valid + # patch_valid only covers patch tokens, so create combined validity mask + prefix_valid = patch_valid.new_ones(x.shape[0], num_prefix_tokens) + patch_valid = torch.cat([prefix_valid, patch_valid], dim=1) + else: + # Exclude prefix tokens from pooling (default behavior) + x = x[:, num_prefix_tokens:] + + patch_valid_float = patch_valid.to(x.dtype) + if pool_type == 'avg': + # Compute masked average pooling, sum valid tokens and divide by count of valid tokens + masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1) + valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1) + pooled = masked_sums / valid_counts + return pooled + elif pool_type == 'avgmax': + # For avgmax, compute masked average and masked max + masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1) + valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1) + masked_avg = masked_sums / valid_counts + + # For max pooling we set masked positions to large negative value + masked_x = x.clone() + masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min + masked_max = masked_x.amax(dim=1) + + # Combine average and max + return 0.5 * (masked_avg + masked_max) + elif pool_type == 'max': + # For max pooling we set masked positions to large negative value + masked_x = x.clone() + masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min + return masked_x.amax(dim=1) + else: + assert False + + +class NaFlexVit(nn.Module): + """NaFlexVit: Vision Transformer with NaFlex support for flexible input handling. + + A flexible implementation of Vision Transformer that supports: + - Standard image classification with various pooling strategies + - NaFlex functionality for variable aspect ratios and resolutions + - Linear patch embedding for pre-patchified inputs + - Multiple position embedding strategies (learned, factorized, rope) + - Comprehensive attention masking for efficient batch processing + - Encapsulated embedding and position encoding in FlexEmbeds module + - Compatible with standard ViT checkpoints through checkpoint filtering + """ + + def __init__( + self, + cfg: Optional[NaFlexVitCfg] = None, + in_chans: int = 3, + num_classes: int = 1000, + img_size: Optional[Union[int, Tuple[int, int]]] = None, + **kwargs, + ) -> None: + """Initialize NaFlexVit model. + + Args: + cfg: Model configuration. If None, uses default NaFlexVitCfg. + in_chans: Number of input image channels. + num_classes: Number of classification classes. + img_size: Input image size for backwards compatibility. + **kwargs: Additional config parameters to override cfg values. + """ + super().__init__() + + # Initialize config + cfg = cfg or NaFlexVitCfg() + if kwargs: + cfg = _overlay_kwargs(cfg, **kwargs) + + # Validate configuration + assert cfg.global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') + assert cfg.class_token or cfg.global_pool != 'token' + assert cfg.pos_embed in ('', 'none', 'learned', 'factorized') + + # Resolve layer implementations + norm_layer = get_norm_layer(cfg.norm_layer) or LayerNorm + embed_norm_layer = get_norm_layer(cfg.embed_norm_layer) + act_layer = get_act_layer(cfg.act_layer) or nn.GELU + block_fn = Block # TODO: Support configurable block_fn via string lookup + mlp_layer = Mlp # TODO: Support configurable mlp_layer via string lookup + + # Store instance variables + self.num_classes = num_classes + self.global_pool = cfg.global_pool + self.num_features = self.head_hidden_size = self.embed_dim = cfg.embed_dim # for consistency with other models + self.num_prefix_tokens = 1 if cfg.class_token else 0 + self.num_prefix_tokens += cfg.reg_tokens + self.num_reg_tokens = cfg.reg_tokens + self.has_class_token = cfg.class_token + self.pool_include_prefix = cfg.pool_include_prefix + self.grad_checkpointing = False + + # Initialize embedding module (includes patch, position embedding, and class/reg tokens) + # FlexEmbeds is always used - handles both linear and conv embedding + self.embeds = NaFlexEmbeds( + patch_size=cfg.patch_size, + in_chans=in_chans, + embed_dim=cfg.embed_dim, + proj_type=cfg.embed_proj_type, + proj_bias=not cfg.pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + class_token=cfg.class_token, + reg_tokens=cfg.reg_tokens, + default_img_size=img_size, + dynamic_img_pad=cfg.dynamic_img_pad, + pos_embed=cfg.pos_embed, + pos_embed_grid_size=cfg.pos_embed_grid_size, + pos_embed_interp_mode=cfg.pos_embed_interp_mode, + pos_embed_ar_preserving=cfg.pos_embed_ar_preserving, + proj_norm_layer=embed_norm_layer, + pos_drop_rate=cfg.pos_drop_rate, + patch_drop_rate=cfg.patch_drop_rate, + enable_patch_interpolator=getattr(cfg, 'enable_patch_interpolator', False), + ) + self.norm_pre = norm_layer(cfg.embed_dim) if cfg.pre_norm else nn.Identity() + + # Transformer blocks + dpr = [x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + block_fn( + dim=cfg.embed_dim, + num_heads=cfg.num_heads, + mlp_ratio=cfg.mlp_ratio, + qkv_bias=cfg.qkv_bias, + qk_norm=cfg.qk_norm, + proj_bias=cfg.proj_bias, + init_values=cfg.init_values, + proj_drop=cfg.proj_drop_rate, + attn_drop=cfg.attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + mlp_layer=mlp_layer, + ) + for i in range(cfg.depth)]) + + # Feature info for downstream tasks + patch_reduction = self.embeds.feat_ratio(as_scalar=True) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=cfg.embed_dim, reduction=patch_reduction) + for i in range(cfg.depth) + ] + + self.norm = norm_layer(cfg.embed_dim) if cfg.final_norm and not cfg.fc_norm else nn.Identity() + + # Classifier Head + if cfg.global_pool == 'map': + self.attn_pool = AttentionPoolLatent( + self.embed_dim, + num_heads=cfg.num_heads, + mlp_ratio=cfg.mlp_ratio, + norm_layer=norm_layer, + act_layer=act_layer, + ) + else: + self.attn_pool = None + + # Handle fc_norm default value + fc_norm = cfg.fc_norm + if fc_norm is None: + fc_norm = cfg.global_pool == 'avg' + self.fc_norm = norm_layer(cfg.embed_dim) if cfg.final_norm and fc_norm else nn.Identity() + self.head_drop = nn.Dropout(cfg.drop_rate) + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if cfg.weight_init != 'skip': + self.init_weights(cfg.weight_init) + if cfg.fix_init: + self.fix_init_weight() + + def fix_init_weight(self) -> None: + """Apply initialization weight fix with layer-wise scaling.""" + def rescale(param: torch.Tensor, _layer_id: int) -> None: + param.div_(math.sqrt(2.0 * _layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def init_weights(self, mode: str = '') -> None: + """Initialize model weights according to specified scheme. + + Args: + mode: Initialization mode ('jax', 'jax_nlhb', 'moco', or '') + """ + assert mode in ('jax', 'jax_nlhb', 'moco', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + named_apply(get_init_weights_vit(mode, head_bias), self) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None: + # Custom loading for the new model structure + from .vision_transformer import _load_weights as _orig_load_weights + + def _load_weights_adapter(model, checkpoint_path, prefix=''): + """Adapter function to handle the different model structure""" + state_dict = torch.load(checkpoint_path, map_location='cpu') + if isinstance(state_dict, dict) and 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + + # Map original keys to new structure + for k in list(state_dict.keys()): + if k.startswith('cls_token'): + state_dict['embeds.' + k] = state_dict.pop(k) + elif k.startswith('reg_token'): + state_dict['embeds.' + k] = state_dict.pop(k) + elif k.startswith('pos_embed'): + state_dict['embeds.' + k] = state_dict.pop(k) + elif k.startswith('patch_embed'): + state_dict['embeds.' + k[12:]] = state_dict.pop(k) + + return _orig_load_weights(model, state_dict, prefix) + + _load_weights_adapter(self, checkpoint_path, prefix) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + """Get set of parameter names that should not have weight decay applied. + + Returns: + Set of parameter names to skip during weight decay + """ + skip_list = {'embeds.pos_embed', 'embeds.cls_token', 'embeds.reg_token'} + return skip_list + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + """Get parameter group matcher for optimizer parameter grouping. + + Args: + coarse: Whether to use coarse-grained grouping + + Returns: + Dictionary mapping group names to regex patterns + """ + return dict( + stem=r'^embeds', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + """Enable or disable gradient checkpointing for memory efficiency. + + Args: + enable: Whether to enable gradient checkpointing + """ + self.grad_checkpointing = enable + if hasattr(self.embeds, 'patch_embed') and hasattr(self.embeds.patch_embed, 'set_grad_checkpointing'): + self.embeds.patch_embed.set_grad_checkpointing(enable) + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + """Get the classification head module. + + Returns: + Classification head module + """ + return self.head + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Reset the classification head with new number of classes and pooling. + + Args: + num_classes: Number of classes for new classification head + global_pool: Optional new global pooling type + """ + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') + if global_pool == 'map' and self.attn_pool is None: + assert False, "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != 'map' and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_intermediates( + self, + x: Union[torch.Tensor, Dict[str, torch.Tensor]], + indices: Optional[Union[int, List[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + output_dict: bool = False, + patch_coord: Optional[torch.Tensor] = None, + patch_valid: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys + patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode + patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex + mask: Optional attention mask + Returns: + A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing + 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix') + """ + + # FIXME unfinished / untested + + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + if isinstance(x, Dict): + # Handle dictionary input from NaFlex collator + patch_coord = x['patch_coord'] + patch_valid = x['patch_valid'] + patches = x['patches'] + assert False, 'WIP, patch mode needs more work' + else: + patches = x + height, width = x.shape[-2:] + H, W = self.embeds.dynamic_feat_size((height, width)) + + # Create attention mask if patch_type is provided and mask is not + if mask is None and patch_valid is not None: + mask = create_attention_mask(patch_valid, self.num_prefix_tokens, patches.dtype) + + # Forward pass through embedding + x = self.embeds(patches, patch_coord=patch_coord) + x = self.norm_pre(x) + + # Forward pass through blocks + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + + for i, blk in enumerate(blocks): + x = blk(x, attn_mask=mask) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # Process intermediates + if self.num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] + intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] + else: + prefix_tokens = None + + if reshape: + # reshape to BCHW output format + intermediates = [ + y.reshape(y.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous() + for y in intermediates + ] + + # For dictionary output + if output_dict: + result_dict = {} + # Intermediates are always included + result_dict['image_intermediates'] = intermediates + if prefix_tokens is not None and return_prefix_tokens: + result_dict['image_intermediates_prefix'] = prefix_tokens + + # Only include features if not intermediates_only + if not intermediates_only: + x_final = self.norm(x) + result_dict['image_features'] = x_final + + return result_dict + + # For non-dictionary output, maintain the original behavior + if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def forward_features( + self, + x: torch.Tensor, + patch_coord: Optional[torch.Tensor] = None, + patch_valid: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if attn_mask is None: + attn_mask = create_attention_mask( + patch_valid, + num_prefix_tokens=self.num_prefix_tokens, + dtype=x.dtype + ) + + # Pass through embedding module with patch coordinate/type support + x = self.embeds(x, patch_coord=patch_coord) + x = self.norm_pre(x) + # Apply transformer blocks with masked attention if mask provided + if attn_mask is not None: + # We need to apply blocks one by one with mask + for blk in self.blocks: + x = blk(x, attn_mask=attn_mask) + elif self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + + x = self.norm(x) + return x + + def _pool( + self, + x: torch.Tensor, + pool_type: Optional[str] = None, + patch_valid: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.attn_pool is not None: + # For attention pooling, we need to pass the mask for NaFlex models + if self.pool_include_prefix: + # Include all tokens in attention pooling - create mask for all tokens including prefix + attn_mask = create_attention_mask( + patch_valid, + num_prefix_tokens=self.num_prefix_tokens, + symmetric=False, + q_len=1, + dtype=x.dtype, + ) + x = self.attn_pool(x, attn_mask=attn_mask) + else: + # Exclude prefix tokens from attention pooling (default behavior) + attn_mask = create_attention_mask( + patch_valid, + num_prefix_tokens=0, # No prefix tokens when we slice them off + symmetric=False, + q_len=1, + dtype=x.dtype, + ) + x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask) + return x + + pool_type = self.global_pool if pool_type is None else pool_type + + x = global_pool_naflex( + x, + patch_valid, + pool_type=pool_type, + num_prefix_tokens=self.num_prefix_tokens, + reduce_include_prefix=self.pool_include_prefix, + ) + return x + + def forward_head( + self, + x: torch.Tensor, + pre_logits: bool = False, + patch_valid: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x = self._pool(x, patch_valid=patch_valid) + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward( + self, + x: Union[torch.Tensor, Dict[str, torch.Tensor]], + patch_coord: Optional[torch.Tensor] = None, + patch_valid: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with optional NaFlex support. + + Args: + x: Input tensor. Supported formats: + - [B, C, H, W] standard image input + - [B, N, P*P*C] pre-patchified tensor (flattened patches) + - [B, N, Ph, Pw, C] pre-patchified tensor (variable patch size) + - Dict from NaFlex collator + patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode. + patch_valid: Optional patch validity indicators for NaFlex. + + Returns: + Model output tensor. + """ + if isinstance(x, Dict): + # Handle dictionary input from NaFlex collator + patch_coord = x['patch_coord'] + patch_valid = x['patch_valid'] + patches = x['patches'] + + # DEBUG, reconstruct patches + # for i in range(len(patches)): + # patch = patches[i][patch_valid[i]] + # h = (patch_coord[i, :, 0].max() + 1).item() + # w = (patch_coord[i, :, 1].max() + 1).item() + # patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3) + # patch = patch.reshape(3, h*16, w*16) + # from torchvision.utils import save_image + # save_image(patch, f'patch_{i}.jpg', normalize=True) + else: + patches = x + + # Create attention mask if patch_type is provided + attn_mask = create_attention_mask( + patch_valid, + num_prefix_tokens=self.num_prefix_tokens, + dtype=patches.dtype, + ) + + # Forward features with mask + x = self.forward_features( + patches, + patch_coord=patch_coord, + patch_valid=patch_valid, + attn_mask=attn_mask, + ) + + # Pass mask to forward_head for masked pooling + x = self.forward_head( + x, + patch_valid=patch_valid, + ) + return x + + +def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable: + """Function imported from vision_transformer.py to maintain compatibility""" + from .vision_transformer import init_weights_vit_jax, init_weights_vit_moco, init_weights_vit_timm + + if 'jax' in mode: + return partial(init_weights_vit_jax, head_bias=head_bias) + elif 'moco' in mode: + return init_weights_vit_moco + else: + return init_weights_vit_timm + + +def checkpoint_filter_fn(state_dict: Dict[str, Any], model: NaFlexVit) -> Dict[str, Any]: + """Handle state dict conversion from original ViT to the new version with combined embedding.""" + from .vision_transformer import checkpoint_filter_fn as orig_filter_fn + + # Handle CombinedEmbed module pattern + out_dict = {} + for k, v in state_dict.items(): + # Convert tokens and embeddings to combined_embed structure + if k == 'pos_embed': + # Handle position embedding format conversion - from (1, N, C) to (1, H, W, C) + if hasattr(model.embeds, 'pos_embed') and v.ndim == 3: + num_cls_token = 0 + num_reg_token = 0 + if 'reg_token' in state_dict: + num_reg_token = state_dict['reg_token'].shape[1] + if 'cls_token' in state_dict: + num_cls_token = state_dict['cls_token'].shape[1] + num_prefix_tokens = num_cls_token + num_reg_token + + # Original format is (1, N, C), need to reshape to (1, H, W, C) + num_patches = v.shape[1] + num_patches_no_prefix = num_patches - num_prefix_tokens + grid_size_no_prefix = math.sqrt(num_patches_no_prefix) + grid_size = math.sqrt(num_patches) + if (grid_size_no_prefix != grid_size + and (grid_size_no_prefix.is_integer() and not grid_size.is_integer()) + ): + # make a decision, did the pos_embed of the original include the prefix tokens? + num_patches = num_patches_no_prefix + cls_token_emb = v[:, 0:num_cls_token] + if cls_token_emb.numel(): + state_dict['cls_token'] += cls_token_emb + reg_token_emb = v[:, num_cls_token:num_reg_token] + if reg_token_emb.numel(): + state_dict['reg_token'] += reg_token_emb + v = v[:, num_prefix_tokens:] + grid_size = grid_size_no_prefix + grid_size = int(grid_size) + + # Check if it's a perfect square for a standard grid + if grid_size * grid_size == num_patches: + # Reshape from (1, N, C) to (1, H, W, C) + v = v.reshape(1, grid_size, grid_size, v.shape[2]) + else: + # Not a square grid, we need to get the actual dimensions + if hasattr(model.embeds.patch_embed, 'grid_size'): + h, w = model.embeds.patch_embed.grid_size + if h * w == num_patches: + # We have the right dimensions + v = v.reshape(1, h, w, v.shape[2]) + else: + # Dimensions don't match, use interpolation + _logger.warning( + f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. " + f"Using default initialization and will resize in forward pass." + ) + # Keep v as is, the forward pass will handle resizing + + out_dict['embeds.pos_embed'] = v + elif k == 'cls_token': + out_dict['embeds.cls_token'] = v + elif k == 'reg_token': + out_dict['embeds.reg_token'] = v + # Convert patch_embed.X to embeds.patch_embed.X + elif k.startswith('patch_embed.'): + suffix = k[12:] + if suffix == 'proj.weight': + v = v.permute(0, 2, 3, 1).flatten(1) + new_key = 'embeds.' + suffix + out_dict[new_key] = v + else: + out_dict[k] = v + + return out_dict + + +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + return { + 'url': url, + 'num_classes': 1000, + 'input_size': (3, 256, 256), + 'pool_size': None, + 'crop_pct': 0.95, + 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, + 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'embeds.proj', + 'classifier': 'head', + 'license': 'apache-2.0', + **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + 'naflexvit_base_patch16_gap': _cfg(), + 'naflexvit_base_patch16_map': _cfg(), + + 'naflexvit_base_patch16_siglip': _cfg(), + 'naflexvit_so400m_patch16_siglip': _cfg(), +}) + + +def _create_naflexvit(variant: str, pretrained: bool = False, **kwargs) -> NaFlexVit: + out_indices = kwargs.pop('out_indices', 3) + cfg = kwargs.pop('cfg', NaFlexVitCfg()) + cfg_field_names = {f.name for f in fields(NaFlexVitCfg)} + # pop in-place so the original kwargs is emptied of cfg-specific keys + cfg_updates = {k: kwargs.pop(k) for k in list(kwargs) if k in cfg_field_names} + if cfg_updates: + cfg = _overlay_kwargs(cfg, **cfg_updates) + + model = build_model_with_cfg( + NaFlexVit, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + cfg=cfg, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) + return model + + +def _create_naflexvit_from_classic( + variant: str, + pretrained: bool = False, + **kwargs, +) -> NaFlexVit: + """Create FlexVit model from classic VisionTransformer configuration. + + This function handles the parameter mapping and configuration logic needed + to create FlexVit models that are compatible with classic VisionTransformer + configurations and pretrained weights. + + Args: + variant: Model variant name + pretrained: Whether to load pretrained weights + **kwargs: Classic VisionTransformer parameters + + Returns: + FlexVit model instance + """ + # Remove VisionTransformer-specific parameters that don't apply to FlexVit + kwargs.pop('no_embed_class', None) + kwargs.pop('dynamic_img_size', None) + + # Handle global pooling and fc_norm defaults that differ between ViT and FlexVit + gp = kwargs.pop('global_pool', 'token') # Original ViTs default to cls token pooling + fc_norm = kwargs.pop('fc_norm', None) # Original ViTs used fc_norm when not set and avg pooling used + if fc_norm is None and gp == 'avg': + fc_norm = True + + # Set FlexVit-specific defaults that differ from VisionTransformer + flex_kwargs = { + 'pos_embed_grid_size': None, # rely on img_size (// patch_size) that will be passed through + 'class_token': kwargs.get('class_token', True), + 'global_pool': gp, + 'fc_norm': fc_norm, + **kwargs # User overrides take precedence + } + + return _create_naflexvit(variant, pretrained, **flex_kwargs) + + +@register_model +def naflexvit_base_patch16_gap(pretrained: bool = False, **kwargs) -> NaFlexVit: + """ViT-Base with NaFlex functionality and global average pooling. + """ + cfg = NaFlexVitCfg( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + init_values=1e-5, + global_pool='avg', + reg_tokens=4, + fc_norm=True, + ) + model = _create_naflexvit('naflexvit_base_patch16_gap', pretrained=pretrained, cfg=cfg, **kwargs) + return model + + +@register_model +def naflexvit_base_patch16_map(pretrained: bool = False, **kwargs) -> NaFlexVit: + """ViT-Base with NaFlex functionality and MAP attention pooling. + """ + cfg = NaFlexVitCfg( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + init_values=1e-5, + global_pool='map', + reg_tokens=1, + ) + model = _create_naflexvit('naflexvit_base_patch16_map', pretrained=pretrained, cfg=cfg, **kwargs) + return model + + +@register_model +def naflexvit_so150m2_patch16_reg1_gap(pretrained: bool = False, **kwargs) -> NaFlexVit: + """ViT-SO150M2 with NaFlex functionality for variable aspect ratios and resolutions. + + This model supports: + 1. Variable aspect ratios and resolutions via patch coordinates + 2. Position embedding interpolation for arbitrary grid sizes + 3. Explicit patch coordinates and valid token masking + """ + cfg = NaFlexVitCfg( + patch_size=16, + embed_dim=832, + depth=21, + num_heads=13, + mlp_ratio=34/13, + init_values=1e-5, + qkv_bias=False, + reg_tokens=1, + global_pool='avg', + fc_norm=True, + ) + model = _create_naflexvit('naflexvit_so150m2_patch16_reg1_gap', pretrained=pretrained, cfg=cfg, **kwargs) + return model + + +@register_model +def naflexvit_so150m2_patch16_reg1_map(pretrained: bool = False, **kwargs) -> NaFlexVit: + """ViT-SO150M2 with NaFlex functionality for variable aspect ratios and resolutions. + + This model supports: + 1. Variable aspect ratios and resolutions via patch coordinates + 2. Position embedding interpolation for arbitrary grid sizes + 3. Explicit patch coordinates and valid token masking + """ + cfg = NaFlexVitCfg( + patch_size=16, + embed_dim=832, + depth=21, + num_heads=13, + mlp_ratio=34/13, + init_values=1e-5, + qkv_bias=False, + reg_tokens=1, + global_pool='map', + ) + model = _create_naflexvit('naflexvit_so150m2_patch16_reg1_map', pretrained=pretrained, cfg=cfg, **kwargs) + return model + + +@register_model +def naflexvit_base_patch16_siglip(pretrained: bool = False, **kwargs) -> NaFlexVit: + """ViT-Base with NaFlex functionality and SigLIP-style configuration. + """ + cfg = NaFlexVitCfg( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + act_layer='gelu_tanh', + global_pool='map', + ) + model = _create_naflexvit('naflexvit_base_patch16_siglip', pretrained=pretrained, cfg=cfg, **kwargs) + return model + + +@register_model +def naflexvit_so400m_patch16_siglip(pretrained: bool = False, **kwargs) -> NaFlexVit: + """ViT-SO400M with NaFlex functionality for variable aspect ratios and resolutions. + """ + cfg = NaFlexVitCfg( + patch_size=16, + embed_dim=1152, + depth=27, + num_heads=16, + mlp_ratio=3.7362, + act_layer='gelu_tanh', + global_pool='map', + ) + model = _create_naflexvit('naflexvit_so400m_patch16_siglip', pretrained=pretrained, cfg=cfg, **kwargs) + return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 594415493b..3fcf400eef 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -26,6 +26,7 @@ import copy import logging import math +import os from collections import OrderedDict from functools import partial from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List @@ -45,15 +46,16 @@ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD ) from timm.layers import ( + Attention, + AttentionPoolLatent, PatchEmbed, Mlp, - DropPath, - AttentionPoolLatent, + SwiGLUPacked, + SwiGLU, LayerNorm, RmsNorm, + DropPath, PatchDropout, - SwiGLUPacked, - SwiGLU, trunc_normal_, lecun_normal_, resample_patch_embed, @@ -61,6 +63,7 @@ use_fused_attn, get_act_layer, get_norm_layer, + maybe_add_mask, LayerType, ) from ._builder import build_model_with_cfg @@ -74,61 +77,6 @@ _logger = logging.getLogger(__name__) -class Attention(nn.Module): - fused_attn: Final[bool] - - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = False, - scale_norm: bool = False, - proj_bias: bool = True, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: Type[nn.Module] = LayerNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - self.fused_attn = use_fused_attn() - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.norm = norm_layer(dim) if scale_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim, bias=proj_bias) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) - q, k = self.q_norm(q), self.k_norm(k) - - if self.fused_attn: - x = F.scaled_dot_product_attention( - q, k, v, - dropout_p=self.attn_drop.p if self.training else 0., - ) - else: - q = q * self.scale - attn = q @ k.transpose(-2, -1) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = attn @ v - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.norm(x) - x = self.proj(x) - x = self.proj_drop(x) - return x - - class LayerScale(nn.Module): def __init__( self, @@ -191,8 +139,8 @@ def __init__( self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x @@ -252,8 +200,8 @@ def init_weights(self) -> None: nn.init.constant_(self.norm1.weight, self.init_values) nn.init.constant_(self.norm2.weight, self.init_values) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.drop_path1(self.norm1(self.attn(x))) + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + x = x + self.drop_path1(self.norm1(self.attn(x, attn_mask=attn_mask))) x = x + self.drop_path2(self.norm2(self.mlp(x))) return x @@ -315,7 +263,7 @@ def __init__( self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: B, N, C = x.shape # Combined MLP fc1 & qkv projections @@ -335,14 +283,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.fused_attn: x_attn = F.scaled_dot_product_attention( q, k, v, + attn_mask=attn_mask, dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) + attn = maybe_add_mask(attn, attn_mask) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x_attn = attn @ v + x_attn = x_attn.transpose(1, 2).reshape(B, N, C) x_attn = self.attn_out_proj(x_attn) @@ -416,23 +367,21 @@ def __init__( ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ]))) - def _forward_jit(self, x: torch.Tensor) -> torch.Tensor: - x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if attn_mask is not None: + attn_out = [] + for attn in self.attns: + x_attn = attn.norm(x) + x_attn = attn.attn(x_attn, attn_mask=attn_mask) + x_attn = attn.ls(x_attn) + x_attn = attn.drop_path(x_attn) + attn_out.append(x_attn) + x = x + torch.stack(attn_out).sum(dim=0) + else: + x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) return x - @torch.jit.ignore - def _forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + sum(attn(x) for attn in self.attns) - x = x + sum(ffn(x) for ffn in self.ffns) - return x - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return self._forward_jit(x) - else: - return self._forward(x) - def global_pool_nlc( x: torch.Tensor, @@ -491,6 +440,7 @@ def __init__( pre_norm: bool = False, final_norm: bool = True, fc_norm: Optional[bool] = None, + pool_include_prefix: bool = False, dynamic_img_size: bool = False, dynamic_img_pad: bool = False, drop_rate: float = 0., @@ -555,7 +505,8 @@ def __init__( self.num_prefix_tokens += reg_tokens self.num_reg_tokens = reg_tokens self.has_class_token = class_token - self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg) + self.no_embed_class = no_embed_class + self.pool_include_prefix = pool_include_prefix self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False @@ -769,7 +720,9 @@ def forward_intermediates( stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, - ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + output_dict: bool = False, + attn_mask: Optional[torch.Tensor] = None, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]: """ Forward features that returns intermediates. Args: @@ -780,8 +733,11 @@ def forward_intermediates( stop_early: Stop iterating over blocks when last desired intermediate hit output_fmt: Shape of intermediate feature outputs intermediates_only: Only return intermediate features + output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys + attn_mask: Optional attention mask for masked attention (e.g., for NaFlex) Returns: - + A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing + 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix') """ assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' reshape = output_fmt == 'NCHW' @@ -800,7 +756,7 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x) + x = blk(x, attn_mask=attn_mask) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) @@ -817,6 +773,23 @@ def forward_intermediates( # reshape to BCHW output format H, W = self.patch_embed.dynamic_feat_size((height, width)) intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + + # For dictionary output, handle prefix tokens separately + if output_dict: + result_dict = {} + # Intermediates are always included + result_dict['image_intermediates'] = intermediates + if prefix_tokens is not None and return_prefix_tokens: + result_dict['image_intermediates_prefix'] = prefix_tokens + + # Only include features if not intermediates_only + if not intermediates_only: + x_final = self.norm(x) + result_dict['image_features'] = x_final + + return result_dict + + # For non-dictionary output, maintain the original behavior if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None: # return_prefix not support in torchscript due to poor type handling intermediates = list(zip(intermediates, prefix_tokens)) @@ -852,6 +825,7 @@ def get_intermediate_layers( reshape: bool = False, return_prefix_tokens: bool = False, norm: bool = False, + attn_mask: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """ Intermediate layer accessor inspired by DINO / DINOv2 interface. NOTE: This API is for backwards compat, favour using forward_intermediates() directly. @@ -862,26 +836,40 @@ def get_intermediate_layers( norm=norm, output_fmt='NCHW' if reshape else 'NLC', intermediates_only=True, + attn_mask=attn_mask, ) - def forward_features(self, x: torch.Tensor) -> torch.Tensor: + def forward_features(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: x = self.patch_embed(x) x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) - if self.grad_checkpointing and not torch.jit.is_scripting(): + + if attn_mask is not None: + # If mask provided, we need to apply blocks one by one + for blk in self.blocks: + x = blk(x, attn_mask=attn_mask) + elif self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) + x = self.norm(x) return x def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor: if self.attn_pool is not None: + if not self.pool_include_prefix: + x = x[:, self.num_prefix_tokens:] x = self.attn_pool(x) return x pool_type = self.global_pool if pool_type is None else pool_type - x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens) + x = global_pool_nlc( + x, + pool_type=pool_type, + num_prefix_tokens=self.num_prefix_tokens, + reduce_include_prefix=self.pool_include_prefix, + ) return x def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: @@ -890,8 +878,8 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso x = self.head_drop(x) return x if pre_logits else self.head(x) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.forward_features(x) + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + x = self.forward_features(x, attn_mask=attn_mask) x = self.forward_head(x) return x @@ -2531,7 +2519,23 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: default_cfgs = generate_default_cfgs(default_cfgs) -def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) -> VisionTransformer: +# Global flag to use NaFlexVit instead of VisionTransformer +_USE_NAFLEX_DEFAULT = os.environ.get('TIMM_USE_NAFLEXVIT', 'false').lower() == 'true' + +def _create_vision_transformer( + variant: str, + pretrained: bool = False, + use_naflex: Optional[bool] = None, + **kwargs, +) -> Union[VisionTransformer, 'NaFlexVit']: + # Check if we should use NaFlexVit instead + if use_naflex is None: + use_naflex = _USE_NAFLEX_DEFAULT + if use_naflex: + # Import here to avoid circular imports + from .naflexvit import _create_naflexvit_from_classic + return _create_naflexvit_from_classic(variant, pretrained, **kwargs) + out_indices = kwargs.pop('out_indices', 3) if 'flexi' in variant: # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed @@ -4177,7 +4181,7 @@ def test_vit3(pretrained: bool = False, **kwargs) -> VisionTransformer: """ model_args = dict( patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=2, - class_token=False, reg_tokens=1, global_pool='map', init_values=1e-5) + class_token=False, reg_tokens=1, global_pool='map', pool_include_prefix=True, init_values=1e-5) model = _create_vision_transformer('test_vit3', pretrained=pretrained, **dict(model_args, **kwargs)) return model diff --git a/train.py b/train.py index 11b783e8e8..eaa2d213e9 100755 --- a/train.py +++ b/train.py @@ -33,7 +33,8 @@ from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm import utils -from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset +from timm.data import create_dataset, create_loader, create_naflex_loader, resolve_data_config, \ + Mixup, FastCollateMixup, AugMixDataset from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters @@ -396,6 +397,20 @@ group.add_argument('--wandb-resume-id', default='', type=str, metavar='ID', help='If resuming a run, the id of the run in wandb') +# NaFlex scheduled loader arguments +group.add_argument('--naflex-loader', action='store_true', default=False, + help='Use NaFlex loader (Requires NaFlex compatible model)') +group.add_argument('--naflex-train-seq-lens', type=int, nargs='+', default=[128, 256, 576, 784, 1024], + help='Sequence lengths to use for NaFlex loader') +group.add_argument('--naflex-max-seq-len', type=int, default=576, + help='Fixed maximum sequence length for NaFlex loader (validation)') +group.add_argument('--naflex-patch-sizes', type=int, nargs='+', default=None, + help='List of patch sizes for variable patch size training (e.g., 8 12 16 24 32)') +group.add_argument('--naflex-patch-size-probs', type=float, nargs='+', default=None, + help='Probabilities for each patch size (must sum to 1.0, uniform if not specified)') +group.add_argument('--naflex-loss-scale', default='linear', type=str, + help='Scale loss (gradient) by batch_size ("none", "sqrt", or "linear")') + def _parse_args(): # Do we have a config file to parse? @@ -669,6 +684,7 @@ def main(): trust_remote_code=args.dataset_trust_remote_code, ) + dataset_eval = None if args.val_split: dataset_eval = create_dataset( args.dataset, @@ -685,38 +701,23 @@ def main(): trust_remote_code=args.dataset_trust_remote_code, ) - # setup mixup / cutmix - collate_fn = None - mixup_fn = None - mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None - if mixup_active: - mixup_args = dict( - mixup_alpha=args.mixup, - cutmix_alpha=args.cutmix, - cutmix_minmax=args.cutmix_minmax, - prob=args.mixup_prob, - switch_prob=args.mixup_switch_prob, - mode=args.mixup_mode, - label_smoothing=args.smoothing, - num_classes=args.num_classes - ) - if args.prefetcher: - assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup) - collate_fn = FastCollateMixup(**mixup_args) - else: - mixup_fn = Mixup(**mixup_args) - - # wrap dataset in AugMix helper - if num_aug_splits > 1: - dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) - # create data loaders w/ augmentation pipeline train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] - loader_train = create_loader( - dataset_train, - input_size=data_config['input_size'], + + # Check if we should use the NaFlex scheduled loader + common_loader_kwargs = dict( + mean=data_config['mean'], + std=data_config['std'], + pin_memory=args.pin_mem, + img_dtype=model_dtype or torch.float32, + device=device, + distributed=args.distributed, + use_prefetcher=args.prefetcher, + ) + + train_loader_kwargs = dict( batch_size=args.batch_size, is_training=True, no_aug=args.no_aug, @@ -737,42 +738,135 @@ def main(): num_aug_repeats=args.aug_repeats, num_aug_splits=num_aug_splits, interpolation=train_interpolation, - mean=data_config['mean'], - std=data_config['std'], num_workers=args.workers, - distributed=args.distributed, - collate_fn=collate_fn, - pin_memory=args.pin_mem, - img_dtype=model_dtype or torch.float32, - device=device, - use_prefetcher=args.prefetcher, - use_multi_epochs_loader=args.use_multi_epochs_loader, worker_seeding=args.worker_seeding, ) + mixup_fn = None + mixup_args = {} + mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + if mixup_active: + mixup_args = dict( + mixup_alpha=args.mixup, + cutmix_alpha=args.cutmix, + cutmix_minmax=args.cutmix_minmax, + prob=args.mixup_prob, + switch_prob=args.mixup_switch_prob, + mode=args.mixup_mode, + label_smoothing=args.smoothing, + num_classes=args.num_classes + ) + + naflex_mode = False + model_patch_size = None + if args.naflex_loader: + if utils.is_primary(args): + _logger.info('Using NaFlex loader') + + assert num_aug_splits <= 1, 'Augmentation splits not supported in NaFlex mode' + naflex_mixup_fn = None + if mixup_active: + from timm.data import NaFlexMixup + mixup_args.pop('mode') # not supported + mixup_args.pop('cutmix_minmax') # not supported + naflex_mixup_fn = NaFlexMixup(**mixup_args) + + # Extract model's patch size for NaFlex mode + if hasattr(model, 'embeds') and hasattr(model.embeds, 'patch_size'): + # NaFlexVit models have embeds.patch_size + model_patch_size = model.embeds.patch_size + else: + # Fallback to default + model_patch_size = (16, 16) + if utils.is_primary(args): + _logger.warning(f'Could not determine model patch size, using default: {model_patch_size}') + + # Configure patch sizes for NaFlex loader + patch_loader_kwargs = {} + if args.naflex_patch_sizes: + # Variable patch size mode + patch_loader_kwargs['patch_size_choices'] = args.naflex_patch_sizes + if args.naflex_patch_size_probs: + if len(args.naflex_patch_size_probs) != len(args.naflex_patch_sizes): + parser.error('--naflex-patch-size-probs must have same length as --naflex-patch-sizes') + patch_loader_kwargs['patch_size_choice_probs'] = args.naflex_patch_size_probs + if utils.is_primary(args): + _logger.info(f'Using variable patch sizes: {args.naflex_patch_sizes}') + else: + # Single patch size mode - use model's patch size + patch_loader_kwargs['patch_size'] = model_patch_size + if utils.is_primary(args): + _logger.info(f'Using model patch size: {model_patch_size}') + + naflex_mode = True + loader_train = create_naflex_loader( + dataset=dataset_train, + train_seq_lens=args.naflex_train_seq_lens, + mixup_fn=naflex_mixup_fn, + rank=args.rank, + world_size=args.world_size, + **patch_loader_kwargs, + **common_loader_kwargs, + **train_loader_kwargs, + ) + else: + # setup mixup / cutmix + collate_fn = None + if mixup_active: + if args.prefetcher: + assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup) + collate_fn = FastCollateMixup(**mixup_args) + else: + mixup_fn = Mixup(**mixup_args) + + # wrap dataset in AugMix helper + if num_aug_splits > 1: + dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) + + # Use standard loader + loader_train = create_loader( + dataset_train, + input_size=data_config['input_size'], + collate_fn=collate_fn, + use_multi_epochs_loader=args.use_multi_epochs_loader, + **common_loader_kwargs, + **train_loader_kwargs, + ) + loader_eval = None if args.val_split: + assert dataset_eval is not None eval_workers = args.workers if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset): # FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training eval_workers = min(2, args.workers) - loader_eval = create_loader( - dataset_eval, - input_size=data_config['input_size'], + + eval_loader_kwargs = dict( batch_size=args.validation_batch_size or args.batch_size, is_training=False, interpolation=data_config['interpolation'], - mean=data_config['mean'], - std=data_config['std'], num_workers=eval_workers, - distributed=args.distributed, crop_pct=data_config['crop_pct'], - pin_memory=args.pin_mem, - img_dtype=model_dtype or torch.float32, - device=device, - use_prefetcher=args.prefetcher, ) + if args.naflex_loader: + # Use largest sequence length for validation + loader_eval = create_naflex_loader( + dataset=dataset_eval, + patch_size=model_patch_size, # Use model's native patch size (already determined above) + max_seq_len=args.naflex_max_seq_len, + **common_loader_kwargs, + **eval_loader_kwargs + ) + else: + # Use standard loader + loader_eval = create_loader( + dataset_eval, + input_size=data_config['input_size'], + **common_loader_kwargs, + **eval_loader_kwargs, + ) + # setup loss function if args.jsd_loss: assert num_aug_splits > 1 # JSD only valid with aug splits set @@ -902,6 +996,7 @@ def main(): model_ema=model_ema, mixup_fn=mixup_fn, num_updates_total=num_epochs * updates_per_epoch, + naflex_mode=naflex_mode, ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): @@ -1004,6 +1099,7 @@ def train_one_epoch( model_ema=None, mixup_fn=None, num_updates_total=None, + naflex_mode=False, ): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -1049,10 +1145,10 @@ def train_one_epoch( def _forward(): with amp_autocast(): output = model(input) - loss = loss_fn(output, target) + _loss = loss_fn(output, target) if accum_steps > 1: - loss /= accum_steps - return loss + _loss /= accum_steps + return _loss def _backward(_loss): if loss_scaler is not None: @@ -1076,16 +1172,53 @@ def _backward(_loss): ) optimizer.step() - if has_no_sync and not need_update: - with model.no_sync(): + if naflex_mode: + assert isinstance(input, dict) + batch_size = input['patches'].shape[0] + + # scale gradient vs the minimum batch size (for max seq len) + if not args.naflex_loss_scale or args.naflex_loss_scale == 'none': + local_scale = 1.0 + else: + local_scale = (batch_size / args.batch_size) + if local_scale == 'sqrt': + local_scale = local_scale ** 0.5 + + if args.distributed: + # scale gradient btw distributed ranks, each one can have different batch size + global_batch_size = utils.reduce_tensor( + torch.tensor(batch_size, device=device, dtype=torch.float32), + 1 # SUM + ) + dist_scale = args.world_size * batch_size / global_batch_size + else: + dist_scale = None + + if has_no_sync and not need_update: + with model.no_sync(): + loss = _forward() + scaled_loss = local_scale * loss + if dist_scale is not None: + scaled_loss *= dist_scale + _backward(scaled_loss) + else: loss = _forward() - _backward(loss) + scaled_loss = local_scale * loss + if dist_scale is not None: + scaled_loss *= dist_scale + _backward(scaled_loss) else: - loss = _forward() - _backward(loss) + batch_size = input.shape[0] + if has_no_sync and not need_update: + with model.no_sync(): + loss = _forward() + _backward(loss) + else: + loss = _forward() + _backward(loss) - losses_m.update(loss.item() * accum_steps, input.size(0)) - update_sample_count += input.size(0) + losses_m.update(loss.item() * accum_steps, batch_size) + update_sample_count += batch_size if not need_update: data_start_time = time.time() @@ -1102,7 +1235,8 @@ def _backward(_loss): elif device.type == 'npu': torch.npu.synchronize() time_now = time.time() - update_time_m.update(time.time() - update_start_time) + + update_time_m.update((time.time() - update_start_time) / update_sample_count, update_sample_count) update_start_time = time_now if update_idx % args.log_interval == 0: @@ -1121,8 +1255,8 @@ def _backward(_loss): f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} ' f'({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)] ' f'Loss: {loss_now:#.3g} ({loss_avg:#.3g}) ' - f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s ' - f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) ' + f'Time: {update_time_m.val:.3f}s, {1 / update_time_m.val:>7.2f}/s ' + f'({update_time_m.avg:.3f}s, {1 / update_time_m.avg:>7.2f}/s) ' f'LR: {lr:.3e} ' f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})' ) @@ -1211,9 +1345,10 @@ def validate( elif device.type == "npu": torch.npu.synchronize() - losses_m.update(reduced_loss.item(), input.size(0)) - top1_m.update(acc1.item(), output.size(0)) - top5_m.update(acc5.item(), output.size(0)) + batch_size = output.shape[0] + losses_m.update(reduced_loss.item(), batch_size) + top1_m.update(acc1.item(), batch_size) + top5_m.update(acc5.item(), batch_size) batch_time_m.update(time.time() - end) end = time.time() diff --git a/validate.py b/validate.py index f757855e4e..59e78a91fd 100755 --- a/validate.py +++ b/validate.py @@ -158,6 +158,12 @@ parser.add_argument('--retry', default=False, action='store_true', help='Enable batch size decay & retry for single model validation') +# NaFlex loader arguments +parser.add_argument('--naflex-loader', action='store_true', default=False, + help='Use NaFlex loader (Requires NaFlex compatible model)') +parser.add_argument('--naflex-max-seq-len', type=int, default=576, + help='Fixed maximum sequence length for NaFlex loader (validation)') + def validate(args): # might as well try to validate something @@ -293,23 +299,43 @@ def validate(args): real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] - loader = create_loader( - dataset, - input_size=data_config['input_size'], - batch_size=args.batch_size, - use_prefetcher=args.prefetcher, - interpolation=data_config['interpolation'], - mean=data_config['mean'], - std=data_config['std'], - num_workers=args.workers, - crop_pct=crop_pct, - crop_mode=data_config['crop_mode'], - crop_border_pixels=args.crop_border_pixels, - pin_memory=args.pin_mem, - device=device, - img_dtype=model_dtype or torch.float32, - tf_preprocessing=args.tf_preprocessing, - ) + if args.naflex_loader: + from timm.data import create_naflex_loader + loader = create_naflex_loader( + dataset, + batch_size=args.batch_size, + use_prefetcher=args.prefetcher, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=crop_pct, + crop_mode=data_config['crop_mode'], + crop_border_pixels=args.crop_border_pixels, + pin_memory=args.pin_mem, + device=device, + img_dtype=model_dtype or torch.float32, + patch_size=16, # Could be derived from model config + max_seq_len=args.naflex_max_seq_len, + ) + else: + loader = create_loader( + dataset, + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=args.prefetcher, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=crop_pct, + crop_mode=data_config['crop_mode'], + crop_border_pixels=args.crop_border_pixels, + pin_memory=args.pin_mem, + device=device, + img_dtype=model_dtype or torch.float32, + tf_preprocessing=args.tf_preprocessing, + ) batch_time = AverageMeter() losses = AverageMeter() @@ -345,10 +371,11 @@ def validate(args): real_labels.add_result(output) # measure accuracy and record loss + batch_size = output.shape[0] acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) - losses.update(loss.item(), input.size(0)) - top1.update(acc1.item(), input.size(0)) - top5.update(acc5.item(), input.size(0)) + losses.update(loss.item(), batch_size) + top1.update(acc1.item(), batch_size) + top5.update(acc5.item(), batch_size) # measure elapsed time batch_time.update(time.time() - end) @@ -364,7 +391,7 @@ def validate(args): batch_idx, len(loader), batch_time=batch_time, - rate_avg=input.size(0) / batch_time.avg, + rate_avg=batch_size / batch_time.avg, loss=losses, top1=top1, top5=top5