diff --git a/README.md b/README.md index e1746928ba..82ea5d08b6 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,25 @@ ## What's New +## June 5, 2025 +* Initial NaFlexVit model code. NaFlexVit is a Vision Transformer with: + 1. Encapsulated embedding and position encoding in a single module + 2. Support for nn.Linear patch embedding on pre-patchified (dictionary) inputs + 3. Support for NaFlex variable aspect, variable resolution (SigLip-2: https://arxiv.org/abs/2502.14786) + 4. Support for FlexiViT variable patch size (https://arxiv.org/abs/2212.08013) + 5. Support for NaViT fractional/factorized position embedding (https://arxiv.org/abs/2307.06304) +* Existing vit models in `vision_transformer.py` can be loaded into the NaFlexVit model by adding the `use_naflex=True` flag to `create_model` + * Some native weights coming soon +* A full NaFlex data pipeline is available that allows training / fine-tuning / evaluating with variable aspect / size images + * To enable in `train.py` and `validate.py` add the `--naflex-loader` arg, must be used with a NaFlexVit +* To evaluate an existing (classic) ViT loaded in NaFlexVit model w/ NaFlex data pipe: + * `python validate.py /imagenet --amp -j 8 --model vit_base_patch16_224 --model-kwargs use_naflex=True --naflex-loader --naflex-max-seq-len 256` +* The training has some extra args features worth noting + * The `--naflex-train-seq-lens'` argument specifies which sequence lengths to randomly pick from per batch during training + * The `--naflex-max-seq-len` argument sets the target sequence length for validation + * Adding `--model-kwargs enable_patch_interpolator=True --naflex-patch-sizes 12 16 24` will enable random patch size selection per-batch w/ interpolation + * The `--naflex-loss-scale` arg changes loss scaling mode per batch relative to the batch size, `timm` NaFlex loading changes the batch size for each seq len + ## May 28, 2025 * Add a number of small/fast models thanks to https://github.com/brianhou0208 * SwiftFormer - [(ICCV2023) SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications](https://github.com/Amshaker/SwiftFormer) diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py index e32e94f396..ac1d964c95 100644 --- a/timm/models/naflexvit.py +++ b/timm/models/naflexvit.py @@ -1192,27 +1192,16 @@ def _pool( patch_valid: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.attn_pool is not None: - # For attention pooling, we need to pass the mask for NaFlex models - if self.pool_include_prefix: - # Include all tokens in attention pooling - create mask for all tokens including prefix - attn_mask = create_attention_mask( - patch_valid, - num_prefix_tokens=self.num_prefix_tokens, - symmetric=False, - q_len=1, - dtype=x.dtype, - ) - x = self.attn_pool(x, attn_mask=attn_mask) - else: - # Exclude prefix tokens from attention pooling (default behavior) - attn_mask = create_attention_mask( - patch_valid, - num_prefix_tokens=0, # No prefix tokens when we slice them off - symmetric=False, - q_len=1, - dtype=x.dtype, - ) - x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask) + attn_mask = create_attention_mask( + patch_valid, + num_prefix_tokens=self.num_prefix_tokens if self.pool_include_prefix else 0, + symmetric=False, + q_len=1, + dtype=x.dtype, + ) + if not self.pool_include_prefix: + x = x[:, self.num_prefix_tokens:] + x = self.attn_pool(x, attn_mask=attn_mask) return x pool_type = self.global_pool if pool_type is None else pool_type