diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 14d484ba9f..f2710fc0d8 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -30,6 +30,7 @@ def __init__( input_img_mode='RGB', transform=None, target_transform=None, + additional_features=None, **kwargs, ): if reader is None or isinstance(reader, str): @@ -38,6 +39,7 @@ def __init__( root=root, split=split, class_map=class_map, + additional_features=additional_features, **kwargs, ) self.reader = reader @@ -45,10 +47,11 @@ def __init__( self.input_img_mode = input_img_mode self.transform = transform self.target_transform = target_transform + self.additional_features = additional_features self._consecutive_errors = 0 def __getitem__(self, index): - img, target = self.reader[index] + img, target, *features = self.reader[index] try: img = img.read() if self.load_bytes else Image.open(img) @@ -71,7 +74,10 @@ def __getitem__(self, index): elif self.target_transform is not None: target = self.target_transform(target) - return img, target + if self.additional_features is None: + return img, target + else: + return img, target, *features def __len__(self): return len(self.reader) diff --git a/timm/data/readers/reader_factory.py b/timm/data/readers/reader_factory.py index cfe1910f5a..623fe6e5c8 100644 --- a/timm/data/readers/reader_factory.py +++ b/timm/data/readers/reader_factory.py @@ -19,11 +19,14 @@ def create_reader( prefix = name[0] name = name[-1] + # FIXME the additional features are only supported by ReaderHfds for now. + additional_features = kwargs.pop("additional_features", None) + # FIXME improve the selection right now just tfds prefix or fallback path, will need options to # explicitly select other options shortly if prefix == 'hfds': from .reader_hfds import ReaderHfds # defer Hf datasets import - reader = ReaderHfds(name=name, root=root, split=split, **kwargs) + reader = ReaderHfds(name=name, root=root, split=split, additional_features=additional_features, **kwargs) elif prefix == 'hfids': from .reader_hfids import ReaderHfids # defer HF datasets import reader = ReaderHfids(name=name, root=root, split=split, **kwargs) diff --git a/timm/data/readers/reader_hfds.py b/timm/data/readers/reader_hfds.py index 13f8e24488..f7f552039c 100644 --- a/timm/data/readers/reader_hfds.py +++ b/timm/data/readers/reader_hfds.py @@ -37,6 +37,7 @@ def __init__( class_map: dict = None, input_key: str = 'image', target_key: str = 'label', + additional_features: Optional[list[str]] = None, download: bool = False, trust_remote_code: bool = False ): @@ -65,18 +66,33 @@ def __init__( self.split_info = self.dataset.info.splits[split] self.num_samples = self.split_info.num_examples + if additional_features is not None: + if isinstance(additional_features, list): + self.additional_features = additional_features + else: + self.additional_features = [additional_features] + else: + self.additional_features = None + def __getitem__(self, index): item = self.dataset[index] image = item[self.image_key] + if 'bytes' in image and image['bytes']: image = io.BytesIO(image['bytes']) else: assert 'path' in image and image['path'] image = open(image['path'], 'rb') + label = item[self.label_key] if self.remap_class: label = self.class_to_idx[label] - return image, label + + if self.additional_features is not None: + features = [item[feat] for feat in self.additional_features] + return image, label, *features + else: + return image, label def __len__(self): return len(self.dataset)