diff --git a/test/test_datasets.py b/test/test_datasets.py index 8f35a27b56..54b611244c 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -15,11 +15,11 @@ class TestDatasets(unittest.TestCase): path = os.path.join(test_dirpath, "assets") def test_yesno(self): - data = YESNO(self.path, return_dict=True) + data = YESNO(self.path) data[0] def test_vctk(self): - data = VCTK(self.path, return_dict=True) + data = VCTK(self.path) data[0] def test_librispeech(self): diff --git a/torchaudio/datasets/commonvoice.py b/torchaudio/datasets/commonvoice.py index 94f19cb72a..6caf3f6e5f 100644 --- a/torchaudio/datasets/commonvoice.py +++ b/torchaudio/datasets/commonvoice.py @@ -1,7 +1,8 @@ import os -import torchaudio from torch.utils.data import Dataset + +import torchaudio from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv_reader # Default TSV should be one of @@ -17,6 +18,10 @@ def load_commonvoice_item(line, header, path, folder_audio): + # Each line as the following data: + # client_id, path, sentence, up_votes, down_votes, age, gender, accent + + assert header[1] == "path" fileid = line[1] filename = os.path.join(path, folder_audio, fileid) @@ -24,13 +29,17 @@ def load_commonvoice_item(line, header, path, folder_audio): waveform, sample_rate = torchaudio.load(filename) dic = dict(zip(header, line)) - dic["waveform"] = waveform - dic["sample_rate"] = sample_rate - return dic + return waveform, sample_rate, dic class COMMONVOICE(Dataset): + """ + Create a Dataset for CommonVoice. Each item is a tuple of the form: + (waveform, sample_rate, dictionary) + where dictionary is a dictionary built from the tsv file with the following keys: + client_id, path, sentence, up_votes, down_votes, age, gender, accent. + """ _ext_txt = ".txt" _ext_audio = ".mp3" diff --git a/torchaudio/datasets/librispeech.py b/torchaudio/datasets/librispeech.py index e9919ba1a3..f3d92467f8 100644 --- a/torchaudio/datasets/librispeech.py +++ b/torchaudio/datasets/librispeech.py @@ -1,8 +1,7 @@ import os -from torch.utils.data import Dataset - import torchaudio +from torch.utils.data import Dataset from torchaudio.datasets.utils import ( download_url, extract_archive, @@ -16,38 +15,43 @@ def load_librispeech_item(fileid, path, ext_audio, ext_txt): - speaker, chapter, utterance = fileid.split("-") + speaker_id, chapter_id, utterance_id = fileid.split("-") - file_text = speaker + "-" + chapter + ext_txt - file_text = os.path.join(path, speaker, chapter, file_text) + file_text = speaker_id + "-" + chapter_id + ext_txt + file_text = os.path.join(path, speaker_id, chapter_id, file_text) - fileid_audio = speaker + "-" + chapter + "-" + utterance + fileid_audio = speaker_id + "-" + chapter_id + "-" + utterance_id file_audio = fileid_audio + ext_audio - file_audio = os.path.join(path, speaker, chapter, file_audio) + file_audio = os.path.join(path, speaker_id, chapter_id, file_audio) # Load audio waveform, sample_rate = torchaudio.load(file_audio) # Load text - for line in open(file_text): - fileid_text, content = line.strip().split(" ", 1) - if fileid_audio == fileid_text: - break - else: - # Translation not found - raise FileNotFoundError("Translation not found for " + fileid_audio) - - return { - "speaker_id": speaker, - "chapter_id": chapter, - "utterance_id": utterance, - "utterance": content, - "waveform": waveform, - "sample_rate": sample_rate, - } + with open(file_text) as ft: + for line in ft: + fileid_text, utterance = line.strip().split(" ", 1) + if fileid_audio == fileid_text: + break + else: + # Translation not found + raise FileNotFoundError("Translation not found for " + fileid_audio) + + return ( + waveform, + sample_rate, + utterance, + int(speaker_id), + int(chapter_id), + int(utterance_id), + ) class LIBRISPEECH(Dataset): + """ + Create a Dataset for LibriSpeech. Each item is a tuple of the form: + waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id + """ _ext_txt = ".trans.txt" _ext_audio = ".flac" diff --git a/torchaudio/datasets/vctk.py b/torchaudio/datasets/vctk.py index 279e83fa1d..813a9df62a 100644 --- a/torchaudio/datasets/vctk.py +++ b/torchaudio/datasets/vctk.py @@ -12,15 +12,15 @@ def load_vctk_item( fileid, path, ext_audio, ext_txt, folder_audio, folder_txt, downsample=False ): - speaker, utterance = fileid.split("_") + speaker_id, utterance_id = fileid.split("_") # Read text - file_txt = os.path.join(path, folder_txt, speaker, fileid + ext_txt) + file_txt = os.path.join(path, folder_txt, speaker_id, fileid + ext_txt) with open(file_txt) as file_text: - content = file_text.readlines()[0] + utterance = file_text.readlines()[0] # Read wav - file_audio = os.path.join(path, folder_audio, speaker, fileid + ext_audio) + file_audio = os.path.join(path, folder_audio, speaker_id, fileid + ext_audio) if downsample: # Legacy E = torchaudio.sox_effects.SoxEffectsChain() @@ -34,16 +34,14 @@ def load_vctk_item( else: waveform, sample_rate = torchaudio.load(file_audio) - return { - "speaker_id": speaker, - "utterance_id": utterance, - "utterance": content, - "waveform": waveform, - "sample_rate": sample_rate, - } + return waveform, sample_rate, utterance, speaker_id, utterance_id class VCTK(Dataset): + """ + Create a Dataset for VCTK. Each item is a tuple of the form: + (waveform, sample_rate, utterance, speaker_id, utterance_id) + """ _folder_txt = "txt" _folder_audio = "wav48" @@ -59,17 +57,8 @@ def __init__( downsample=False, transform=None, target_transform=None, - return_dict=False, ): - if not return_dict: - warnings.warn( - "In the next version, the item returned will be a dictionary. " - "Please use `return_dict=True` to enable this behavior now, " - "and suppress this warning.", - DeprecationWarning, - ) - if downsample: warnings.warn( "In the next version, transforms will not be part of the dataset. " @@ -89,7 +78,6 @@ def __init__( self.downsample = downsample self.transform = transform self.target_transform = target_transform - self.return_dict = return_dict archive = os.path.basename(url) archive = os.path.join(root, archive) @@ -122,23 +110,15 @@ def __getitem__(self, n): self._folder_txt, ) - # Legacy - waveform = item["waveform"] + # TODO Upon deprecation, uncomment line below and remove following code + # return item + + waveform, sample_rate, utterance, speaker_id, utterance_id = item if self.transform is not None: waveform = self.transform(waveform) - item["waveform"] = waveform - - # Legacy - utterance = item["utterance"] if self.target_transform is not None: utterance = self.target_transform(utterance) - item["utterance"] = utterance - - if self.return_dict: - return item - - # Legacy - return item["waveform"], item["utterance"] + return waveform, sample_rate, utterance, speaker_id, utterance_id def __len__(self): return len(self._walker) diff --git a/torchaudio/datasets/yesno.py b/torchaudio/datasets/yesno.py index dd1b9e180e..01bf8a3d15 100644 --- a/torchaudio/datasets/yesno.py +++ b/torchaudio/datasets/yesno.py @@ -11,16 +11,20 @@ def load_yesno_item(fileid, path, ext_audio): # Read label - label = fileid.split("_") + labels = [int(c) for c in fileid.split("_")] # Read wav file_audio = os.path.join(path, fileid + ext_audio) waveform, sample_rate = torchaudio.load(file_audio) - return {"label": label, "waveform": waveform, "sample_rate": sample_rate} + return waveform, sample_rate, labels class YESNO(Dataset): + """ + Create a Dataset for YesNo. Each item is a tuple of the form: + (waveform, sample_rate, labels) + """ _ext_audio = ".wav" @@ -32,17 +36,8 @@ def __init__( download=False, transform=None, target_transform=None, - return_dict=False, ): - if not return_dict: - warnings.warn( - "In the next version, the item returned will be a dictionary. " - "Please use `return_dict=True` to enable this behavior now, " - "and suppress this warning.", - DeprecationWarning, - ) - if transform is not None or target_transform is not None: warnings.warn( "In the next version, transforms will not be part of the dataset. " @@ -53,7 +48,6 @@ def __init__( self.transform = transform self.target_transform = target_transform - self.return_dict = return_dict archive = os.path.basename(url) archive = os.path.join(root, archive) @@ -79,20 +73,15 @@ def __getitem__(self, n): fileid = self._walker[n] item = load_yesno_item(fileid, self._path, self._ext_audio) - waveform = item["waveform"] + # TODO Upon deprecation, uncomment line below and remove following code + # return item + + waveform, sample_rate, labels = item if self.transform is not None: waveform = self.transform(waveform) - item["waveform"] = waveform - - label = item["label"] if self.target_transform is not None: - label = self.target_transform(label) - item["label"] = label - - if self.return_dict: - return item - - return item["waveform"], item["label"] + labels = self.target_transform(labels) + return waveform, sample_rate, labels def __len__(self): return len(self._walker)