From 76d6789706f721c41e28e480fe9a37edcf4a55c2 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Fri, 1 Jan 2021 18:35:04 +0530 Subject: [PATCH 1/2] refactor TestGTZAN unittest --- .../datasets/gtzan_test.py | 54 ++++++++++++------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/test/torchaudio_unittest/datasets/gtzan_test.py b/test/torchaudio_unittest/datasets/gtzan_test.py index 763e0a866f..c0dee3a21c 100644 --- a/test/torchaudio_unittest/datasets/gtzan_test.py +++ b/test/torchaudio_unittest/datasets/gtzan_test.py @@ -11,6 +11,36 @@ normalize_wav, ) +def get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + mocked_samples = [] + mocked_training = [] + mocked_validation = [] + mocked_testing = [] + sample_rate = 22050 + + seed = 0 + for genre in gtzan.gtzan_genres: + base_dir = os.path.join(root_dir, 'genres', genre) + os.makedirs(base_dir, exist_ok=True) + for i in range(100): + filename = f'{genre}.{i:05d}' + path = os.path.join(base_dir, f'{filename}.wav') + data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype='int16', seed=seed) + save_wav(path, data, sample_rate) + sample = (normalize_wav(data), sample_rate, genre) + mocked_samples.append(sample) + if filename in gtzan.filtered_test: + mocked_testing.append(sample) + if filename in gtzan.filtered_train: + mocked_training.append(sample) + if filename in gtzan.filtered_valid: + mocked_validation.append(sample) + seed += 1 + return (mocked_samples, mocked_training, mocked_validation, mocked_testing) + class TestGTZAN(TempDirMixin, TorchaudioTestCase): backend = 'default' @@ -24,25 +54,11 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase): @classmethod def setUpClass(cls): cls.root_dir = cls.get_base_temp_dir() - sample_rate = 22050 - seed = 0 - for genre in gtzan.gtzan_genres: - base_dir = os.path.join(cls.root_dir, 'genres', genre) - os.makedirs(base_dir, exist_ok=True) - for i in range(100): - filename = f'{genre}.{i:05d}' - path = os.path.join(base_dir, f'{filename}.wav') - data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype='int16', seed=seed) - save_wav(path, data, sample_rate) - sample = (normalize_wav(data), sample_rate, genre) - cls.samples.append(sample) - if filename in gtzan.filtered_test: - cls.testing.append(sample) - if filename in gtzan.filtered_train: - cls.training.append(sample) - if filename in gtzan.filtered_valid: - cls.validation.append(sample) - seed += 1 + mocked_data = get_mock_dataset(cls.root_dir) + cls.samples = mocked_data[0] + cls.training = mocked_data[1] + cls.validation = mocked_data[2] + cls.testing = mocked_data[3] def test_no_subset(self): dataset = gtzan.GTZAN(self.root_dir) From a5b6861de45dde2302d1cc8143f673ba3ebdefa5 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Sat, 2 Jan 2021 00:54:00 +0530 Subject: [PATCH 2/2] fix style --- test/torchaudio_unittest/datasets/gtzan_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/torchaudio_unittest/datasets/gtzan_test.py b/test/torchaudio_unittest/datasets/gtzan_test.py index c0dee3a21c..838292f55d 100644 --- a/test/torchaudio_unittest/datasets/gtzan_test.py +++ b/test/torchaudio_unittest/datasets/gtzan_test.py @@ -11,6 +11,7 @@ normalize_wav, ) + def get_mock_dataset(root_dir): """ root_dir: directory to the mocked dataset