diff --git a/tinyimagenetloader.py b/tinyimagenetloader.py index d5362f2..af5e59b 100644 --- a/tinyimagenetloader.py +++ b/tinyimagenetloader.py @@ -39,7 +39,13 @@ def __getitem__(self, idx): image = read_image(img_path) if image.shape[0] == 1: image = read_image(img_path,ImageReadMode.RGB) - label = self.id_dict[img_path.split('/')[4]] + # label = self.id_dict[img_path.split('/')[4]] + + norm_path = os.path.normpath(img_path) + path_parts = norm_path.split(os.path.sep) + class_id = path_parts[-3] + label = self.id_dict[class_id] + if self.transform: image = self.transform(image.type(torch.FloatTensor)) return image, label @@ -64,7 +70,10 @@ def __getitem__(self, idx): image = read_image(img_path) if image.shape[0] == 1: image = read_image(img_path,ImageReadMode.RGB) - label = self.cls_dic[img_path.split('/')[-1]] + # label = self.cls_dic[img_path.split('/')[-1]] + + filename = os.path.basename(img_path) + label = self.cls_dic[filename] if self.transform: image = self.transform(image.type(torch.FloatTensor)) return image, label