|
6 | 6 | import PIL |
7 | 7 | import torch |
8 | 8 |
|
9 | | -from .utils import download_file_from_google_drive, check_integrity, verify_str_arg |
| 9 | +from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive |
10 | 10 | from .vision import VisionDataset |
11 | 11 |
|
12 | 12 | CSV = namedtuple("CSV", ["header", "index", "data"]) |
@@ -142,17 +142,14 @@ def _check_integrity(self) -> bool: |
142 | 142 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) |
143 | 143 |
|
144 | 144 | def download(self) -> None: |
145 | | - import zipfile |
146 | | - |
147 | 145 | if self._check_integrity(): |
148 | 146 | print("Files already downloaded and verified") |
149 | 147 | return |
150 | 148 |
|
151 | 149 | for (file_id, md5, filename) in self.file_list: |
152 | 150 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) |
153 | 151 |
|
154 | | - with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: |
155 | | - f.extractall(os.path.join(self.root, self.base_folder)) |
| 152 | + extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip")) |
156 | 153 |
|
157 | 154 | def __getitem__(self, index: int) -> Tuple[Any, Any]: |
158 | 155 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) |
|
0 commit comments