1111 extract_archive ,
1212)
1313
14- URL = "http://www.openslr.org/resources/1/waves_yesno.tar.gz"
15- FOLDER_IN_ARCHIVE = "waves_yesno"
16- _CHECKSUMS = {
17- "http://www.openslr.org/resources/1/waves_yesno.tar.gz" :
18- "962ff6e904d2df1126132ecec6978786"
19- }
20-
21-
22- def load_yesno_item (fileid : str , path : str , ext_audio : str ) -> Tuple [Tensor , int , List [int ]]:
23- # Read label
24- labels = [int (c ) for c in fileid .split ("_" )]
2514
26- # Read wav
27- file_audio = os .path .join (path , fileid + ext_audio )
28- waveform , sample_rate = torchaudio .load (file_audio )
29-
30- return waveform , sample_rate , labels
15+ _RELEASE_CONFIGS = {
16+ "release1" : {
17+ "folder_in_archive" : "waves_yesno" ,
18+ "url" : "http://www.openslr.org/resources/1/waves_yesno.tar.gz" ,
19+ "checksum" : "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27" ,
20+ }
21+ }
3122
3223
3324class YESNO (Dataset ):
@@ -43,25 +34,26 @@ class YESNO(Dataset):
4334 Whether to download the dataset if it is not found at root path. (default: ``False``).
4435 """
4536
46- _ext_audio = ".wav"
47-
48- def __init__ ( self ,
49- root : Union [ str , Path ],
50- url : str = URL ,
51- folder_in_archive : str = FOLDER_IN_ARCHIVE ,
52- download : bool = False ) -> None :
37+ def __init__ (
38+ self ,
39+ root : Union [ str , Path ] ,
40+ url : str = _RELEASE_CONFIGS [ "release1" ][ "url" ],
41+ folder_in_archive : str = _RELEASE_CONFIGS [ "release1" ][ "folder_in_archive" ] ,
42+ download : bool = False
43+ ) -> None :
5344
54- # Get string representation of 'root' in case Path object is passed
55- root = os .fspath (root )
45+ self ._parse_filesystem (root , url , folder_in_archive , download )
5646
47+ def _parse_filesystem (self , root : str , url : str , folder_in_archive : str , download : bool ) -> None :
48+ root = Path (root )
5749 archive = os .path .basename (url )
58- archive = os .path .join (root , archive )
59- self ._path = os .path .join (root , folder_in_archive )
50+ archive = root / archive
6051
52+ self ._path = root / folder_in_archive
6153 if download :
6254 if not os .path .isdir (self ._path ):
6355 if not os .path .isfile (archive ):
64- checksum = _CHECKSUMS . get ( url , None )
56+ checksum = _RELEASE_CONFIGS [ "release1" ][ "checksum" ]
6557 download_url (url , root , hash_value = checksum , hash_type = "md5" )
6658 extract_archive (archive )
6759
@@ -70,7 +62,13 @@ def __init__(self,
7062 "Dataset not found. Please use `download=True` to download it."
7163 )
7264
73- self ._walker = sorted (str (p .stem ) for p in Path (self ._path ).glob ('*' + self ._ext_audio ))
65+ self ._walker = sorted (str (p .stem ) for p in Path (self ._path ).glob ("*.wav" ))
66+
67+ def _load_item (self , fileid : str , path : str ):
68+ labels = [int (c ) for c in fileid .split ("_" )]
69+ file_audio = os .path .join (path , fileid + ".wav" )
70+ waveform , sample_rate = torchaudio .load (file_audio )
71+ return waveform , sample_rate , labels
7472
7573 def __getitem__ (self , n : int ) -> Tuple [Tensor , int , List [int ]]:
7674 """Load the n-th sample from the dataset.
@@ -82,13 +80,8 @@ def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]:
8280 tuple: ``(waveform, sample_rate, labels)``
8381 """
8482 fileid = self ._walker [n ]
85- item = load_yesno_item (fileid , self ._path , self ._ext_audio )
86-
87- # TODO Upon deprecation, uncomment line below and remove following code
88- # return item
89-
90- waveform , sample_rate , labels = item
91- return waveform , sample_rate , labels
83+ item = self ._load_item (fileid , self ._path )
84+ return item
9285
9386 def __len__ (self ) -> int :
9487 return len (self ._walker )
0 commit comments