|
1 | 1 | import os |
2 | | -import sys |
3 | | -import tempfile |
4 | 2 | import torchvision.datasets.utils as utils |
5 | 3 | import unittest |
6 | 4 | import unittest.mock |
@@ -102,62 +100,95 @@ def test_download_url_dispatch_download_from_google_drive(self, mock): |
102 | 100 |
|
103 | 101 | mock.assert_called_once_with(id, root, filename, md5) |
104 | 102 |
|
105 | | - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') |
106 | 103 | def test_extract_zip(self): |
| 104 | + def create_archive(root, content="this is the content"): |
| 105 | + file = os.path.join(root, "dst.txt") |
| 106 | + archive = os.path.join(root, "archive.zip") |
| 107 | + |
| 108 | + with zipfile.ZipFile(archive, "w") as zf: |
| 109 | + zf.writestr(os.path.basename(file), content) |
| 110 | + |
| 111 | + return archive, file, content |
| 112 | + |
107 | 113 | with get_tmp_dir() as temp_dir: |
108 | | - with tempfile.NamedTemporaryFile(suffix='.zip') as f: |
109 | | - with zipfile.ZipFile(f, 'w') as zf: |
110 | | - zf.writestr('file.tst', 'this is the content') |
111 | | - utils.extract_archive(f.name, temp_dir) |
112 | | - self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst'))) |
113 | | - with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: |
114 | | - data = nf.read() |
115 | | - self.assertEqual(data, 'this is the content') |
116 | | - |
117 | | - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') |
| 114 | + archive, file, content = create_archive(temp_dir) |
| 115 | + |
| 116 | + utils.extract_archive(archive, temp_dir) |
| 117 | + |
| 118 | + self.assertTrue(os.path.exists(file)) |
| 119 | + |
| 120 | + with open(file, "r") as fh: |
| 121 | + self.assertEqual(fh.read(), content) |
| 122 | + |
118 | 123 | def test_extract_tar(self): |
| 124 | + def create_archive(root, ext, mode, content="this is the content"): |
| 125 | + src = os.path.join(root, "src.txt") |
| 126 | + dst = os.path.join(root, "dst.txt") |
| 127 | + archive = os.path.join(root, f"archive{ext}") |
| 128 | + |
| 129 | + with open(src, "w") as fh: |
| 130 | + fh.write(content) |
| 131 | + |
| 132 | + with tarfile.open(archive, mode=mode) as fh: |
| 133 | + fh.add(src, arcname=os.path.basename(dst)) |
| 134 | + |
| 135 | + return archive, dst, content |
| 136 | + |
119 | 137 | for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']): |
120 | 138 | with get_tmp_dir() as temp_dir: |
121 | | - with tempfile.NamedTemporaryFile() as bf: |
122 | | - bf.write("this is the content".encode()) |
123 | | - bf.seek(0) |
124 | | - with tempfile.NamedTemporaryFile(suffix=ext) as f: |
125 | | - with tarfile.open(f.name, mode=mode) as zf: |
126 | | - zf.add(bf.name, arcname='file.tst') |
127 | | - utils.extract_archive(f.name, temp_dir) |
128 | | - self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst'))) |
129 | | - with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: |
130 | | - data = nf.read() |
131 | | - self.assertEqual(data, 'this is the content') |
132 | | - |
133 | | - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') |
| 139 | + archive, file, content = create_archive(temp_dir, ext, mode) |
| 140 | + |
| 141 | + utils.extract_archive(archive, temp_dir) |
| 142 | + |
| 143 | + self.assertTrue(os.path.exists(file)) |
| 144 | + |
| 145 | + with open(file, "r") as fh: |
| 146 | + self.assertEqual(fh.read(), content) |
| 147 | + |
134 | 148 | def test_extract_tar_xz(self): |
| 149 | + def create_archive(root, ext, mode, content="this is the content"): |
| 150 | + src = os.path.join(root, "src.txt") |
| 151 | + dst = os.path.join(root, "dst.txt") |
| 152 | + archive = os.path.join(root, f"archive{ext}") |
| 153 | + |
| 154 | + with open(src, "w") as fh: |
| 155 | + fh.write(content) |
| 156 | + |
| 157 | + with tarfile.open(archive, mode=mode) as fh: |
| 158 | + fh.add(src, arcname=os.path.basename(dst)) |
| 159 | + |
| 160 | + return archive, dst, content |
| 161 | + |
135 | 162 | for ext, mode in zip(['.tar.xz'], ['w:xz']): |
136 | 163 | with get_tmp_dir() as temp_dir: |
137 | | - with tempfile.NamedTemporaryFile() as bf: |
138 | | - bf.write("this is the content".encode()) |
139 | | - bf.seek(0) |
140 | | - with tempfile.NamedTemporaryFile(suffix=ext) as f: |
141 | | - with tarfile.open(f.name, mode=mode) as zf: |
142 | | - zf.add(bf.name, arcname='file.tst') |
143 | | - utils.extract_archive(f.name, temp_dir) |
144 | | - self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst'))) |
145 | | - with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: |
146 | | - data = nf.read() |
147 | | - self.assertEqual(data, 'this is the content') |
148 | | - |
149 | | - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') |
| 164 | + archive, file, content = create_archive(temp_dir, ext, mode) |
| 165 | + |
| 166 | + utils.extract_archive(archive, temp_dir) |
| 167 | + |
| 168 | + self.assertTrue(os.path.exists(file)) |
| 169 | + |
| 170 | + with open(file, "r") as fh: |
| 171 | + self.assertEqual(fh.read(), content) |
| 172 | + |
150 | 173 | def test_extract_gzip(self): |
| 174 | + def create_compressed(root, content="this is the content"): |
| 175 | + file = os.path.join(root, "file") |
| 176 | + compressed = f"{file}.gz" |
| 177 | + |
| 178 | + with gzip.GzipFile(compressed, "wb") as fh: |
| 179 | + fh.write(content.encode()) |
| 180 | + |
| 181 | + return compressed, file, content |
| 182 | + |
151 | 183 | with get_tmp_dir() as temp_dir: |
152 | | - with tempfile.NamedTemporaryFile(suffix='.gz') as f: |
153 | | - with gzip.GzipFile(f.name, 'wb') as zf: |
154 | | - zf.write('this is the content'.encode()) |
155 | | - utils.extract_archive(f.name, temp_dir) |
156 | | - f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0]) |
157 | | - self.assertTrue(os.path.exists(f_name)) |
158 | | - with open(os.path.join(f_name), 'r') as nf: |
159 | | - data = nf.read() |
160 | | - self.assertEqual(data, 'this is the content') |
| 184 | + compressed, file, content = create_compressed(temp_dir) |
| 185 | + |
| 186 | + utils.extract_archive(compressed, temp_dir) |
| 187 | + |
| 188 | + self.assertTrue(os.path.exists(file)) |
| 189 | + |
| 190 | + with open(file, "r") as fh: |
| 191 | + self.assertEqual(fh.read(), content) |
161 | 192 |
|
162 | 193 | def test_verify_str_arg(self): |
163 | 194 | self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",))) |
|
0 commit comments