diff --git a/torchaudio/datasets/yesno.py b/torchaudio/datasets/yesno.py index 7178b8332c..6d0d2c0a5f 100644 --- a/torchaudio/datasets/yesno.py +++ b/torchaudio/datasets/yesno.py @@ -11,23 +11,14 @@ extract_archive, ) -URL = "http://www.openslr.org/resources/1/waves_yesno.tar.gz" -FOLDER_IN_ARCHIVE = "waves_yesno" -_CHECKSUMS = { - "http://www.openslr.org/resources/1/waves_yesno.tar.gz": - "962ff6e904d2df1126132ecec6978786" -} - - -def load_yesno_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, int, List[int]]: - # Read label - labels = [int(c) for c in fileid.split("_")] - # Read wav - file_audio = os.path.join(path, fileid + ext_audio) - waveform, sample_rate = torchaudio.load(file_audio) - - return waveform, sample_rate, labels +_RELEASE_CONFIGS = { + "release1": { + "folder_in_archive": "waves_yesno", + "url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz", + "checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27", + } +} class YESNO(Dataset): @@ -43,25 +34,26 @@ class YESNO(Dataset): Whether to download the dataset if it is not found at root path. (default: ``False``). """ - _ext_audio = ".wav" - - def __init__(self, - root: Union[str, Path], - url: str = URL, - folder_in_archive: str = FOLDER_IN_ARCHIVE, - download: bool = False) -> None: + def __init__( + self, + root: Union[str, Path], + url: str = _RELEASE_CONFIGS["release1"]["url"], + folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"], + download: bool = False + ) -> None: - # Get string representation of 'root' in case Path object is passed - root = os.fspath(root) + self._parse_filesystem(root, url, folder_in_archive, download) + def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None: + root = Path(root) archive = os.path.basename(url) - archive = os.path.join(root, archive) - self._path = os.path.join(root, folder_in_archive) + archive = root / archive + self._path = root / folder_in_archive if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): - checksum = _CHECKSUMS.get(url, None) + checksum = _RELEASE_CONFIGS["release1"]["checksum"] download_url(url, root, hash_value=checksum, hash_type="md5") extract_archive(archive) @@ -70,7 +62,13 @@ def __init__(self, "Dataset not found. Please use `download=True` to download it." ) - self._walker = sorted(str(p.stem) for p in Path(self._path).glob('*' + self._ext_audio)) + self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*.wav")) + + def _load_item(self, fileid: str, path: str): + labels = [int(c) for c in fileid.split("_")] + file_audio = os.path.join(path, fileid + ".wav") + waveform, sample_rate = torchaudio.load(file_audio) + return waveform, sample_rate, labels def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]: """Load the n-th sample from the dataset. @@ -82,13 +80,8 @@ def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]: tuple: ``(waveform, sample_rate, labels)`` """ fileid = self._walker[n] - item = load_yesno_item(fileid, self._path, self._ext_audio) - - # TODO Upon deprecation, uncomment line below and remove following code - # return item - - waveform, sample_rate, labels = item - return waveform, sample_rate, labels + item = self._load_item(fileid, self._path) + return item def __len__(self) -> int: return len(self._walker)