diff --git a/test/torchaudio_unittest/datasets/gtzan_test.py b/test/torchaudio_unittest/datasets/gtzan_test.py index 763e0a866f..838292f55d 100644 --- a/test/torchaudio_unittest/datasets/gtzan_test.py +++ b/test/torchaudio_unittest/datasets/gtzan_test.py @@ -12,6 +12,37 @@ ) +def get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + mocked_samples = [] + mocked_training = [] + mocked_validation = [] + mocked_testing = [] + sample_rate = 22050 + + seed = 0 + for genre in gtzan.gtzan_genres: + base_dir = os.path.join(root_dir, 'genres', genre) + os.makedirs(base_dir, exist_ok=True) + for i in range(100): + filename = f'{genre}.{i:05d}' + path = os.path.join(base_dir, f'{filename}.wav') + data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype='int16', seed=seed) + save_wav(path, data, sample_rate) + sample = (normalize_wav(data), sample_rate, genre) + mocked_samples.append(sample) + if filename in gtzan.filtered_test: + mocked_testing.append(sample) + if filename in gtzan.filtered_train: + mocked_training.append(sample) + if filename in gtzan.filtered_valid: + mocked_validation.append(sample) + seed += 1 + return (mocked_samples, mocked_training, mocked_validation, mocked_testing) + + class TestGTZAN(TempDirMixin, TorchaudioTestCase): backend = 'default' @@ -24,25 +55,11 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase): @classmethod def setUpClass(cls): cls.root_dir = cls.get_base_temp_dir() - sample_rate = 22050 - seed = 0 - for genre in gtzan.gtzan_genres: - base_dir = os.path.join(cls.root_dir, 'genres', genre) - os.makedirs(base_dir, exist_ok=True) - for i in range(100): - filename = f'{genre}.{i:05d}' - path = os.path.join(base_dir, f'{filename}.wav') - data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype='int16', seed=seed) - save_wav(path, data, sample_rate) - sample = (normalize_wav(data), sample_rate, genre) - cls.samples.append(sample) - if filename in gtzan.filtered_test: - cls.testing.append(sample) - if filename in gtzan.filtered_train: - cls.training.append(sample) - if filename in gtzan.filtered_valid: - cls.validation.append(sample) - seed += 1 + mocked_data = get_mock_dataset(cls.root_dir) + cls.samples = mocked_data[0] + cls.training = mocked_data[1] + cls.validation = mocked_data[2] + cls.testing = mocked_data[3] def test_no_subset(self): dataset = gtzan.GTZAN(self.root_dir)