From 227ff6db2cfae278e262fa83956d99d6dddf933b Mon Sep 17 00:00:00 2001 From: AzizCode92 Date: Mon, 4 Jan 2021 08:16:08 +0100 Subject: [PATCH] refactor gtzan unittest --- .../datasets/gtzan_test.py | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/test/torchaudio_unittest/datasets/gtzan_test.py b/test/torchaudio_unittest/datasets/gtzan_test.py index 763e0a866f..74ed7eb8ce 100644 --- a/test/torchaudio_unittest/datasets/gtzan_test.py +++ b/test/torchaudio_unittest/datasets/gtzan_test.py @@ -12,6 +12,33 @@ ) +def get_mock_dataset(root_dir): + """ + prepares mocked data + """ + mocked_samples, mocked_training, mocked_testing, mocked_validation = [], [], [], [] + sample_rate = 22050 + seed = 0 + for genre in gtzan.gtzan_genres: + base_dir = os.path.join(root_dir, 'genres', genre) + os.makedirs(base_dir, exist_ok=True) + for i in range(100): + filename = f'{genre}.{i:05d}' + path = os.path.join(base_dir, f'{filename}.wav') + data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype='int16', seed=seed) + save_wav(path, data, sample_rate) + sample = (normalize_wav(data), sample_rate, genre) + mocked_samples.append(sample) + if filename in gtzan.filtered_test: + mocked_testing.append(sample) + if filename in gtzan.filtered_train: + mocked_training.append(sample) + if filename in gtzan.filtered_valid: + mocked_validation.append(sample) + seed += 1 + return mocked_samples, mocked_training, mocked_testing, mocked_validation + + class TestGTZAN(TempDirMixin, TorchaudioTestCase): backend = 'default' @@ -24,25 +51,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase): @classmethod def setUpClass(cls): cls.root_dir = cls.get_base_temp_dir() - sample_rate = 22050 - seed = 0 - for genre in gtzan.gtzan_genres: - base_dir = os.path.join(cls.root_dir, 'genres', genre) - os.makedirs(base_dir, exist_ok=True) - for i in range(100): - filename = f'{genre}.{i:05d}' - path = os.path.join(base_dir, f'{filename}.wav') - data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype='int16', seed=seed) - save_wav(path, data, sample_rate) - sample = (normalize_wav(data), sample_rate, genre) - cls.samples.append(sample) - if filename in gtzan.filtered_test: - cls.testing.append(sample) - if filename in gtzan.filtered_train: - cls.training.append(sample) - if filename in gtzan.filtered_valid: - cls.validation.append(sample) - seed += 1 + cls.samples, cls.training, cls.testing, cls.validation = get_mock_dataset(cls.root_dir) def test_no_subset(self): dataset = gtzan.GTZAN(self.root_dir)