Skip to content

Commit 90ccc57

Browse files
committed
Make walk_files traverse in alphabetically breadth-first order.
1 parent 3cdcd7b commit 90ccc57

File tree

4 files changed

+53
-8
lines changed

4 files changed

+53
-8
lines changed

test/datasets/libritts_test.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,13 @@ def setUpClass(cls):
4949

5050
def test_libritts(self):
5151
dataset = LIBRITTS(self.root_dir)
52-
samples = list(dataset)
53-
samples.sort(key=lambda s: s[4])
54-
5552
for i, (waveform,
5653
sample_rate,
5754
original_text,
5855
normalized_text,
5956
speaker_id,
6057
chapter_id,
61-
utterance_id) in enumerate(samples):
58+
utterance_id) in enumerate(dataset):
6259

6360
expected_ids = self.utterance_ids[i]
6461
expected_data = self.data[i]

test/datasets/utils_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
from pathlib import Path
3+
4+
from torchaudio.datasets import utils as dataset_utils
5+
6+
from ..common_utils import (
7+
TempDirMixin,
8+
TorchaudioTestCase,
9+
)
10+
11+
12+
class TestWalkFiles(TempDirMixin, TorchaudioTestCase):
13+
root = None
14+
expected = None
15+
16+
def _add_file(self, *parts):
17+
path = self.get_temp_path(*parts)
18+
self.expected.append(path)
19+
Path(path).touch()
20+
21+
def setUp(self):
22+
self.root = self.get_temp_path()
23+
self.expected = []
24+
25+
# level 1
26+
for filename in ['a.txt', 'b.txt', 'c.txt']:
27+
self._add_file(filename)
28+
29+
# level 2
30+
for dir1 in ['d1', 'd2', 'd3']:
31+
for filename in ['d.txt', 'e.txt', 'f.txt']:
32+
self._add_file(dir1, filename)
33+
# level 3
34+
for dir2 in ['d1', 'd2', 'd3']:
35+
for filename in ['g.txt', 'h.txt', 'i.txt']:
36+
self._add_file(dir1, dir2, filename)
37+
38+
print('\n'.join(self.expected))
39+
40+
def test_walk_files(self):
41+
"""walk_files should traverse files in alphabetical order"""
42+
for i, path in enumerate(dataset_utils.walk_files(self.root, '.txt', prefix=True)):
43+
found = os.path.join(self.root, path)
44+
assert found == self.expected[i]

test/datasets/yesno_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ def setUpClass(cls):
3838

3939
def test_yesno(self):
4040
dataset = yesno.YESNO(self.root_dir)
41-
samples = list(dataset)
42-
samples.sort(key=lambda s: s[2])
43-
for i, (waveform, sample_rate, label) in enumerate(samples):
41+
for i, (waveform, sample_rate, label) in enumerate(dataset):
4442
expected_label = self.labels[i]
4543
expected_data = self.data[i]
4644
self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)

torchaudio/datasets/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,13 @@ def walk_files(root: str,
264264

265265
root = os.path.expanduser(root)
266266

267-
for dirpath, _, files in os.walk(root):
267+
for dirpath, dirs, files in os.walk(root):
268+
dirs.sort()
269+
# `dirs` is the list used in os.walk function and by sorting it in-place here, we change the
270+
# behavior of os.walk to traverse sub directory alphabetically
271+
# see also
272+
# https://stackoverflow.com/questions/6670029/can-i-force-python3s-os-walk-to-visit-directories-in-alphabetical-order-how#comment71993866_6670926
273+
files.sort()
268274
for f in files:
269275
if f.endswith(suffix):
270276

0 commit comments

Comments
 (0)