Skip to content

Commit 05dcf50

Browse files
pmeierprabhat00155
andauthored
use helper function to extract archive in CelebA (#4557)
Co-authored-by: Prabhat Roy <[email protected]>
1 parent 32df801 commit 05dcf50

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

torchvision/datasets/celeba.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import PIL
77
import torch
88

9-
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg
9+
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive
1010
from .vision import VisionDataset
1111

1212
CSV = namedtuple("CSV", ["header", "index", "data"])
@@ -142,17 +142,14 @@ def _check_integrity(self) -> bool:
142142
return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
143143

144144
def download(self) -> None:
145-
import zipfile
146-
147145
if self._check_integrity():
148146
print("Files already downloaded and verified")
149147
return
150148

151149
for (file_id, md5, filename) in self.file_list:
152150
download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
153151

154-
with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
155-
f.extractall(os.path.join(self.root, self.base_folder))
152+
extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))
156153

157154
def __getitem__(self, index: int) -> Tuple[Any, Any]:
158155
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))

0 commit comments

Comments
 (0)