@@ -37,6 +37,7 @@ def __init__(
3737 class_map : dict = None ,
3838 input_key : str = 'image' ,
3939 target_key : str = 'label' ,
40+ additional_features : Optional [list [str ]] = None ,
4041 download : bool = False ,
4142 trust_remote_code : bool = False
4243 ):
@@ -65,9 +66,18 @@ def __init__(
6566 self .split_info = self .dataset .info .splits [split ]
6667 self .num_samples = self .split_info .num_examples
6768
69+ if isinstance (additional_features , str ):
70+ self .additional_features = [additional_features ]
71+ elif isinstance (additional_features , list ):
72+ self .additional_features = additional_features
73+ else :
74+ self .additional_features = []
75+
6876 def __getitem__ (self , index ):
6977 item = self .dataset [index ]
7078 image = item [self .image_key ]
79+ features = [item [feat ] for feat in self .additional_features ]
80+
7181 if 'bytes' in image and image ['bytes' ]:
7282 image = io .BytesIO (image ['bytes' ])
7383 else :
@@ -76,7 +86,8 @@ def __getitem__(self, index):
7686 label = item [self .label_key ]
7787 if self .remap_class :
7888 label = self .class_to_idx [label ]
79- return image , label
89+
90+ return image , label , * features
8091
8192 def __len__ (self ):
8293 return len (self .dataset )
0 commit comments