Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ class TestDatasets(unittest.TestCase):
path = os.path.join(test_dirpath, "assets")

def test_yesno(self):
data = YESNO(self.path, return_dict=True)
data = YESNO(self.path)
data[0]

def test_vctk(self):
data = VCTK(self.path, return_dict=True)
data = VCTK(self.path)
data[0]

def test_librispeech(self):
Expand Down
17 changes: 13 additions & 4 deletions torchaudio/datasets/commonvoice.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os

import torchaudio
from torch.utils.data import Dataset

import torchaudio
from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv_reader

# Default TSV should be one of
Expand All @@ -17,20 +18,28 @@


def load_commonvoice_item(line, header, path, folder_audio):
# Each line as the following data:
# client_id, path, sentence, up_votes, down_votes, age, gender, accent

assert header[1] == "path"
fileid = line[1]

filename = os.path.join(path, folder_audio, fileid)

waveform, sample_rate = torchaudio.load(filename)

dic = dict(zip(header, line))
dic["waveform"] = waveform
dic["sample_rate"] = sample_rate

return dic
return waveform, sample_rate, dic


class COMMONVOICE(Dataset):
"""
Create a Dataset for CommonVoice. Each item is a tuple of the form:
(waveform, sample_rate, dictionary)
where dictionary is a dictionary built from the tsv file with the following keys:
client_id, path, sentence, up_votes, down_votes, age, gender, accent.
"""

_ext_txt = ".txt"
_ext_audio = ".mp3"
Expand Down
50 changes: 27 additions & 23 deletions torchaudio/datasets/librispeech.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os

from torch.utils.data import Dataset

import torchaudio
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
download_url,
extract_archive,
Expand All @@ -16,38 +15,43 @@

def load_librispeech_item(fileid, path, ext_audio, ext_txt):

speaker, chapter, utterance = fileid.split("-")
speaker_id, chapter_id, utterance_id = fileid.split("-")

file_text = speaker + "-" + chapter + ext_txt
file_text = os.path.join(path, speaker, chapter, file_text)
file_text = speaker_id + "-" + chapter_id + ext_txt
file_text = os.path.join(path, speaker_id, chapter_id, file_text)

fileid_audio = speaker + "-" + chapter + "-" + utterance
fileid_audio = speaker_id + "-" + chapter_id + "-" + utterance_id
file_audio = fileid_audio + ext_audio
file_audio = os.path.join(path, speaker, chapter, file_audio)
file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)

# Load audio
waveform, sample_rate = torchaudio.load(file_audio)

# Load text
for line in open(file_text):
fileid_text, content = line.strip().split(" ", 1)
if fileid_audio == fileid_text:
break
else:
# Translation not found
raise FileNotFoundError("Translation not found for " + fileid_audio)

return {
"speaker_id": speaker,
"chapter_id": chapter,
"utterance_id": utterance,
"utterance": content,
"waveform": waveform,
"sample_rate": sample_rate,
}
with open(file_text) as ft:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding this as part of this PR to make sure that the file is closed.

for line in ft:
fileid_text, utterance = line.strip().split(" ", 1)
if fileid_audio == fileid_text:
break
else:
# Translation not found
raise FileNotFoundError("Translation not found for " + fileid_audio)

return (
waveform,
sample_rate,
utterance,
int(speaker_id),
int(chapter_id),
int(utterance_id),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also converting to int though they come from file names.

)


class LIBRISPEECH(Dataset):
"""
Create a Dataset for LibriSpeech. Each item is a tuple of the form:
waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id
"""

_ext_txt = ".trans.txt"
_ext_audio = ".flac"
Expand Down
48 changes: 14 additions & 34 deletions torchaudio/datasets/vctk.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
def load_vctk_item(
fileid, path, ext_audio, ext_txt, folder_audio, folder_txt, downsample=False
):
speaker, utterance = fileid.split("_")
speaker_id, utterance_id = fileid.split("_")

# Read text
file_txt = os.path.join(path, folder_txt, speaker, fileid + ext_txt)
file_txt = os.path.join(path, folder_txt, speaker_id, fileid + ext_txt)
with open(file_txt) as file_text:
content = file_text.readlines()[0]
utterance = file_text.readlines()[0]

# Read wav
file_audio = os.path.join(path, folder_audio, speaker, fileid + ext_audio)
file_audio = os.path.join(path, folder_audio, speaker_id, fileid + ext_audio)
if downsample:
# Legacy
E = torchaudio.sox_effects.SoxEffectsChain()
Expand All @@ -34,16 +34,14 @@ def load_vctk_item(
else:
waveform, sample_rate = torchaudio.load(file_audio)

return {
"speaker_id": speaker,
"utterance_id": utterance,
"utterance": content,
"waveform": waveform,
"sample_rate": sample_rate,
}
return waveform, sample_rate, utterance, speaker_id, utterance_id


class VCTK(Dataset):
"""
Create a Dataset for VCTK. Each item is a tuple of the form:
(waveform, sample_rate, utterance, speaker_id, utterance_id)
"""

_folder_txt = "txt"
_folder_audio = "wav48"
Expand All @@ -59,17 +57,8 @@ def __init__(
downsample=False,
transform=None,
target_transform=None,
return_dict=False,
):

if not return_dict:
warnings.warn(
"In the next version, the item returned will be a dictionary. "
"Please use `return_dict=True` to enable this behavior now, "
"and suppress this warning.",
DeprecationWarning,
)

if downsample:
warnings.warn(
"In the next version, transforms will not be part of the dataset. "
Expand All @@ -89,7 +78,6 @@ def __init__(
self.downsample = downsample
self.transform = transform
self.target_transform = target_transform
self.return_dict = return_dict

archive = os.path.basename(url)
archive = os.path.join(root, archive)
Expand Down Expand Up @@ -122,23 +110,15 @@ def __getitem__(self, n):
self._folder_txt,
)

# Legacy
waveform = item["waveform"]
# TODO Upon deprecation, uncomment line below and remove following code
# return item

waveform, sample_rate, utterance, speaker_id, utterance_id = item
if self.transform is not None:
waveform = self.transform(waveform)
item["waveform"] = waveform

# Legacy
utterance = item["utterance"]
if self.target_transform is not None:
utterance = self.target_transform(utterance)
item["utterance"] = utterance

if self.return_dict:
return item

# Legacy
return item["waveform"], item["utterance"]
return waveform, sample_rate, utterance, speaker_id, utterance_id

def __len__(self):
return len(self._walker)
35 changes: 12 additions & 23 deletions torchaudio/datasets/yesno.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,20 @@

def load_yesno_item(fileid, path, ext_audio):
# Read label
label = fileid.split("_")
labels = [int(c) for c in fileid.split("_")]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also converting to int though they come from file names.


# Read wav
file_audio = os.path.join(path, fileid + ext_audio)
waveform, sample_rate = torchaudio.load(file_audio)

return {"label": label, "waveform": waveform, "sample_rate": sample_rate}
return waveform, sample_rate, labels


class YESNO(Dataset):
"""
Create a Dataset for YesNo. Each item is a tuple of the form:
(waveform, sample_rate, labels)
"""

_ext_audio = ".wav"

Expand All @@ -32,17 +36,8 @@ def __init__(
download=False,
transform=None,
target_transform=None,
return_dict=False,
):

if not return_dict:
warnings.warn(
"In the next version, the item returned will be a dictionary. "
"Please use `return_dict=True` to enable this behavior now, "
"and suppress this warning.",
DeprecationWarning,
)

if transform is not None or target_transform is not None:
warnings.warn(
"In the next version, transforms will not be part of the dataset. "
Expand All @@ -53,7 +48,6 @@ def __init__(

self.transform = transform
self.target_transform = target_transform
self.return_dict = return_dict

archive = os.path.basename(url)
archive = os.path.join(root, archive)
Expand All @@ -79,20 +73,15 @@ def __getitem__(self, n):
fileid = self._walker[n]
item = load_yesno_item(fileid, self._path, self._ext_audio)

waveform = item["waveform"]
# TODO Upon deprecation, uncomment line below and remove following code
# return item

waveform, sample_rate, labels = item
if self.transform is not None:
waveform = self.transform(waveform)
item["waveform"] = waveform

label = item["label"]
if self.target_transform is not None:
label = self.target_transform(label)
item["label"] = label

if self.return_dict:
return item

return item["waveform"], item["label"]
labels = self.target_transform(labels)
return waveform, sample_rate, labels

def __len__(self):
return len(self._walker)