Skip to content

Commit c38ba86

Browse files
authored
Fix dataloader tutorial (#1542)
1 parent fd97ed4 commit c38ba86

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

beginner_source/basics/data_tutorial.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,7 @@ def __getitem__(self, idx):
140140
image = self.transform(image)
141141
if self.target_transform:
142142
label = self.target_transform(label)
143-
sample = {"image": image, "label": label}
144-
return sample
143+
return image, label
145144

146145

147146
#################################################################
@@ -187,7 +186,7 @@ def __len__(self):
187186
# The __getitem__ function loads and returns a sample from the dataset at the given index ``idx``.
188187
# Based on the index, it identifies the image's location on disk, converts that to a tensor using ``read_image``, retrieves the
189188
# corresponding label from the csv data in ``self.img_labels``, calls the transform functions on them (if applicable), and returns the
190-
# tensor image and corresponding label in a Python dict.
189+
# tensor image and corresponding label in a tuple.
191190

192191
def __getitem__(self, idx):
193192
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
@@ -197,8 +196,7 @@ def __getitem__(self, idx):
197196
image = self.transform(image)
198197
if self.target_transform:
199198
label = self.target_transform(label)
200-
sample = {"image": image, "label": label}
201-
return sample
199+
return image, label
202200

203201

204202
######################################################################

0 commit comments

Comments
 (0)