Skip to content

Commit 89c1c67

Browse files
committed
Add wsj0mix dataset
1 parent 725f8b0 commit 89c1c67

File tree

5 files changed

+266
-0
lines changed

5 files changed

+266
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from . import (
2+
dataset,
23
metrics,
34
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import utils, wsj0mix
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from typing import List
2+
from functools import partial
3+
from collections import namedtuple
4+
5+
import torch
6+
7+
from . import wsj0mix
8+
9+
Batch = namedtuple("Batch", ["mix", "src", "mask"])
10+
11+
12+
def get_dataset(dataset_type, root_dir, num_speakers, sample_rate):
13+
if dataset_type == "wsj0mix":
14+
train = wsj0mix.WSJ0Mix(root_dir / "tr", num_speakers, sample_rate)
15+
validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate)
16+
evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate)
17+
else:
18+
raise ValueError(f"Unexpected dataset: {dataset_type}")
19+
return train, validation, evaluation
20+
21+
22+
def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, random_start=False):
23+
"""Ensure waveform has exact number of frames by slicing or padding"""
24+
mix = sample[1] # [1, num_frames]
25+
src = torch.cat(sample[2], 0) # [num_sources, num_frames]
26+
27+
num_channels, num_frames = src.shape
28+
if num_frames >= target_num_frames:
29+
if random_start and num_frames > target_num_frames:
30+
start_frame = torch.randint(num_frames - target_num_frames, [1])
31+
mix = mix[:, start_frame:]
32+
src = src[:, start_frame:]
33+
mix = mix[:, :target_num_frames]
34+
src = src[:, :target_num_frames]
35+
mask = torch.ones_like(mix)
36+
else:
37+
num_padding = target_num_frames - num_frames
38+
pad = torch.zeros([1, num_padding], dtype=mix.dtype, device=mix.device)
39+
mix = torch.cat([mix, pad], 1)
40+
src = torch.cat([src, pad.expand(num_channels, -1)], 1)
41+
mask = torch.ones_like(mix)
42+
mask[..., num_frames:] = 0
43+
return mix, src, mask
44+
45+
46+
47+
def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, duration):
48+
target_num_frames = int(duration * sample_rate)
49+
50+
mixes, srcs, masks = [], [], []
51+
for sample in samples:
52+
mix, src, mask = _fix_num_frames(sample, target_num_frames, random_start=True)
53+
54+
mixes.append(mix)
55+
srcs.append(src)
56+
masks.append(mask)
57+
58+
return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0))
59+
60+
61+
def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType]):
62+
max_num_frames = max(s[1].shape[-1] for s in samples)
63+
64+
mixes, srcs, masks = [], [], []
65+
for sample in samples:
66+
mix, src, mask = _fix_num_frames(sample, max_num_frames, random_start=False)
67+
68+
mixes.append(mix)
69+
srcs.append(src)
70+
masks.append(mask)
71+
72+
return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0))
73+
74+
75+
def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4):
76+
assert mode in ["train", "test"]
77+
if dataset_type == "wsj0mix":
78+
if mode == 'train':
79+
if sample_rate is None:
80+
raise ValueError("sample_rate is not given.")
81+
return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration)
82+
return collate_fn_wsj0mix_test
83+
raise ValueError(f"Unexpected dataset: {dataset_type}")
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from pathlib import Path
2+
from typing import Union, Tuple, List
3+
4+
import torch
5+
from torch.utils.data import Dataset
6+
7+
import torchaudio
8+
9+
SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
10+
11+
12+
class WSJ0Mix(Dataset):
13+
"""Create a Dataset for wsj0-mix.
14+
15+
Args:
16+
root (str or Path): Path to the directory where the dataset is found.
17+
num_speakers (int): The number of speakers, which determines the directories
18+
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
19+
N source audios.
20+
sample_rate (int): Expected sample rate of audio files. If any of the audio has a
21+
different sample rate, raises ``ValueError``.
22+
audio_ext (str): The extension of audio files to find. (default: ".wav")
23+
"""
24+
def __init__(
25+
self,
26+
root: Union[str, Path],
27+
num_speakers: int,
28+
sample_rate: int,
29+
audio_ext: str = ".wav",
30+
):
31+
self.root = Path(root)
32+
self.sample_rate = sample_rate
33+
self.mix_dir = (self.root / "mix").resolve()
34+
self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)]
35+
36+
self.files = [p.name for p in self.mix_dir.glob(f"*{audio_ext}")]
37+
self.files.sort()
38+
39+
def _load_audio(self, path) -> torch.Tensor:
40+
waveform, sample_rate = torchaudio.load(path)
41+
if sample_rate != self.sample_rate:
42+
raise ValueError(
43+
f"The dataset contains audio file of sample rate {sample_rate}. "
44+
"Where the requested sample rate is {self.sample_rate}."
45+
)
46+
return waveform
47+
48+
def _load_sample(self, filename) -> SampleType:
49+
mixed = self._load_audio(str(self.mix_dir / filename))
50+
srcs = []
51+
for i, dir_ in enumerate(self.src_dirs):
52+
src = self._load_audio(str(dir_ / filename))
53+
if mixed.shape != src.shape:
54+
raise ValueError(
55+
f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}"
56+
)
57+
srcs.append(src)
58+
return self.sample_rate, mixed, srcs
59+
60+
def __len__(self) -> int:
61+
return len(self.files)
62+
63+
def __getitem__(self, key: int) -> SampleType:
64+
"""Load the n-th sample from the dataset.
65+
Args:
66+
n (int): The index of the sample to be loaded
67+
Returns:
68+
tuple: ``(sample_rate, mix_waveform, list_of_source_waveforms)``
69+
"""
70+
return self._load_sample(self.files[key])
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import os
2+
3+
from torchaudio_unittest.common_utils import (
4+
TempDirMixin,
5+
TorchaudioTestCase,
6+
get_whitenoise,
7+
save_wav,
8+
normalize_wav,
9+
)
10+
11+
from source_separation.utils.dataset import wsj0mix
12+
13+
14+
_FILENAMES = [
15+
"012c0207_1.9952_01cc0202_-1.9952.wav",
16+
"01co0302_1.63_014c020q_-1.63.wav",
17+
"01do0316_0.24011_205a0104_-0.24011.wav",
18+
"01lc020x_1.1301_027o030r_-1.1301.wav",
19+
"01mc0202_0.34056_205o0106_-0.34056.wav",
20+
"01nc020t_0.53821_018o030w_-0.53821.wav",
21+
"01po030f_2.2136_40ko031a_-2.2136.wav",
22+
"01ra010o_2.4098_403a010f_-2.4098.wav",
23+
"01xo030b_0.22377_016o031a_-0.22377.wav",
24+
"02ac020x_0.68566_01ec020b_-0.68566.wav",
25+
"20co010m_0.82801_019c0212_-0.82801.wav",
26+
"20da010u_1.2483_017c0211_-1.2483.wav",
27+
"20oo010d_1.0631_01ic020s_-1.0631.wav",
28+
"20sc0107_2.0222_20fo010h_-2.0222.wav",
29+
"20tc010f_0.051456_404a0110_-0.051456.wav",
30+
"407c0214_1.1712_02ca0113_-1.1712.wav",
31+
"40ao030w_2.4697_20vc010a_-2.4697.wav",
32+
"40pa0101_1.1087_40ea0107_-1.1087.wav",
33+
]
34+
35+
36+
def _mock_dataset(root_dir, num_speaker):
37+
dirnames = ["mix"] + [f"s{i+1}" for i in range(num_speaker)]
38+
for dirname in dirnames:
39+
os.makedirs(os.path.join(root_dir, dirname), exist_ok=True)
40+
41+
seed = 0
42+
sample_rate = 8000
43+
expected = []
44+
for filename in _FILENAMES:
45+
mix = None
46+
src = []
47+
for dirname in dirnames:
48+
waveform = get_whitenoise(
49+
sample_rate=8000, duration=1, n_channels=1, dtype="int16", seed=seed
50+
)
51+
seed += 1
52+
53+
path = os.path.join(root_dir, dirname, filename)
54+
save_wav(path, waveform, sample_rate)
55+
waveform = normalize_wav(waveform)
56+
57+
if dirname == "mix":
58+
mix = waveform
59+
else:
60+
src.append(waveform)
61+
expected.append((sample_rate, mix, src))
62+
return expected
63+
64+
65+
class TestWSJ0Mix2(TempDirMixin, TorchaudioTestCase):
66+
backend = "default"
67+
root_dir = None
68+
expected = None
69+
70+
@classmethod
71+
def setUpClass(cls):
72+
cls.root_dir = cls.get_base_temp_dir()
73+
cls.expected = _mock_dataset(cls.root_dir, 2)
74+
75+
def test_wsj0mix(self):
76+
dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=2, sample_rate=8000)
77+
78+
n_ite = 0
79+
for i, sample in enumerate(dataset):
80+
(_, sample_mix, sample_src) = sample
81+
(_, expected_mix, expected_src) = self.expected[i]
82+
self.assertEqual(sample_mix, expected_mix, atol=5e-5, rtol=1e-8)
83+
self.assertEqual(sample_src[0], expected_src[0], atol=5e-5, rtol=1e-8)
84+
self.assertEqual(sample_src[1], expected_src[1], atol=5e-5, rtol=1e-8)
85+
n_ite += 1
86+
assert n_ite == len(self.expected)
87+
88+
89+
class TestWSJ0Mix3(TempDirMixin, TorchaudioTestCase):
90+
backend = "default"
91+
root_dir = None
92+
expected = None
93+
94+
@classmethod
95+
def setUpClass(cls):
96+
cls.root_dir = cls.get_base_temp_dir()
97+
cls.expected = _mock_dataset(cls.root_dir, 3)
98+
99+
def test_wsj0mix(self):
100+
dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=3, sample_rate=8000)
101+
102+
n_ite = 0
103+
for i, sample in enumerate(dataset):
104+
(_, sample_mix, sample_src) = sample
105+
(_, expected_mix, expected_src) = self.expected[i]
106+
self.assertEqual(sample_mix, expected_mix, atol=5e-5, rtol=1e-8)
107+
self.assertEqual(sample_src[0], expected_src[0], atol=5e-5, rtol=1e-8)
108+
self.assertEqual(sample_src[1], expected_src[1], atol=5e-5, rtol=1e-8)
109+
self.assertEqual(sample_src[2], expected_src[2], atol=5e-5, rtol=1e-8)
110+
n_ite += 1
111+
assert n_ite == len(self.expected)

0 commit comments

Comments
 (0)