diff --git a/torchaudio/datasets/yesno.py b/torchaudio/datasets/yesno.py index 6d0d2c0a5f..614cd4c7a5 100644 --- a/torchaudio/datasets/yesno.py +++ b/torchaudio/datasets/yesno.py @@ -1,14 +1,16 @@ import os from pathlib import Path +from dataclasses import dataclass from typing import List, Tuple, Union from torch import Tensor -from torch.utils.data import Dataset +from torch.utils.data import Dataset, IterDataPipe import torchaudio from torchaudio.datasets.utils import ( download_url, extract_archive, + ) @@ -17,10 +19,108 @@ "folder_in_archive": "waves_yesno", "url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz", "checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27", + "files": [ + "0_0_0_0_1_1_1_1.wav", + "0_0_0_1_0_0_0_1.wav", + "0_0_0_1_0_1_1_0.wav", + "0_0_1_0_0_0_1_0.wav", + "0_0_1_0_0_1_1_0.wav", + "0_0_1_0_0_1_1_1.wav", + "0_0_1_0_1_0_0_0.wav", + "0_0_1_0_1_0_0_1.wav", + "0_0_1_0_1_0_1_1.wav", + "0_0_1_1_0_0_0_1.wav", + "0_0_1_1_0_1_0_0.wav", + "0_0_1_1_0_1_1_0.wav", + "0_0_1_1_0_1_1_1.wav", + "0_0_1_1_1_0_0_0.wav", + "0_0_1_1_1_0_0_1.wav", + "0_0_1_1_1_1_0_0.wav", + "0_0_1_1_1_1_1_0.wav", + "0_1_0_0_0_1_0_0.wav", + "0_1_0_0_0_1_1_0.wav", + "0_1_0_0_1_0_1_0.wav", + "0_1_0_0_1_0_1_1.wav", + "0_1_0_1_0_0_0_0.wav", + "0_1_0_1_1_0_1_0.wav", + "0_1_0_1_1_1_0_0.wav", + "0_1_1_0_0_1_1_0.wav", + "0_1_1_0_0_1_1_1.wav", + "0_1_1_1_0_0_0_0.wav", + "0_1_1_1_0_0_1_0.wav", + "0_1_1_1_0_1_0_1.wav", + "0_1_1_1_1_0_1_0.wav", + "0_1_1_1_1_1_1_1.wav", + "1_0_0_0_0_0_0_0.wav", + "1_0_0_0_0_0_0_1.wav", + "1_0_0_0_0_0_1_1.wav", + "1_0_0_0_1_0_0_1.wav", + "1_0_0_1_0_1_1_1.wav", + "1_0_1_0_1_0_0_1.wav", + "1_0_1_1_0_1_1_1.wav", + "1_0_1_1_1_0_1_0.wav", + "1_0_1_1_1_1_0_1.wav", + "1_1_0_0_0_0_0_1.wav", + "1_1_0_0_0_1_1_1.wav", + "1_1_0_0_1_0_1_0.wav", + "1_1_0_0_1_0_1_1.wav", + "1_1_0_0_1_1_1_0.wav", + "1_1_0_1_0_1_0_0.wav", + "1_1_0_1_0_1_1_0.wav", + "1_1_0_1_1_0_0_1.wav", + "1_1_0_1_1_0_1_1.wav", + "1_1_0_1_1_1_1_0.wav", + "1_1_1_0_0_0_0_1.wav", + "1_1_1_0_0_1_0_1.wav", + "1_1_1_0_0_1_1_1.wav", + "1_1_1_0_1_0_1_0.wav", + "1_1_1_0_1_0_1_1.wav", + "1_1_1_1_0_0_1_0.wav", + "1_1_1_1_0_1_0_0.wav", + "1_1_1_1_1_0_0_0.wav", + "1_1_1_1_1_1_0_0.wav", + "1_1_1_1_1_1_1_1.wav", + ] } } +@dataclass +class YesNoItem: + path: str + label: List[int] + waveform: Tensor + sample_rate: int + + +class ListYesNoItems(IterDataPipe): + """Given a root directory, return the list of files""" + def __init__(self, root): + self.data_dir = os.path.join(root, 'waves_yesno') + self.files = _RELEASE_CONFIGS['release1']['files'] + + def __iter__(self): + for filename in self.files: + path = os.path.join(self.data_dir, filename) + label = [int(c) for c in path.split("_")] + yield path, label + + +class LoadYesNoItem(IterDataPipe): + def __init__(self, data_pipe): + self.data_pipe = data_pipe + + def __iter__(self): + for path, label in self.data_pipe: + waveform, sample_rate = torchaudio.load(path) + yield YesNoItem(path, label, waveform, sample_rate) + + +def get_yesno_dataset(root_dir: str): # , download=False, url=_RELEASE_CONFIGS["release1"]["url"]): + # TODO: download dataset if necessary + return LoadYesNoItem(ListYesNoItems(root_dir)) + + class YESNO(Dataset): """Create a Dataset for YesNo.