Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed test/assets/waves_yesno/0_1_0_1_0_1_1_0.wav
Binary file not shown.
4 changes: 2 additions & 2 deletions test/common_utils/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def supports_mp3(backend):
def set_audio_backend(backend):
"""Allow additional backend value, 'default'"""
if backend == 'default':
if 'sox' in BACKENDS:
be = 'sox'
Comment on lines -32 to -33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this is removing "sox" from "set_audio_backend"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, it's changing the default backend to be sox_io when it's available.

if 'sox_io' in BACKENDS:
be = 'sox_io'
elif 'soundfile' in BACKENDS:
be = 'soundfile'
else:
Expand Down
12 changes: 6 additions & 6 deletions test/common_utils/test_case_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ class TempDirMixin:
"""Mixin to provide easy access to temp dir"""
temp_dir_ = None

@property
def base_temp_dir(self):
@classmethod
def get_base_temp_dir(cls):
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key = 'TORCHAUDIO_TEST_TEMP_DIR'
if key in os.environ:
return os.environ[key]
if self.__class__.temp_dir_ is None:
self.__class__.temp_dir_ = tempfile.TemporaryDirectory()
return self.__class__.temp_dir_.name
if cls.temp_dir_ is None:
cls.temp_dir_ = tempfile.TemporaryDirectory()
return cls.temp_dir_.name

@classmethod
def tearDownClass(cls):
Expand All @@ -34,7 +34,7 @@ def tearDownClass(cls):
cls.temp_dir_ = None

def get_temp_path(self, *paths):
temp_dir = os.path.join(self.base_temp_dir, self.id())
temp_dir = os.path.join(self.get_base_temp_dir(), self.id())
path = os.path.join(temp_dir, *paths)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
Expand Down
59 changes: 50 additions & 9 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import unittest

from torchaudio.datasets.commonvoice import COMMONVOICE
Expand All @@ -10,16 +11,19 @@
from torchaudio.datasets.gtzan import GTZAN
from torchaudio.datasets.cmuarctic import CMUARCTIC

from . import common_utils
from .common_utils import (
TempDirMixin,
TorchaudioTestCase,
get_asset_path,
get_whitenoise,
save_wav,
normalize_wav,
)


class TestDatasets(common_utils.TorchaudioTestCase):
class TestDatasets(TorchaudioTestCase):
backend = 'default'
path = common_utils.get_asset_path()

def test_yesno(self):
data = YESNO(self.path)
data[0]
path = get_asset_path()

def test_vctk(self):
data = VCTK(self.path)
Expand All @@ -46,9 +50,9 @@ def test_cmuarctic(self):
data[0]


class TestCommonVoice(common_utils.TorchaudioTestCase):
class TestCommonVoice(TorchaudioTestCase):
backend = 'default'
path = common_utils.get_asset_path()
path = get_asset_path()

def test_commonvoice(self):
data = COMMONVOICE(self.path, url="tatar")
Expand All @@ -69,5 +73,42 @@ def test_commonvoice_bg(self):
pass


class TestYesNo(TempDirMixin, TorchaudioTestCase):
backend = 'default'

root_dir = None
data = []
labels = [
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 1, 0, 1, 0, 1, 1, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1],
]

@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
base_dir = os.path.join(cls.root_dir, 'waves_yesno')
os.makedirs(base_dir, exist_ok=True)
for label in cls.labels:
filename = f'{"_".join(str(l) for l in label)}.wav'
path = os.path.join(base_dir, filename)
data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype='int16')
save_wav(path, data, 8000)
cls.data.append(normalize_wav(data))

def test_yesno(self):
dataset = YESNO(self.root_dir)
samples = list(dataset)
samples.sort(key=lambda s: s[2])
for i, (waveform, sample_rate, label) in enumerate(samples):
expected_label = self.labels[i]
expected_data = self.data[i]
self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
assert sample_rate == 8000
assert label == expected_label


if __name__ == "__main__":
unittest.main()