Skip to content

Commit a7f0e2a

Browse files
committed
Add wsj0mix dataset
1 parent 52a18a9 commit a7f0e2a

File tree

5 files changed

+219
-0
lines changed

5 files changed

+219
-0
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import os
2+
import sys
3+
4+
_THIS_DIR = os.path.abspath(os.path.dirname(__file__))
5+
6+
7+
sys.path.append(os.path.join(_THIS_DIR, "..", "..", "..", "test"))
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import os
2+
3+
from torchaudio_unittest.common_utils import (
4+
TempDirMixin,
5+
TorchaudioTestCase,
6+
get_whitenoise,
7+
save_wav,
8+
normalize_wav,
9+
)
10+
11+
from utils.dataset import wsj0mix
12+
13+
14+
_FILENAMES = [
15+
"012c0207_1.9952_01cc0202_-1.9952.wav",
16+
"01co0302_1.63_014c020q_-1.63.wav",
17+
"01do0316_0.24011_205a0104_-0.24011.wav",
18+
"01lc020x_1.1301_027o030r_-1.1301.wav",
19+
"01mc0202_0.34056_205o0106_-0.34056.wav",
20+
"01nc020t_0.53821_018o030w_-0.53821.wav",
21+
"01po030f_2.2136_40ko031a_-2.2136.wav",
22+
"01ra010o_2.4098_403a010f_-2.4098.wav",
23+
"01xo030b_0.22377_016o031a_-0.22377.wav",
24+
"02ac020x_0.68566_01ec020b_-0.68566.wav",
25+
"20co010m_0.82801_019c0212_-0.82801.wav",
26+
"20da010u_1.2483_017c0211_-1.2483.wav",
27+
"20oo010d_1.0631_01ic020s_-1.0631.wav",
28+
"20sc0107_2.0222_20fo010h_-2.0222.wav",
29+
"20tc010f_0.051456_404a0110_-0.051456.wav",
30+
"407c0214_1.1712_02ca0113_-1.1712.wav",
31+
"40ao030w_2.4697_20vc010a_-2.4697.wav",
32+
"40pa0101_1.1087_40ea0107_-1.1087.wav",
33+
]
34+
35+
36+
def _mock_dataset(root_dir, num_speaker):
37+
dirnames = ["mix"] + [f"s{i+1}" for i in range(num_speaker)]
38+
for dirname in dirnames:
39+
os.makedirs(os.path.join(root_dir, dirname), exist_ok=True)
40+
41+
seed = 0
42+
sample_rate = 8000
43+
expected = []
44+
for filename in _FILENAMES:
45+
mix = None
46+
src = []
47+
for dirname in dirnames:
48+
waveform = get_whitenoise(
49+
sample_rate=8000, duration=1, n_channels=1, dtype="int16", seed=seed
50+
)
51+
seed += 1
52+
53+
path = os.path.join(root_dir, dirname, filename)
54+
save_wav(path, waveform, sample_rate)
55+
waveform = normalize_wav(waveform)
56+
57+
if dirname == "mix":
58+
mix = waveform
59+
else:
60+
src.append(waveform)
61+
expected.append(wsj0mix.Sample(sample_rate, mix, src))
62+
return expected
63+
64+
65+
class TestWSJ0Mix2(TempDirMixin, TorchaudioTestCase):
66+
backend = "default"
67+
root_dir = None
68+
expected = None
69+
70+
@classmethod
71+
def setUpClass(cls):
72+
cls.root_dir = cls.get_base_temp_dir()
73+
cls.expected = _mock_dataset(cls.root_dir, 2)
74+
75+
def test_wsj0mix(self):
76+
dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=2, sample_rate=8000)
77+
78+
n_ite = 0
79+
for i, sample in enumerate(dataset):
80+
expected = self.expected[i]
81+
self.assertEqual(sample.mix, expected.mix, atol=5e-5, rtol=1e-8)
82+
self.assertEqual(sample.src[0], expected.src[0], atol=5e-5, rtol=1e-8)
83+
self.assertEqual(sample.src[1], expected.src[1], atol=5e-5, rtol=1e-8)
84+
n_ite += 1
85+
assert n_ite == len(self.expected)
86+
87+
88+
class TestWSJ0Mix3(TempDirMixin, TorchaudioTestCase):
89+
backend = "default"
90+
root_dir = None
91+
expected = None
92+
93+
@classmethod
94+
def setUpClass(cls):
95+
cls.root_dir = cls.get_base_temp_dir()
96+
cls.expected = _mock_dataset(cls.root_dir, 3)
97+
98+
def test_wsj0mix(self):
99+
dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=3, sample_rate=8000)
100+
101+
n_ite = 0
102+
for i, sample in enumerate(dataset):
103+
expected = self.expected[i]
104+
self.assertEqual(sample.mix, expected.mix, atol=5e-5, rtol=1e-8)
105+
self.assertEqual(sample.src[0], expected.src[0], atol=5e-5, rtol=1e-8)
106+
self.assertEqual(sample.src[1], expected.src[1], atol=5e-5, rtol=1e-8)
107+
self.assertEqual(sample.src[2], expected.src[2], atol=5e-5, rtol=1e-8)
108+
n_ite += 1
109+
assert n_ite == len(self.expected)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import wsj0mix
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Union
2+
from pathlib import Path
3+
from collections import namedtuple
4+
5+
import torch
6+
from torch.utils.data import Dataset
7+
8+
import torchaudio
9+
10+
Sample = namedtuple("Sample", ["sample_rate", "mix", "src"])
11+
12+
13+
class WSJ0Mix(Dataset):
14+
def __init__(
15+
self, root: Union[str, Path], num_speakers, sample_rate, audio_ext="wav"
16+
):
17+
self.root = Path(root)
18+
self.sample_rate = sample_rate
19+
self.mix_dir = (self.root / "mix").resolve()
20+
self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)]
21+
22+
self.files = [p.name for p in self.mix_dir.glob(f"*.{audio_ext}")]
23+
self.files.sort()
24+
25+
def _load_audio(self, path) -> torch.Tensor:
26+
waveform, sample_rate = torchaudio.load(path)
27+
if sample_rate != self.sample_rate:
28+
raise ValueError(
29+
f"The dataset contains audio file of sample rate {sample_rate}. "
30+
"Where the requested sample rate is {self.sample_rate}."
31+
)
32+
return waveform
33+
34+
def _load_sample(self, filename) -> Sample:
35+
mixed = self._load_audio(str(self.mix_dir / filename))
36+
srcs = []
37+
for i, dir_ in enumerate(self.src_dirs):
38+
src = self._load_audio(str(dir_ / filename))
39+
if mixed.shape != src.shape:
40+
raise ValueError(
41+
f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}"
42+
)
43+
srcs.append(src)
44+
return Sample(self.sample_rate, mixed, srcs)
45+
46+
def __len__(self) -> int:
47+
return len(self.files)
48+
49+
def __getitem__(self, key: int) -> Sample:
50+
return self._load_sample(self.files[key])
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import List
2+
from functools import partial
3+
from collections import namedtuple
4+
5+
import torch
6+
7+
from utils.dataset import wsj0mix
8+
9+
Batch = namedtuple("Batch", ["mix", "src"])
10+
11+
12+
def get_dataset(dataset_type, root_dir, num_speakers, sample_rate):
13+
if dataset_type == "wsj0mix":
14+
train = wsj0mix.WSJ0Mix(root_dir / "tr", num_speakers, sample_rate)
15+
validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate)
16+
evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate)
17+
else:
18+
raise ValueError(f"Unexpected dataset: {dataset_type}")
19+
return train, validation, evaluation
20+
21+
22+
def _fix_num_frames(waveform: torch.Tensor, target_num_frames: int):
23+
"""Ensure waveform has exact number of frames by slicing or padding"""
24+
num_channels, num_frames = waveform.shape
25+
if num_frames == target_num_frames:
26+
return waveform
27+
if num_frames > target_num_frames:
28+
return waveform[..., :target_num_frames]
29+
pad = torch.zeros(
30+
num_channels,
31+
target_num_frames - num_frames,
32+
dtype=waveform.dtype,
33+
device=waveform.device,
34+
)
35+
return torch.cat([waveform, pad], 1)
36+
37+
38+
def collate_fn_wsj0mix(samples: List[wsj0mix.Sample], sample_rate, duration):
39+
target_num_frames = int(duration * sample_rate)
40+
41+
mixed = [_fix_num_frames(s.mix, target_num_frames) for s in samples]
42+
mixed = torch.stack(mixed, 0)
43+
44+
src = [_fix_num_frames(torch.cat(s.src, 0), target_num_frames) for s in samples]
45+
src = torch.stack(src, 0)
46+
return Batch(mixed, src)
47+
48+
49+
def get_collate_fn(dataset_type, sample_rate, duration=4):
50+
if dataset_type == "wsj0mix":
51+
return partial(collate_fn_wsj0mix, sample_rate=sample_rate, duration=duration)
52+
raise ValueError(f"Unexpected dataset: {dataset_type}")

0 commit comments

Comments
 (0)