diff --git a/test/assets/genres/blues/blues.00000.wav b/test/assets/genres/blues/blues.00000.wav deleted file mode 100644 index cf1cbcde22..0000000000 Binary files a/test/assets/genres/blues/blues.00000.wav and /dev/null differ diff --git a/test/assets/genres/noise/noise.0000.wav b/test/assets/genres/noise/noise.0000.wav deleted file mode 100644 index cf1cbcde22..0000000000 Binary files a/test/assets/genres/noise/noise.0000.wav and /dev/null differ diff --git a/test/datasets/datasets_test.py b/test/datasets/datasets_test.py index 5026b270a4..b8a980f346 100644 --- a/test/datasets/datasets_test.py +++ b/test/datasets/datasets_test.py @@ -4,7 +4,6 @@ from torchaudio.datasets.utils import diskcache_iterator, bg_iterator from torchaudio.datasets.vctk import VCTK from torchaudio.datasets.ljspeech import LJSPEECH -from torchaudio.datasets.gtzan import GTZAN from torchaudio.datasets.cmuarctic import CMUARCTIC from ..common_utils import ( @@ -33,10 +32,6 @@ def test_speechcommands(self): data = SPEECHCOMMANDS(self.path) data[0] - def test_gtzan(self): - data = GTZAN(self.path) - data[0] - def test_cmuarctic(self): data = CMUARCTIC(self.path) data[0] diff --git a/test/datasets/gtzan_test.py b/test/datasets/gtzan_test.py new file mode 100644 index 0000000000..14c82a512a --- /dev/null +++ b/test/datasets/gtzan_test.py @@ -0,0 +1,88 @@ +import os + +from torchaudio.datasets import gtzan + +from ..common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + save_wav, + normalize_wav, +) + + +class TestGTZAN(TempDirMixin, TorchaudioTestCase): + backend = 'default' + + root_dir = None + samples = [] + training = [] + validation = [] + testing = [] + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + sample_rate = 22050 + seed = 0 + for genre in gtzan.gtzan_genres: + base_dir = os.path.join(cls.root_dir, 'genres', genre) + os.makedirs(base_dir, exist_ok=True) + for i in range(100): + filename = f'{genre}.{i:05d}' + path = os.path.join(base_dir, f'{filename}.wav') + data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype='int16', seed=seed) + save_wav(path, data, sample_rate) + sample = (normalize_wav(data), sample_rate, genre) + cls.samples.append(sample) + if filename in gtzan.filtered_test: + cls.testing.append(sample) + if filename in gtzan.filtered_train: + cls.training.append(sample) + if filename in gtzan.filtered_valid: + cls.validation.append(sample) + seed += 1 + + def test_no_subset(self): + dataset = gtzan.GTZAN(self.root_dir) + + n_ite = 0 + for i, (waveform, sample_rate, label) in enumerate(dataset): + self.assertEqual(waveform, self.samples[i][0], atol=5e-5, rtol=1e-8) + assert sample_rate == self.samples[i][1] + assert label == self.samples[i][2] + n_ite += 1 + assert n_ite == len(self.samples) + + def test_training(self): + dataset = gtzan.GTZAN(self.root_dir, subset='training') + + n_ite = 0 + for i, (waveform, sample_rate, label) in enumerate(dataset): + self.assertEqual(waveform, self.training[i][0], atol=5e-5, rtol=1e-8) + assert sample_rate == self.training[i][1] + assert label == self.training[i][2] + n_ite += 1 + assert n_ite == len(self.training) + + def test_validation(self): + dataset = gtzan.GTZAN(self.root_dir, subset='validation') + + n_ite = 0 + for i, (waveform, sample_rate, label) in enumerate(dataset): + self.assertEqual(waveform, self.validation[i][0], atol=5e-5, rtol=1e-8) + assert sample_rate == self.validation[i][1] + assert label == self.validation[i][2] + n_ite += 1 + assert n_ite == len(self.validation) + + def test_testing(self): + dataset = gtzan.GTZAN(self.root_dir, subset='testing') + + n_ite = 0 + for i, (waveform, sample_rate, label) in enumerate(dataset): + self.assertEqual(waveform, self.testing[i][0], atol=5e-5, rtol=1e-8) + assert sample_rate == self.testing[i][1] + assert label == self.testing[i][2] + n_ite += 1 + assert n_ite == len(self.testing) diff --git a/torchaudio/datasets/gtzan.py b/torchaudio/datasets/gtzan.py index 9098cf1fe0..e031801bc1 100644 --- a/torchaudio/datasets/gtzan.py +++ b/torchaudio/datasets/gtzan.py @@ -1064,6 +1064,7 @@ def __init__( continue songs_in_genre = os.listdir(fulldir) + songs_in_genre.sort() for fname in songs_in_genre: name, ext = os.path.splitext(fname) if ext.lower() == ".wav" and "." in name: