Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/assets/commonvoice/train.tsv
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
client_id path sentence up_votes down_votes age gender accent
00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 common_voice_tt_00000000.mp3 test. 1 0 thirties female
00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001 common_voice_tt_00000000.mp3 test. 1 0 thirties female
11 changes: 9 additions & 2 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from torchaudio.datasets.commonvoice import COMMONVOICE
from torchaudio.datasets.librispeech import LIBRISPEECH
from torchaudio.datasets.utils import DiskCache
from torchaudio.datasets.utils import diskcache_iterator, bg_iterator
from torchaudio.datasets.vctk import VCTK
from torchaudio.datasets.yesno import YESNO

Expand Down Expand Up @@ -34,12 +34,19 @@ def test_commonvoice(self):
def test_commonvoice_diskcache(self):
path = os.path.join(self.path, "commonvoice")
data = COMMONVOICE(path, "train.tsv", "tatar")
data = DiskCache(data)
data = diskcache_iterator(data)
# Save
data[0]
# Load
data[0]

def test_commonvoice_bg(self):
path = os.path.join(self.path, "commonvoice")
data = COMMONVOICE(path, "train.tsv", "tatar")
data = bg_iterator(data, 5)
for d in data:
pass


if __name__ == "__main__":
unittest.main()
11 changes: 9 additions & 2 deletions torchaudio/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from .commonvoice import COMMONVOICE
from .librispeech import LIBRISPEECH
from .utils import bg_iterator, diskcache_iterator
from .vctk import VCTK
from .yesno import YESNO
from .utils import DiskCache

__all__ = ("COMMONVOICE", "LIBRISPEECH", "VCTK", "YESNO", "DiskCache")
__all__ = (
"COMMONVOICE",
"LIBRISPEECH",
"VCTK",
"YESNO",
"diskcache_iterator",
"bg_iterator",
)
46 changes: 45 additions & 1 deletion torchaudio/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import os
import sys
import tarfile
import threading
import zipfile
from queue import Queue

import six
import torch
Expand Down Expand Up @@ -192,7 +194,7 @@ def walk_files(root, suffix, prefix=False, remove_suffix=False):
yield f


class DiskCache(Dataset):
class _DiskCache(Dataset):
"""
Wrap a dataset so that, whenever a new item is returned, it is saved to disk.
"""
Expand Down Expand Up @@ -221,3 +223,45 @@ def __getitem__(self, n):

def __len__(self):
return len(self.dataset)


def diskcache_iterator(dataset, location=".cached"):
return _DiskCache(dataset, location)


class _ThreadedIterator(threading.Thread):
"""
Prefetch the next queue_length items from iterator in a background thread.

Example:
>> for i in bg_iterator(range(10)):
>> print(i)
"""

class _End:
pass

def __init__(self, generator, maxsize):
threading.Thread.__init__(self)
self.queue = Queue(maxsize)
self.generator = generator
self.daemon = True
self.start()

def run(self):
for item in self.generator:
self.queue.put(item)
self.queue.put(self._End)

def __iter__(self):
return self

def __next__(self):
next_item = self.queue.get()
if next_item == self._End:
raise StopIteration
return next_item


def bg_iterator(iterable, maxsize):
return _ThreadedIterator(iterable, maxsize=maxsize)