Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
0893f5d
Initial NaFlex ViT model and training support
rwightman Apr 8, 2025
825edcc
Type fixes, remove old comments
rwightman Apr 8, 2025
9b23d6d
Exclude naflex models from jit tests
rwightman Apr 8, 2025
6675590
Fix ParallelThingsBlock w/ attn_mask
rwightman Apr 8, 2025
13e0f3a
Add loss scale arg, initial distributed loss scale. Maybe fix FX for …
rwightman Apr 9, 2025
b4bb0f4
Exclude embeds module and mask attn functions from tracing
rwightman Apr 9, 2025
97341fe
A much faster resample_patch_embed, can be used at train/validation time
rwightman Apr 10, 2025
ea728f6
Improve several typing issues for flex vit, can (almost) work with ji…
rwightman Apr 14, 2025
c527c37
Optimizations for pos embed resize, merge different mask helper fns
rwightman Apr 21, 2025
3dc90ed
Add naflex loader support to validate.py, fix bug in naflex pos embed…
rwightman Apr 25, 2025
ee27b73
Further pos embed tweaks, rejig model defs for testing
rwightman Apr 28, 2025
39eb56f
Starting to test distributed train, fix issue with batch_size reduce
rwightman Apr 28, 2025
e2073e3
Move NaFlexCollate with dataset, remove stand alone collate_fn and re…
rwightman Apr 29, 2025
8fcbceb
Add a WIP NaFlex compatible mixup/cutmix for testing
rwightman May 10, 2025
7624389
Mixup cleanup, add prob support and train script integration. Add wor…
rwightman May 20, 2025
f001b15
NaFlex random erasing performance improvements, python loops were slo…
rwightman May 21, 2025
7bfe606
Merge remote-tracking branch 'origin/main' into naflex
rwightman May 23, 2025
d7d3538
Add so400m model size for test, few tweaks.
rwightman May 24, 2025
2ad75e8
Fix issue w/ MAP attention mask and no patch_valid
rwightman May 24, 2025
162f492
Move naflex global pool into one fn that can be marked notrace
rwightman May 24, 2025
dd2c141
Fix tracing of attention module with attn_mask support
rwightman May 25, 2025
842a786
A few more maybe_add_mask situations
rwightman May 25, 2025
b7ced7c
torch.fx.wrap not working with older pytorch, trying register_notrace…
rwightman May 25, 2025
72858c1
Add siglip2 compatible naflex encoders. Add support to factorized pos…
rwightman May 30, 2025
fe2867c
Significant naflex refactor. Rename classes, models. Support flag for…
rwightman Jun 3, 2025
2bf71f5
Merge remote-tracking branch 'origin/main' into naflex
rwightman Jun 3, 2025
b3ca8fd
Add naflex vit exceptions to tests
rwightman Jun 4, 2025
dd3b96c
Fix features intermediates for NCHW inputs, patch variable size input…
rwightman Jun 4, 2025
d78cbf4
Rename dataset wrapper to NaFlexMapDatasetWrapper
rwightman Jun 4, 2025
0d43942
Add variable patch size to naflex training, improve patch size arg ha…
rwightman Jun 5, 2025
dac2ec6
Add missing patch embed interpolator
rwightman Jun 5, 2025
4ff865c
A bit of docstring and comment consistency cleanup, remove some debug…
rwightman Jun 5, 2025
99a09eb
Update old FastCollateMixup to accept torch tensor inputs instead of …
rwightman Jun 5, 2025
a0b5bcc
Fix another low use path where only numpy arrays are supported
rwightman Jun 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt',
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet',
'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet', 'naflexvit'
]

# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'vit_*', 'naflexvit*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', 'swiftformer_*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
Expand All @@ -81,7 +81,7 @@
EXCLUDE_FILTERS = ['*enormous*']
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*']

EXCLUDE_JIT_FILTERS = ['hiera_*']
EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*']

TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
TARGET_BWD_SIZE = 128
Expand Down
12 changes: 12 additions & 0 deletions timm/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
from .loader import create_loader
from .mixup import Mixup, FastCollateMixup
from .naflex_dataset import NaFlexMapDatasetWrapper, calculate_naflex_batch_size
from .naflex_loader import create_naflex_loader
from .naflex_mixup import NaFlexMixup, pairwise_mixup_target, mix_batch_variable_size
from .naflex_transforms import (
ResizeToSequence,
CenterCropToSequence,
RandomCropToSequence,
RandomResizedCropToSequence,
ResizeKeepRatioToSequence,
Patchify,
patchify_image,
)
from .readers import create_reader
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
from .real_labels import RealLabelsImagenet
Expand Down
6 changes: 5 additions & 1 deletion timm/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def fast_collate(batch):
if isinstance(batch[0][0], tuple):
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
is_np = isinstance(batch[0][0], np.ndarray)
inner_tuple_size = len(batch[0][0])
flattened_batch_size = batch_size * inner_tuple_size
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
Expand All @@ -41,7 +42,10 @@ def fast_collate(batch):
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
for j in range(inner_tuple_size):
targets[i + j * batch_size] = batch[i][1]
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
if is_np:
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
else:
tensor[i + j * batch_size] += batch[i][0][j]
return tensor, targets
elif isinstance(batch[0][0], np.ndarray):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
Expand Down
71 changes: 52 additions & 19 deletions timm/data/mixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,29 +229,41 @@ def _mix_elem_collate(self, output, batch, half=False):
num_elem = batch_size // 2 if half else batch_size
assert len(output) == num_elem
lam_batch, use_cutmix = self._params_per_elem(num_elem)
is_np = isinstance(batch[0][0], np.ndarray)

for i in range(num_elem):
j = batch_size - i - 1
lam = lam_batch[i]
mixed = batch[i][0]
if lam != 1.:
if use_cutmix[i]:
if not half:
mixed = mixed.copy()
mixed = mixed.copy() if is_np else mixed.clone()
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
output.shape,
lam,
ratio_minmax=self.cutmix_minmax,
correct_lam=self.correct_lam,
)
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
np.rint(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8))
if is_np:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
np.rint(mixed, out=mixed)
else:
mixed = mixed.float() * lam + batch[j][0].float() * (1 - lam)
torch.round(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8)) if is_np else mixed.byte()
if half:
lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
return torch.tensor(lam_batch).unsqueeze(1)

def _mix_pair_collate(self, output, batch):
batch_size = len(batch)
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
is_np = isinstance(batch[0][0], np.ndarray)

for i in range(batch_size // 2):
j = batch_size - i - 1
lam = lam_batch[i]
Expand All @@ -261,39 +273,60 @@ def _mix_pair_collate(self, output, batch):
if lam < 1.:
if use_cutmix[i]:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
patch_i = mixed_i[:, yl:yh, xl:xh].copy()
output.shape,
lam,
ratio_minmax=self.cutmix_minmax,
correct_lam=self.correct_lam,
)
patch_i = mixed_i[:, yl:yh, xl:xh].copy() if is_np else mixed_i[:, yl:yh, xl:xh].clone()
mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
mixed_j[:, yl:yh, xl:xh] = patch_i
lam_batch[i] = lam
else:
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
mixed_i = mixed_temp
np.rint(mixed_j, out=mixed_j)
np.rint(mixed_i, out=mixed_i)
output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
if is_np:
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
mixed_i = mixed_temp
np.rint(mixed_j, out=mixed_j)
np.rint(mixed_i, out=mixed_i)
else:
mixed_temp = mixed_i.float() * lam + mixed_j.float() * (1 - lam)
mixed_j = mixed_j.float() * lam + mixed_i.float() * (1 - lam)
mixed_i = mixed_temp
torch.round(mixed_j, out=mixed_j)
torch.round(mixed_i, out=mixed_i)
output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) if is_np else mixed_i.byte()
output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) if is_np else mixed_j.byte()
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
return torch.tensor(lam_batch).unsqueeze(1)

def _mix_batch_collate(self, output, batch):
batch_size = len(batch)
lam, use_cutmix = self._params_per_batch()
is_np = isinstance(batch[0][0], np.ndarray)

if use_cutmix:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
output.shape,
lam,
ratio_minmax=self.cutmix_minmax,
correct_lam=self.correct_lam,
)
for i in range(batch_size):
j = batch_size - i - 1
mixed = batch[i][0]
if lam != 1.:
if use_cutmix:
mixed = mixed.copy() # don't want to modify the original while iterating
mixed = mixed.copy() if is_np else mixed.clone() # don't want to modify the original while iterating
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
else:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
np.rint(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8))
if is_np:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
np.rint(mixed, out=mixed)
else:
mixed = mixed.float() * lam + batch[j][0].float() * (1 - lam)
torch.round(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8)) if is_np else mixed.byte()
return lam

def __call__(self, batch, _=None):
Expand Down
Loading