Skip to content

Commit 61ce622

Browse files
vincentqbfmassa
authored andcommitted
[fbsync] fix test_extract_(zip|tar|tar_xz|gzip) on windows (#3542)
Summary: * fix test_extract_(zip|tar|tar_xz|gzip) on windows * lint Reviewed By: fmassa Differential Revision: D27127988 fbshipit-source-id: 62394146aef72ca5baf86ae86d52cf82f77c07aa Co-authored-by: Francisco Massa <[email protected]>
1 parent 83b8905 commit 61ce622

File tree

1 file changed

+79
-48
lines changed

1 file changed

+79
-48
lines changed

test/test_datasets_utils.py

Lines changed: 79 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import os
2-
import sys
3-
import tempfile
42
import torchvision.datasets.utils as utils
53
import unittest
64
import unittest.mock
@@ -102,62 +100,95 @@ def test_download_url_dispatch_download_from_google_drive(self, mock):
102100

103101
mock.assert_called_once_with(id, root, filename, md5)
104102

105-
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
106103
def test_extract_zip(self):
104+
def create_archive(root, content="this is the content"):
105+
file = os.path.join(root, "dst.txt")
106+
archive = os.path.join(root, "archive.zip")
107+
108+
with zipfile.ZipFile(archive, "w") as zf:
109+
zf.writestr(os.path.basename(file), content)
110+
111+
return archive, file, content
112+
107113
with get_tmp_dir() as temp_dir:
108-
with tempfile.NamedTemporaryFile(suffix='.zip') as f:
109-
with zipfile.ZipFile(f, 'w') as zf:
110-
zf.writestr('file.tst', 'this is the content')
111-
utils.extract_archive(f.name, temp_dir)
112-
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
113-
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
114-
data = nf.read()
115-
self.assertEqual(data, 'this is the content')
116-
117-
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
114+
archive, file, content = create_archive(temp_dir)
115+
116+
utils.extract_archive(archive, temp_dir)
117+
118+
self.assertTrue(os.path.exists(file))
119+
120+
with open(file, "r") as fh:
121+
self.assertEqual(fh.read(), content)
122+
118123
def test_extract_tar(self):
124+
def create_archive(root, ext, mode, content="this is the content"):
125+
src = os.path.join(root, "src.txt")
126+
dst = os.path.join(root, "dst.txt")
127+
archive = os.path.join(root, f"archive{ext}")
128+
129+
with open(src, "w") as fh:
130+
fh.write(content)
131+
132+
with tarfile.open(archive, mode=mode) as fh:
133+
fh.add(src, arcname=os.path.basename(dst))
134+
135+
return archive, dst, content
136+
119137
for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']):
120138
with get_tmp_dir() as temp_dir:
121-
with tempfile.NamedTemporaryFile() as bf:
122-
bf.write("this is the content".encode())
123-
bf.seek(0)
124-
with tempfile.NamedTemporaryFile(suffix=ext) as f:
125-
with tarfile.open(f.name, mode=mode) as zf:
126-
zf.add(bf.name, arcname='file.tst')
127-
utils.extract_archive(f.name, temp_dir)
128-
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
129-
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
130-
data = nf.read()
131-
self.assertEqual(data, 'this is the content')
132-
133-
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
139+
archive, file, content = create_archive(temp_dir, ext, mode)
140+
141+
utils.extract_archive(archive, temp_dir)
142+
143+
self.assertTrue(os.path.exists(file))
144+
145+
with open(file, "r") as fh:
146+
self.assertEqual(fh.read(), content)
147+
134148
def test_extract_tar_xz(self):
149+
def create_archive(root, ext, mode, content="this is the content"):
150+
src = os.path.join(root, "src.txt")
151+
dst = os.path.join(root, "dst.txt")
152+
archive = os.path.join(root, f"archive{ext}")
153+
154+
with open(src, "w") as fh:
155+
fh.write(content)
156+
157+
with tarfile.open(archive, mode=mode) as fh:
158+
fh.add(src, arcname=os.path.basename(dst))
159+
160+
return archive, dst, content
161+
135162
for ext, mode in zip(['.tar.xz'], ['w:xz']):
136163
with get_tmp_dir() as temp_dir:
137-
with tempfile.NamedTemporaryFile() as bf:
138-
bf.write("this is the content".encode())
139-
bf.seek(0)
140-
with tempfile.NamedTemporaryFile(suffix=ext) as f:
141-
with tarfile.open(f.name, mode=mode) as zf:
142-
zf.add(bf.name, arcname='file.tst')
143-
utils.extract_archive(f.name, temp_dir)
144-
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
145-
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
146-
data = nf.read()
147-
self.assertEqual(data, 'this is the content')
148-
149-
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
164+
archive, file, content = create_archive(temp_dir, ext, mode)
165+
166+
utils.extract_archive(archive, temp_dir)
167+
168+
self.assertTrue(os.path.exists(file))
169+
170+
with open(file, "r") as fh:
171+
self.assertEqual(fh.read(), content)
172+
150173
def test_extract_gzip(self):
174+
def create_compressed(root, content="this is the content"):
175+
file = os.path.join(root, "file")
176+
compressed = f"{file}.gz"
177+
178+
with gzip.GzipFile(compressed, "wb") as fh:
179+
fh.write(content.encode())
180+
181+
return compressed, file, content
182+
151183
with get_tmp_dir() as temp_dir:
152-
with tempfile.NamedTemporaryFile(suffix='.gz') as f:
153-
with gzip.GzipFile(f.name, 'wb') as zf:
154-
zf.write('this is the content'.encode())
155-
utils.extract_archive(f.name, temp_dir)
156-
f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
157-
self.assertTrue(os.path.exists(f_name))
158-
with open(os.path.join(f_name), 'r') as nf:
159-
data = nf.read()
160-
self.assertEqual(data, 'this is the content')
184+
compressed, file, content = create_compressed(temp_dir)
185+
186+
utils.extract_archive(compressed, temp_dir)
187+
188+
self.assertTrue(os.path.exists(file))
189+
190+
with open(file, "r") as fh:
191+
self.assertEqual(fh.read(), content)
161192

162193
def test_verify_str_arg(self):
163194
self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))

0 commit comments

Comments
 (0)