|
20 | 20 | "https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz": |
21 | 21 | "6b74f3901214cb2c2934e98196829835", |
22 | 22 | } |
| 23 | +VALIDATION_LIST = "validation_list.txt" |
| 24 | +TESTING_LIST = "testing_list.txt" |
23 | 25 |
|
24 | 26 |
|
25 | 27 | def load_speechcommands_item(filepath: str, path: str) -> Tuple[Tensor, int, str, str, int]: |
@@ -90,29 +92,22 @@ def __init__(self, |
90 | 92 | download_url(url, root, hash_value=checksum, hash_type="md5") |
91 | 93 | extract_archive(archive, self._path) |
92 | 94 |
|
93 | | - walker = walk_files(self._path, suffix=".wav", prefix=True) |
94 | | - walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker) |
95 | | - |
96 | | - if subset in ["training", "validation"]: |
97 | | - filepath = os.path.join(self._path, "validation_list.txt") |
98 | | - with open(filepath) as f: |
99 | | - validation_list = [os.path.join(self._path, l.strip()) for l in f.readlines()] |
100 | | - |
101 | | - if subset in ["training", "testing"]: |
102 | | - filepath = os.path.join(self._path, "testing_list.txt") |
103 | | - with open(filepath) as f: |
104 | | - testing_list = [os.path.join(self._path, l.strip()) for l in f.readlines()] |
| 95 | + def load_list(filename): |
| 96 | + filepath = os.path.join(self._path, filename) |
| 97 | + with open(filepath) as fileobj: |
| 98 | + return [os.path.join(self._path, line.strip()) for line in fileobj] |
105 | 99 |
|
106 | 100 | if subset == "validation": |
107 | | - walker = validation_list |
| 101 | + self._walker = load_list(VALIDATION_LIST) |
108 | 102 | elif subset == "testing": |
109 | | - walker = testing_list |
| 103 | + self._walker = load_list(TESTING_LIST) |
110 | 104 | elif subset == "training": |
111 | | - walker = filter( |
112 | | - lambda w: not (w in validation_list or w in testing_list), walker |
113 | | - ) |
114 | | - |
115 | | - self._walker = list(walker) |
| 105 | + excludes = load_list(VALIDATION_LIST) + load_list(TESTING_LIST) |
| 106 | + walker = walk_files(self._path, suffix=".wav", prefix=True) |
| 107 | + self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w and w not in excludes] |
| 108 | + else: |
| 109 | + walker = walk_files(self._path, suffix=".wav", prefix=True) |
| 110 | + self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w] |
116 | 111 |
|
117 | 112 | def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]: |
118 | 113 | """Load the n-th sample from the dataset. |
|
0 commit comments