|
| 1 | +import os |
| 2 | + |
| 3 | +from torchaudio_unittest.common_utils import ( |
| 4 | + TempDirMixin, |
| 5 | + TorchaudioTestCase, |
| 6 | + get_whitenoise, |
| 7 | + save_wav, |
| 8 | + normalize_wav, |
| 9 | +) |
| 10 | + |
| 11 | +from utils.dataset import wsj0mix |
| 12 | + |
| 13 | + |
| 14 | +_FILENAMES = [ |
| 15 | + "012c0207_1.9952_01cc0202_-1.9952.wav", |
| 16 | + "01co0302_1.63_014c020q_-1.63.wav", |
| 17 | + "01do0316_0.24011_205a0104_-0.24011.wav", |
| 18 | + "01lc020x_1.1301_027o030r_-1.1301.wav", |
| 19 | + "01mc0202_0.34056_205o0106_-0.34056.wav", |
| 20 | + "01nc020t_0.53821_018o030w_-0.53821.wav", |
| 21 | + "01po030f_2.2136_40ko031a_-2.2136.wav", |
| 22 | + "01ra010o_2.4098_403a010f_-2.4098.wav", |
| 23 | + "01xo030b_0.22377_016o031a_-0.22377.wav", |
| 24 | + "02ac020x_0.68566_01ec020b_-0.68566.wav", |
| 25 | + "20co010m_0.82801_019c0212_-0.82801.wav", |
| 26 | + "20da010u_1.2483_017c0211_-1.2483.wav", |
| 27 | + "20oo010d_1.0631_01ic020s_-1.0631.wav", |
| 28 | + "20sc0107_2.0222_20fo010h_-2.0222.wav", |
| 29 | + "20tc010f_0.051456_404a0110_-0.051456.wav", |
| 30 | + "407c0214_1.1712_02ca0113_-1.1712.wav", |
| 31 | + "40ao030w_2.4697_20vc010a_-2.4697.wav", |
| 32 | + "40pa0101_1.1087_40ea0107_-1.1087.wav", |
| 33 | +] |
| 34 | + |
| 35 | + |
| 36 | +def _mock_dataset(root_dir, num_speaker): |
| 37 | + dirnames = ["mix"] + [f"s{i+1}" for i in range(num_speaker)] |
| 38 | + for dirname in dirnames: |
| 39 | + os.makedirs(os.path.join(root_dir, dirname), exist_ok=True) |
| 40 | + |
| 41 | + seed = 0 |
| 42 | + sample_rate = 8000 |
| 43 | + expected = [] |
| 44 | + for filename in _FILENAMES: |
| 45 | + mix = None |
| 46 | + src = [] |
| 47 | + for dirname in dirnames: |
| 48 | + waveform = get_whitenoise( |
| 49 | + sample_rate=8000, duration=1, n_channels=1, dtype="int16", seed=seed |
| 50 | + ) |
| 51 | + seed += 1 |
| 52 | + |
| 53 | + path = os.path.join(root_dir, dirname, filename) |
| 54 | + save_wav(path, waveform, sample_rate) |
| 55 | + waveform = normalize_wav(waveform) |
| 56 | + |
| 57 | + if dirname == "mix": |
| 58 | + mix = waveform |
| 59 | + else: |
| 60 | + src.append(waveform) |
| 61 | + expected.append(wsj0mix.Sample(sample_rate, mix, src)) |
| 62 | + return expected |
| 63 | + |
| 64 | + |
| 65 | +class TestWSJ0Mix2(TempDirMixin, TorchaudioTestCase): |
| 66 | + backend = "default" |
| 67 | + root_dir = None |
| 68 | + expected = None |
| 69 | + |
| 70 | + @classmethod |
| 71 | + def setUpClass(cls): |
| 72 | + cls.root_dir = cls.get_base_temp_dir() |
| 73 | + cls.expected = _mock_dataset(cls.root_dir, 2) |
| 74 | + |
| 75 | + def test_wsj0mix(self): |
| 76 | + dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=2, sample_rate=8000) |
| 77 | + |
| 78 | + n_ite = 0 |
| 79 | + for i, sample in enumerate(dataset): |
| 80 | + expected = self.expected[i] |
| 81 | + self.assertEqual(sample.mix, expected.mix, atol=5e-5, rtol=1e-8) |
| 82 | + self.assertEqual(sample.src[0], expected.src[0], atol=5e-5, rtol=1e-8) |
| 83 | + self.assertEqual(sample.src[1], expected.src[1], atol=5e-5, rtol=1e-8) |
| 84 | + n_ite += 1 |
| 85 | + assert n_ite == len(self.expected) |
| 86 | + |
| 87 | + |
| 88 | +class TestWSJ0Mix3(TempDirMixin, TorchaudioTestCase): |
| 89 | + backend = "default" |
| 90 | + root_dir = None |
| 91 | + expected = None |
| 92 | + |
| 93 | + @classmethod |
| 94 | + def setUpClass(cls): |
| 95 | + cls.root_dir = cls.get_base_temp_dir() |
| 96 | + cls.expected = _mock_dataset(cls.root_dir, 3) |
| 97 | + |
| 98 | + def test_wsj0mix(self): |
| 99 | + dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=3, sample_rate=8000) |
| 100 | + |
| 101 | + n_ite = 0 |
| 102 | + for i, sample in enumerate(dataset): |
| 103 | + expected = self.expected[i] |
| 104 | + self.assertEqual(sample.mix, expected.mix, atol=5e-5, rtol=1e-8) |
| 105 | + self.assertEqual(sample.src[0], expected.src[0], atol=5e-5, rtol=1e-8) |
| 106 | + self.assertEqual(sample.src[1], expected.src[1], atol=5e-5, rtol=1e-8) |
| 107 | + self.assertEqual(sample.src[2], expected.src[2], atol=5e-5, rtol=1e-8) |
| 108 | + n_ite += 1 |
| 109 | + assert n_ite == len(self.expected) |
0 commit comments