Skip to content

Commit f5aced8

Browse files
krishnakalyan3krishnakalyan3
andauthored
Refactor YesNo dataset (#1127)
Co-authored-by: krishnakalyan3 <[email protected]>
1 parent e43a8e7 commit f5aced8

File tree

1 file changed

+29
-36
lines changed

1 file changed

+29
-36
lines changed

torchaudio/datasets/yesno.py

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,14 @@
1111
extract_archive,
1212
)
1313

14-
URL = "http://www.openslr.org/resources/1/waves_yesno.tar.gz"
15-
FOLDER_IN_ARCHIVE = "waves_yesno"
16-
_CHECKSUMS = {
17-
"http://www.openslr.org/resources/1/waves_yesno.tar.gz":
18-
"962ff6e904d2df1126132ecec6978786"
19-
}
20-
21-
22-
def load_yesno_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, int, List[int]]:
23-
# Read label
24-
labels = [int(c) for c in fileid.split("_")]
2514

26-
# Read wav
27-
file_audio = os.path.join(path, fileid + ext_audio)
28-
waveform, sample_rate = torchaudio.load(file_audio)
29-
30-
return waveform, sample_rate, labels
15+
_RELEASE_CONFIGS = {
16+
"release1": {
17+
"folder_in_archive": "waves_yesno",
18+
"url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz",
19+
"checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27",
20+
}
21+
}
3122

3223

3324
class YESNO(Dataset):
@@ -43,25 +34,26 @@ class YESNO(Dataset):
4334
Whether to download the dataset if it is not found at root path. (default: ``False``).
4435
"""
4536

46-
_ext_audio = ".wav"
47-
48-
def __init__(self,
49-
root: Union[str, Path],
50-
url: str = URL,
51-
folder_in_archive: str = FOLDER_IN_ARCHIVE,
52-
download: bool = False) -> None:
37+
def __init__(
38+
self,
39+
root: Union[str, Path],
40+
url: str = _RELEASE_CONFIGS["release1"]["url"],
41+
folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
42+
download: bool = False
43+
) -> None:
5344

54-
# Get string representation of 'root' in case Path object is passed
55-
root = os.fspath(root)
45+
self._parse_filesystem(root, url, folder_in_archive, download)
5646

47+
def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None:
48+
root = Path(root)
5749
archive = os.path.basename(url)
58-
archive = os.path.join(root, archive)
59-
self._path = os.path.join(root, folder_in_archive)
50+
archive = root / archive
6051

52+
self._path = root / folder_in_archive
6153
if download:
6254
if not os.path.isdir(self._path):
6355
if not os.path.isfile(archive):
64-
checksum = _CHECKSUMS.get(url, None)
56+
checksum = _RELEASE_CONFIGS["release1"]["checksum"]
6557
download_url(url, root, hash_value=checksum, hash_type="md5")
6658
extract_archive(archive)
6759

@@ -70,7 +62,13 @@ def __init__(self,
7062
"Dataset not found. Please use `download=True` to download it."
7163
)
7264

73-
self._walker = sorted(str(p.stem) for p in Path(self._path).glob('*' + self._ext_audio))
65+
self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*.wav"))
66+
67+
def _load_item(self, fileid: str, path: str):
68+
labels = [int(c) for c in fileid.split("_")]
69+
file_audio = os.path.join(path, fileid + ".wav")
70+
waveform, sample_rate = torchaudio.load(file_audio)
71+
return waveform, sample_rate, labels
7472

7573
def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]:
7674
"""Load the n-th sample from the dataset.
@@ -82,13 +80,8 @@ def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]:
8280
tuple: ``(waveform, sample_rate, labels)``
8381
"""
8482
fileid = self._walker[n]
85-
item = load_yesno_item(fileid, self._path, self._ext_audio)
86-
87-
# TODO Upon deprecation, uncomment line below and remove following code
88-
# return item
89-
90-
waveform, sample_rate, labels = item
91-
return waveform, sample_rate, labels
83+
item = self._load_item(fileid, self._path)
84+
return item
9285

9386
def __len__(self) -> int:
9487
return len(self._walker)

0 commit comments

Comments
 (0)