Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 114 additions & 29 deletions test/torchaudio_unittest/datasets/speechcommands_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,47 +53,60 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):

root_dir = None
samples = []
train_samples = []
valid_samples = []
test_samples = []

@classmethod
def setUp(cls):
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
dataset_dir = os.path.join(
cls.root_dir, speechcommands.FOLDER_IN_ARCHIVE, speechcommands.URL
)
os.makedirs(dataset_dir, exist_ok=True)
sample_rate = 16000 # 16kHz sample rate
seed = 0
for label in LABELS:
path = os.path.join(dataset_dir, label)
os.makedirs(path, exist_ok=True)
for j in range(2):
# generate hash ID for speaker
speaker = "{:08x}".format(j)

for utterance in range(3):
filename = f"{speaker}{speechcommands.HASH_DIVIDER}{utterance}.wav"
file_path = os.path.join(path, filename)
seed += 1
data = get_whitenoise(
sample_rate=sample_rate,
duration=0.01,
n_channels=1,
dtype="int16",
seed=seed,
)
save_wav(file_path, data, sample_rate)
sample = (
normalize_wav(data),
sample_rate,
label,
speaker,
utterance,
)
cls.samples.append(sample)
valid_file = os.path.join(dataset_dir, "validation_list.txt")
test_file = os.path.join(dataset_dir, "testing_list.txt")
with open(valid_file, "w") as valid, open(test_file, "w") as test:
for label in LABELS:
path = os.path.join(dataset_dir, label)
os.makedirs(path, exist_ok=True)
for j in range(6):
# generate hash ID for speaker
speaker = "{:08x}".format(j)

for utterance in range(3):
filename = f"{speaker}{speechcommands.HASH_DIVIDER}{utterance}.wav"
file_path = os.path.join(path, filename)
seed += 1
data = get_whitenoise(
sample_rate=sample_rate,
duration=0.01,
n_channels=1,
dtype="int16",
seed=seed,
)
save_wav(file_path, data, sample_rate)
sample = (
normalize_wav(data),
sample_rate,
label,
speaker,
utterance,
)
cls.samples.append(sample)
if j < 2:
cls.train_samples.append(sample)
elif j < 4:
valid.write(f'{label}/{filename}\n')
cls.valid_samples.append(sample)
elif j < 6:
test.write(f'{label}/{filename}\n')
cls.test_samples.append(sample)

def testSpeechCommands(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir)
print(dataset._path)

num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
Expand All @@ -107,3 +120,75 @@ def testSpeechCommands(self):
num_samples += 1

assert num_samples == len(self.samples)

def testSpeechCommandsNone(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset=None)

num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[i][1]
assert label == self.samples[i][2]
assert speaker_id == self.samples[i][3]
assert utterance_number == self.samples[i][4]
num_samples += 1

assert num_samples == len(self.samples)

def testSpeechCommandsSubsetTrain(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="training")

num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.train_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.train_samples[i][1]
assert label == self.train_samples[i][2]
assert speaker_id == self.train_samples[i][3]
assert utterance_number == self.train_samples[i][4]
num_samples += 1

assert num_samples == len(self.train_samples)

def testSpeechCommandsSubsetValid(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="validation")

num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.valid_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.valid_samples[i][1]
assert label == self.valid_samples[i][2]
assert speaker_id == self.valid_samples[i][3]
assert utterance_number == self.valid_samples[i][4]
num_samples += 1

assert num_samples == len(self.valid_samples)

def testSpeechCommandsSubsetTest(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="testing")

num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.test_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.test_samples[i][1]
assert label == self.test_samples[i][2]
assert speaker_id == self.test_samples[i][3]
assert utterance_number == self.test_samples[i][4]
num_samples += 1

assert num_samples == len(self.test_samples)

def testSpeechCommandsSum(self):
dataset_all = speechcommands.SPEECHCOMMANDS(self.root_dir)
dataset_train = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="training")
dataset_valid = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="validation")
dataset_test = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="testing")

assert len(dataset_train) + len(dataset_valid) + len(dataset_test) == len(dataset_all)
44 changes: 39 additions & 5 deletions torchaudio/datasets/speechcommands.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Tuple
from typing import Tuple, Optional

import torchaudio
from torch.utils.data import Dataset
Expand All @@ -22,6 +22,15 @@
}


def _load_list(root, *filenames):
output = []
for filename in filenames:
filepath = os.path.join(root, filename)
with open(filepath) as fileobj:
output += [os.path.normpath(os.path.join(root, line.strip())) for line in fileobj]
Copy link
Contributor

@mthrok mthrok Oct 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment on why normapth is necessary here?
Even though normpath's official document says On Windows, it converts forward slashes to backward slashes. To normalize case, use normcase().. For many of us, who are linux users, It is not well known. (I learned this only today.)

return output


def load_speechcommands_item(filepath: str, path: str) -> Tuple[Tensor, int, str, str, int]:
relpath = os.path.relpath(filepath, path)
label, filename = os.path.split(relpath)
Expand All @@ -48,13 +57,25 @@ class SPEECHCOMMANDS(Dataset):
The top-level directory of the dataset. (default: ``"SpeechCommands"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
subset (Optional[str]):
Select a subset of the dataset [None, "training", "validation", "testing"]. None means
the whole dataset. "validation" and "testing" are defined in "validation_list.txt" and
"testing_list.txt", respectively, and "training" is the rest. (default: ``None``)
"""

def __init__(self,
root: str,
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False) -> None:
download: bool = False,
subset: Optional[str] = None,
) -> None:

assert subset is None or subset in ["training", "validation", "testing"], (
"When `subset` not None, it must take a value from "
+ "{'training', 'validation', 'testing'}."
)

if url in [
"speech_commands_v0.01",
"speech_commands_v0.02",
Expand All @@ -79,9 +100,22 @@ def __init__(self,
download_url(url, root, hash_value=checksum, hash_type="md5")
extract_archive(archive, self._path)

walker = walk_files(self._path, suffix=".wav", prefix=True)
walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker)
self._walker = list(walker)
if subset == "validation":
self._walker = _load_list(self._path, "validation_list.txt")
elif subset == "testing":
self._walker = _load_list(self._path, "testing_list.txt")
elif subset == "training":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see there is a behavior inconsistency between "training"/None and "validation"/"testing".

If certain valid files are removed from the dataset, then this dataset implementation will keep working for "training"/None case with less files, but for "validation"/"testing" it will raise an exception.

What's your take on that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"testing", "validation", "training" are defined by the dataset by the training/validation files. It is technically undefined by the dataset outside of that. Given how the three are defined by the dataset in those files, as a user, I'd expect changes to those file to propagate. I'd say it'd be fair to add a quick note in the docstring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is technically undefined by the dataset outside of that.

I agree with that, however IIRC, one of the goal of torchaudio's dataset implementation is to make it easy to modify the dataset. Something along the line of point 3&4 of #852 (comment) . With this rule, we have to take the extra step to think through what kind of modification is valid/invalid and what is the expected behavior, and put it in the implementation.

I am trying to raise the awareness of it.

Copy link
Contributor

@mthrok mthrok Oct 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And of course, if we are just getting rid of the easy modification of the dataset, we do no need think about the any modification to dataset, and we leave it as UB.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, SpeechCommands does explain how those two files are generated, and how to generalize their approach, see README.

excludes = set(_load_list(self._path, "validation_list.txt", "testing_list.txt"))
walker = walk_files(self._path, suffix=".wav", prefix=True)
self._walker = [
w for w in walker
if HASH_DIVIDER in w
and EXCEPT_FOLDER not in w
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if HASH_DIVIDER in w and EXCEPT_FOLDER not in w is redundant because files in background noises do not have _nohash_ in filenames.

Copy link
Contributor Author

@vincentqb vincentqb Oct 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll avoid changing the logic here to avoid breaking any codes, say if a user moved a file in background_noise folder/etc. This can be looked at in a later PR.

and os.path.normpath(w) not in excludes
]
else:
walker = walk_files(self._path, suffix=".wav", prefix=True)
self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w]

def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]:
"""Load the n-th sample from the dataset.
Expand Down