Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import os
import unittest

from torchaudio.datasets.commonvoice import COMMONVOICE
from torchaudio.datasets.librispeech import LIBRISPEECH
from torchaudio.datasets.utils import DiskCache
from torchaudio.datasets.vctk import VCTK
from torchaudio.datasets.yesno import YESNO
from torchaudio.datasets import COMMONVOICE, LIBRISPEECH, VCTK, YESNO, DiskCache


class TestDatasets(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from .librispeech import LIBRISPEECH
from .vctk import VCTK
from .yesno import YESNO
from .utils import DiskCache
from .cache import DiskCache

__all__ = ("COMMONVOICE", "LIBRISPEECH", "VCTK", "YESNO", "DiskCache")
36 changes: 36 additions & 0 deletions torchaudio/datasets/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os

import torch
from torch.utils.data import Dataset
from torchaudio.datasets.utils import makedir_exist_ok


class DiskCache(Dataset):
"""
Wrap a dataset so that, whenever a new item is returned, it is saved to disk.
"""

def __init__(self, dataset, location=".cached"):
self.dataset = dataset
self.location = location

self._id = id(self)
self._cache = [None] * len(dataset)

def __getitem__(self, n):
if self._cache[n]:
f = self._cache[n]
return torch.load(f)

f = str(self._id) + "-" + str(n)
f = os.path.join(self.location, f)
item = self.dataset[n]

self._cache[n] = f
makedir_exist_ok(self.location)
torch.save(item, f)

return item

def __len__(self):
return len(self.dataset)
32 changes: 0 additions & 32 deletions torchaudio/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import six
import torch
import torchaudio
from torch.utils.data import Dataset
from torch.utils.model_zoo import tqdm


Expand Down Expand Up @@ -190,34 +189,3 @@ def walk_files(root, suffix, prefix=False, remove_suffix=False):
f = os.path.join(root, f)

yield f


class DiskCache(Dataset):
"""
Wrap a dataset so that, whenever a new item is returned, it is saved to disk.
"""

def __init__(self, dataset, location=".cached"):
self.dataset = dataset
self.location = location

self._id = id(self)
self._cache = [None] * len(dataset)

def __getitem__(self, n):
if self._cache[n]:
f = self._cache[n]
return torch.load(f)

f = str(self._id) + "-" + str(n)
f = os.path.join(self.location, f)
item = self.dataset[n]

self._cache[n] = f
makedir_exist_ok(self.location)
torch.save(item, f)

return item

def __len__(self):
return len(self.dataset)