Skip to content

Commit 6db2ad1

Browse files
authored
Background generator (#323)
* BackgroundGenerator * renaming disk cache.
1 parent 99c5260 commit 6db2ad1

File tree

4 files changed

+64
-5
lines changed

4 files changed

+64
-5
lines changed

test/assets/commonvoice/train.tsv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
client_id path sentence up_votes down_votes age gender accent
22
00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 common_voice_tt_00000000.mp3 test. 1 0 thirties female
3+
00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001 common_voice_tt_00000000.mp3 test. 1 0 thirties female

test/test_datasets.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from torchaudio.datasets.commonvoice import COMMONVOICE
55
from torchaudio.datasets.librispeech import LIBRISPEECH
6-
from torchaudio.datasets.utils import DiskCache
6+
from torchaudio.datasets.utils import diskcache_iterator, bg_iterator
77
from torchaudio.datasets.vctk import VCTK
88
from torchaudio.datasets.yesno import YESNO
99

@@ -34,12 +34,19 @@ def test_commonvoice(self):
3434
def test_commonvoice_diskcache(self):
3535
path = os.path.join(self.path, "commonvoice")
3636
data = COMMONVOICE(path, "train.tsv", "tatar")
37-
data = DiskCache(data)
37+
data = diskcache_iterator(data)
3838
# Save
3939
data[0]
4040
# Load
4141
data[0]
4242

43+
def test_commonvoice_bg(self):
44+
path = os.path.join(self.path, "commonvoice")
45+
data = COMMONVOICE(path, "train.tsv", "tatar")
46+
data = bg_iterator(data, 5)
47+
for d in data:
48+
pass
49+
4350

4451
if __name__ == "__main__":
4552
unittest.main()

torchaudio/datasets/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
from .commonvoice import COMMONVOICE
22
from .librispeech import LIBRISPEECH
3+
from .utils import bg_iterator, diskcache_iterator
34
from .vctk import VCTK
45
from .yesno import YESNO
5-
from .utils import DiskCache
66

7-
__all__ = ("COMMONVOICE", "LIBRISPEECH", "VCTK", "YESNO", "DiskCache")
7+
__all__ = (
8+
"COMMONVOICE",
9+
"LIBRISPEECH",
10+
"VCTK",
11+
"YESNO",
12+
"diskcache_iterator",
13+
"bg_iterator",
14+
)

torchaudio/datasets/utils.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import os
77
import sys
88
import tarfile
9+
import threading
910
import zipfile
11+
from queue import Queue
1012

1113
import six
1214
import torch
@@ -192,7 +194,7 @@ def walk_files(root, suffix, prefix=False, remove_suffix=False):
192194
yield f
193195

194196

195-
class DiskCache(Dataset):
197+
class _DiskCache(Dataset):
196198
"""
197199
Wrap a dataset so that, whenever a new item is returned, it is saved to disk.
198200
"""
@@ -221,3 +223,45 @@ def __getitem__(self, n):
221223

222224
def __len__(self):
223225
return len(self.dataset)
226+
227+
228+
def diskcache_iterator(dataset, location=".cached"):
229+
return _DiskCache(dataset, location)
230+
231+
232+
class _ThreadedIterator(threading.Thread):
233+
"""
234+
Prefetch the next queue_length items from iterator in a background thread.
235+
236+
Example:
237+
>> for i in bg_iterator(range(10)):
238+
>> print(i)
239+
"""
240+
241+
class _End:
242+
pass
243+
244+
def __init__(self, generator, maxsize):
245+
threading.Thread.__init__(self)
246+
self.queue = Queue(maxsize)
247+
self.generator = generator
248+
self.daemon = True
249+
self.start()
250+
251+
def run(self):
252+
for item in self.generator:
253+
self.queue.put(item)
254+
self.queue.put(self._End)
255+
256+
def __iter__(self):
257+
return self
258+
259+
def __next__(self):
260+
next_item = self.queue.get()
261+
if next_item == self._End:
262+
raise StopIteration
263+
return next_item
264+
265+
266+
def bg_iterator(iterable, maxsize):
267+
return _ThreadedIterator(iterable, maxsize=maxsize)

0 commit comments

Comments
 (0)