diff --git a/examples/source_separation/utils/__init__.py b/examples/source_separation/utils/__init__.py index 4df46bd9b8..8ddd18d932 100644 --- a/examples/source_separation/utils/__init__.py +++ b/examples/source_separation/utils/__init__.py @@ -1,3 +1,4 @@ from . import ( + dataset, metrics, ) diff --git a/examples/source_separation/utils/dataset/__init__.py b/examples/source_separation/utils/dataset/__init__.py new file mode 100644 index 0000000000..1a5e0b10c9 --- /dev/null +++ b/examples/source_separation/utils/dataset/__init__.py @@ -0,0 +1 @@ +from . import utils, wsj0mix diff --git a/examples/source_separation/utils/dataset/utils.py b/examples/source_separation/utils/dataset/utils.py new file mode 100644 index 0000000000..5b4b9cf2cc --- /dev/null +++ b/examples/source_separation/utils/dataset/utils.py @@ -0,0 +1,83 @@ +from typing import List +from functools import partial +from collections import namedtuple + +import torch + +from . import wsj0mix + +Batch = namedtuple("Batch", ["mix", "src", "mask"]) + + +def get_dataset(dataset_type, root_dir, num_speakers, sample_rate): + if dataset_type == "wsj0mix": + train = wsj0mix.WSJ0Mix(root_dir / "tr", num_speakers, sample_rate) + validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate) + evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate) + else: + raise ValueError(f"Unexpected dataset: {dataset_type}") + return train, validation, evaluation + + +def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, random_start=False): + """Ensure waveform has exact number of frames by slicing or padding""" + mix = sample[1] # [1, num_frames] + src = torch.cat(sample[2], 0) # [num_sources, num_frames] + + num_channels, num_frames = src.shape + if num_frames >= target_num_frames: + if random_start and num_frames > target_num_frames: + start_frame = torch.randint(num_frames - target_num_frames, [1]) + mix = mix[:, start_frame:] + src = src[:, start_frame:] + mix = mix[:, :target_num_frames] + src = src[:, :target_num_frames] + mask = torch.ones_like(mix) + else: + num_padding = target_num_frames - num_frames + pad = torch.zeros([1, num_padding], dtype=mix.dtype, device=mix.device) + mix = torch.cat([mix, pad], 1) + src = torch.cat([src, pad.expand(num_channels, -1)], 1) + mask = torch.ones_like(mix) + mask[..., num_frames:] = 0 + return mix, src, mask + + + +def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, duration): + target_num_frames = int(duration * sample_rate) + + mixes, srcs, masks = [], [], [] + for sample in samples: + mix, src, mask = _fix_num_frames(sample, target_num_frames, random_start=True) + + mixes.append(mix) + srcs.append(src) + masks.append(mask) + + return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0)) + + +def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType]): + max_num_frames = max(s[1].shape[-1] for s in samples) + + mixes, srcs, masks = [], [], [] + for sample in samples: + mix, src, mask = _fix_num_frames(sample, max_num_frames, random_start=False) + + mixes.append(mix) + srcs.append(src) + masks.append(mask) + + return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0)) + + +def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4): + assert mode in ["train", "test"] + if dataset_type == "wsj0mix": + if mode == 'train': + if sample_rate is None: + raise ValueError("sample_rate is not given.") + return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration) + return collate_fn_wsj0mix_test + raise ValueError(f"Unexpected dataset: {dataset_type}") diff --git a/examples/source_separation/utils/dataset/wsj0mix.py b/examples/source_separation/utils/dataset/wsj0mix.py new file mode 100644 index 0000000000..efd9268971 --- /dev/null +++ b/examples/source_separation/utils/dataset/wsj0mix.py @@ -0,0 +1,70 @@ +from pathlib import Path +from typing import Union, Tuple, List + +import torch +from torch.utils.data import Dataset + +import torchaudio + +SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]] + + +class WSJ0Mix(Dataset): + """Create a Dataset for wsj0-mix. + + Args: + root (str or Path): Path to the directory where the dataset is found. + num_speakers (int): The number of speakers, which determines the directories + to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect + N source audios. + sample_rate (int): Expected sample rate of audio files. If any of the audio has a + different sample rate, raises ``ValueError``. + audio_ext (str): The extension of audio files to find. (default: ".wav") + """ + def __init__( + self, + root: Union[str, Path], + num_speakers: int, + sample_rate: int, + audio_ext: str = ".wav", + ): + self.root = Path(root) + self.sample_rate = sample_rate + self.mix_dir = (self.root / "mix").resolve() + self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)] + + self.files = [p.name for p in self.mix_dir.glob(f"*{audio_ext}")] + self.files.sort() + + def _load_audio(self, path) -> torch.Tensor: + waveform, sample_rate = torchaudio.load(path) + if sample_rate != self.sample_rate: + raise ValueError( + f"The dataset contains audio file of sample rate {sample_rate}. " + "Where the requested sample rate is {self.sample_rate}." + ) + return waveform + + def _load_sample(self, filename) -> SampleType: + mixed = self._load_audio(str(self.mix_dir / filename)) + srcs = [] + for i, dir_ in enumerate(self.src_dirs): + src = self._load_audio(str(dir_ / filename)) + if mixed.shape != src.shape: + raise ValueError( + f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}" + ) + srcs.append(src) + return self.sample_rate, mixed, srcs + + def __len__(self) -> int: + return len(self.files) + + def __getitem__(self, key: int) -> SampleType: + """Load the n-th sample from the dataset. + Args: + n (int): The index of the sample to be loaded + Returns: + tuple: ``(sample_rate, mix_waveform, list_of_source_waveforms)`` + """ + return self._load_sample(self.files[key]) diff --git a/test/torchaudio_unittest/example/souce_sepration/wsj0mix_test.py b/test/torchaudio_unittest/example/souce_sepration/wsj0mix_test.py new file mode 100644 index 0000000000..46927b182f --- /dev/null +++ b/test/torchaudio_unittest/example/souce_sepration/wsj0mix_test.py @@ -0,0 +1,111 @@ +import os + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + save_wav, + normalize_wav, +) + +from source_separation.utils.dataset import wsj0mix + + +_FILENAMES = [ + "012c0207_1.9952_01cc0202_-1.9952.wav", + "01co0302_1.63_014c020q_-1.63.wav", + "01do0316_0.24011_205a0104_-0.24011.wav", + "01lc020x_1.1301_027o030r_-1.1301.wav", + "01mc0202_0.34056_205o0106_-0.34056.wav", + "01nc020t_0.53821_018o030w_-0.53821.wav", + "01po030f_2.2136_40ko031a_-2.2136.wav", + "01ra010o_2.4098_403a010f_-2.4098.wav", + "01xo030b_0.22377_016o031a_-0.22377.wav", + "02ac020x_0.68566_01ec020b_-0.68566.wav", + "20co010m_0.82801_019c0212_-0.82801.wav", + "20da010u_1.2483_017c0211_-1.2483.wav", + "20oo010d_1.0631_01ic020s_-1.0631.wav", + "20sc0107_2.0222_20fo010h_-2.0222.wav", + "20tc010f_0.051456_404a0110_-0.051456.wav", + "407c0214_1.1712_02ca0113_-1.1712.wav", + "40ao030w_2.4697_20vc010a_-2.4697.wav", + "40pa0101_1.1087_40ea0107_-1.1087.wav", +] + + +def _mock_dataset(root_dir, num_speaker): + dirnames = ["mix"] + [f"s{i+1}" for i in range(num_speaker)] + for dirname in dirnames: + os.makedirs(os.path.join(root_dir, dirname), exist_ok=True) + + seed = 0 + sample_rate = 8000 + expected = [] + for filename in _FILENAMES: + mix = None + src = [] + for dirname in dirnames: + waveform = get_whitenoise( + sample_rate=8000, duration=1, n_channels=1, dtype="int16", seed=seed + ) + seed += 1 + + path = os.path.join(root_dir, dirname, filename) + save_wav(path, waveform, sample_rate) + waveform = normalize_wav(waveform) + + if dirname == "mix": + mix = waveform + else: + src.append(waveform) + expected.append((sample_rate, mix, src)) + return expected + + +class TestWSJ0Mix2(TempDirMixin, TorchaudioTestCase): + backend = "default" + root_dir = None + expected = None + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + cls.expected = _mock_dataset(cls.root_dir, 2) + + def test_wsj0mix(self): + dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=2, sample_rate=8000) + + n_ite = 0 + for i, sample in enumerate(dataset): + (_, sample_mix, sample_src) = sample + (_, expected_mix, expected_src) = self.expected[i] + self.assertEqual(sample_mix, expected_mix, atol=5e-5, rtol=1e-8) + self.assertEqual(sample_src[0], expected_src[0], atol=5e-5, rtol=1e-8) + self.assertEqual(sample_src[1], expected_src[1], atol=5e-5, rtol=1e-8) + n_ite += 1 + assert n_ite == len(self.expected) + + +class TestWSJ0Mix3(TempDirMixin, TorchaudioTestCase): + backend = "default" + root_dir = None + expected = None + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + cls.expected = _mock_dataset(cls.root_dir, 3) + + def test_wsj0mix(self): + dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=3, sample_rate=8000) + + n_ite = 0 + for i, sample in enumerate(dataset): + (_, sample_mix, sample_src) = sample + (_, expected_mix, expected_src) = self.expected[i] + self.assertEqual(sample_mix, expected_mix, atol=5e-5, rtol=1e-8) + self.assertEqual(sample_src[0], expected_src[0], atol=5e-5, rtol=1e-8) + self.assertEqual(sample_src[1], expected_src[1], atol=5e-5, rtol=1e-8) + self.assertEqual(sample_src[2], expected_src[2], atol=5e-5, rtol=1e-8) + n_ite += 1 + assert n_ite == len(self.expected)