diff --git a/test/test_datasets.py b/test/test_datasets.py index 4d8cacbb4a..947dfa5fd6 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1,11 +1,7 @@ import os import unittest -from torchaudio.datasets.commonvoice import COMMONVOICE -from torchaudio.datasets.librispeech import LIBRISPEECH -from torchaudio.datasets.utils import DiskCache -from torchaudio.datasets.vctk import VCTK -from torchaudio.datasets.yesno import YESNO +from torchaudio.datasets import COMMONVOICE, LIBRISPEECH, VCTK, YESNO, DiskCache class TestDatasets(unittest.TestCase): diff --git a/torchaudio/datasets/__init__.py b/torchaudio/datasets/__init__.py index ef395a0ca4..848ed7202a 100644 --- a/torchaudio/datasets/__init__.py +++ b/torchaudio/datasets/__init__.py @@ -2,6 +2,6 @@ from .librispeech import LIBRISPEECH from .vctk import VCTK from .yesno import YESNO -from .utils import DiskCache +from .cache import DiskCache __all__ = ("COMMONVOICE", "LIBRISPEECH", "VCTK", "YESNO", "DiskCache") diff --git a/torchaudio/datasets/cache.py b/torchaudio/datasets/cache.py new file mode 100644 index 0000000000..71cb150825 --- /dev/null +++ b/torchaudio/datasets/cache.py @@ -0,0 +1,36 @@ +import os + +import torch +from torch.utils.data import Dataset +from torchaudio.datasets.utils import makedir_exist_ok + + +class DiskCache(Dataset): + """ + Wrap a dataset so that, whenever a new item is returned, it is saved to disk. + """ + + def __init__(self, dataset, location=".cached"): + self.dataset = dataset + self.location = location + + self._id = id(self) + self._cache = [None] * len(dataset) + + def __getitem__(self, n): + if self._cache[n]: + f = self._cache[n] + return torch.load(f) + + f = str(self._id) + "-" + str(n) + f = os.path.join(self.location, f) + item = self.dataset[n] + + self._cache[n] = f + makedir_exist_ok(self.location) + torch.save(item, f) + + return item + + def __len__(self): + return len(self.dataset) diff --git a/torchaudio/datasets/utils.py b/torchaudio/datasets/utils.py index 25a30293aa..a57b6c1c1f 100644 --- a/torchaudio/datasets/utils.py +++ b/torchaudio/datasets/utils.py @@ -11,7 +11,6 @@ import six import torch import torchaudio -from torch.utils.data import Dataset from torch.utils.model_zoo import tqdm @@ -190,34 +189,3 @@ def walk_files(root, suffix, prefix=False, remove_suffix=False): f = os.path.join(root, f) yield f - - -class DiskCache(Dataset): - """ - Wrap a dataset so that, whenever a new item is returned, it is saved to disk. - """ - - def __init__(self, dataset, location=".cached"): - self.dataset = dataset - self.location = location - - self._id = id(self) - self._cache = [None] * len(dataset) - - def __getitem__(self, n): - if self._cache[n]: - f = self._cache[n] - return torch.load(f) - - f = str(self._id) + "-" + str(n) - f = os.path.join(self.location, f) - item = self.dataset[n] - - self._cache[n] = f - makedir_exist_ok(self.location) - torch.save(item, f) - - return item - - def __len__(self): - return len(self.dataset)