diff --git a/test/assets/commonvoice/train.tsv b/test/assets/commonvoice/train.tsv index e1f4623200..2b677dbf7f 100644 --- a/test/assets/commonvoice/train.tsv +++ b/test/assets/commonvoice/train.tsv @@ -1,2 +1,3 @@ client_id path sentence up_votes down_votes age gender accent 00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 common_voice_tt_00000000.mp3 test. 1 0 thirties female +00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001 common_voice_tt_00000000.mp3 test. 1 0 thirties female diff --git a/test/test_datasets.py b/test/test_datasets.py index 8f35a27b56..e27dde3051 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -3,7 +3,7 @@ from torchaudio.datasets.commonvoice import COMMONVOICE from torchaudio.datasets.librispeech import LIBRISPEECH -from torchaudio.datasets.utils import DiskCache +from torchaudio.datasets.utils import diskcache_iterator, bg_iterator from torchaudio.datasets.vctk import VCTK from torchaudio.datasets.yesno import YESNO @@ -34,12 +34,19 @@ def test_commonvoice(self): def test_commonvoice_diskcache(self): path = os.path.join(self.path, "commonvoice") data = COMMONVOICE(path, "train.tsv", "tatar") - data = DiskCache(data) + data = diskcache_iterator(data) # Save data[0] # Load data[0] + def test_commonvoice_bg(self): + path = os.path.join(self.path, "commonvoice") + data = COMMONVOICE(path, "train.tsv", "tatar") + data = bg_iterator(data, 5) + for d in data: + pass + if __name__ == "__main__": unittest.main() diff --git a/torchaudio/datasets/__init__.py b/torchaudio/datasets/__init__.py index ef395a0ca4..67f6a77d2a 100644 --- a/torchaudio/datasets/__init__.py +++ b/torchaudio/datasets/__init__.py @@ -1,7 +1,14 @@ from .commonvoice import COMMONVOICE from .librispeech import LIBRISPEECH +from .utils import bg_iterator, diskcache_iterator from .vctk import VCTK from .yesno import YESNO -from .utils import DiskCache -__all__ = ("COMMONVOICE", "LIBRISPEECH", "VCTK", "YESNO", "DiskCache") +__all__ = ( + "COMMONVOICE", + "LIBRISPEECH", + "VCTK", + "YESNO", + "diskcache_iterator", + "bg_iterator", +) diff --git a/torchaudio/datasets/utils.py b/torchaudio/datasets/utils.py index 25a30293aa..f3daebcd12 100644 --- a/torchaudio/datasets/utils.py +++ b/torchaudio/datasets/utils.py @@ -6,7 +6,9 @@ import os import sys import tarfile +import threading import zipfile +from queue import Queue import six import torch @@ -192,7 +194,7 @@ def walk_files(root, suffix, prefix=False, remove_suffix=False): yield f -class DiskCache(Dataset): +class _DiskCache(Dataset): """ Wrap a dataset so that, whenever a new item is returned, it is saved to disk. """ @@ -221,3 +223,45 @@ def __getitem__(self, n): def __len__(self): return len(self.dataset) + + +def diskcache_iterator(dataset, location=".cached"): + return _DiskCache(dataset, location) + + +class _ThreadedIterator(threading.Thread): + """ + Prefetch the next queue_length items from iterator in a background thread. + + Example: + >> for i in bg_iterator(range(10)): + >> print(i) + """ + + class _End: + pass + + def __init__(self, generator, maxsize): + threading.Thread.__init__(self) + self.queue = Queue(maxsize) + self.generator = generator + self.daemon = True + self.start() + + def run(self): + for item in self.generator: + self.queue.put(item) + self.queue.put(self._End) + + def __iter__(self): + return self + + def __next__(self): + next_item = self.queue.get() + if next_item == self._End: + raise StopIteration + return next_item + + +def bg_iterator(iterable, maxsize): + return _ThreadedIterator(iterable, maxsize=maxsize)