Skip to content

Commit 51f83e3

Browse files
committed
Update docstring
1 parent 16de3ea commit 51f83e3

File tree

4 files changed

+34
-11
lines changed

4 files changed

+34
-11
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from . import (
2+
dataset,
23
metrics,
34
)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from . import wsj0mix
1+
from . import utils, wsj0mix

examples/source_separation/utils/dataset_utils.py renamed to examples/source_separation/utils/dataset/utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from utils.dataset import wsj0mix
7+
from . import wsj0mix
88

99
Batch = namedtuple("Batch", ["mix", "src"])
1010

@@ -35,13 +35,12 @@ def _fix_num_frames(waveform: torch.Tensor, target_num_frames: int):
3535
return torch.cat([waveform, pad], 1)
3636

3737

38-
def collate_fn_wsj0mix_train(samples: List[wsj0mix.Sample], sample_rate, duration):
38+
def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, duration):
3939
target_num_frames = int(duration * sample_rate)
4040

4141
mixes, srcs = [], []
42-
for sample in samples:
43-
mix = sample.mix
44-
src = torch.cat(sample.src, 0)
42+
for (_, mix, src) in samples:
43+
src = torch.cat(src, 0)
4544

4645
num_frames = mix.shape[-1]
4746
if num_frames > target_num_frames:
@@ -58,7 +57,7 @@ def collate_fn_wsj0mix_train(samples: List[wsj0mix.Sample], sample_rate, duratio
5857
return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0))
5958

6059

61-
def collate_fn_wsj0mix_test(samples: List[wsj0mix.Sample]):
60+
def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType]) -> Batch:
6261
return [Batch(
6362
sample.mix.unsqueeze(0),
6463
torch.cat(sample.src, 0).unsqueeze(0),

examples/source_separation/utils/dataset/wsj0mix.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,34 @@
66

77
import torchaudio
88

9+
SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
10+
911

1012
class WSJ0Mix(Dataset):
13+
"""Create a Dataset for wsj0-mix.
14+
15+
Args:
16+
root (str or Path): Path to the directory where the dataset is found.
17+
num_speakers (int): The number of speakers, which determines the directories
18+
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
19+
N source audios.
20+
sample_rate (int): Expected sample rate of audio files. If any of the audio has a
21+
different sample rate, raises ``ValueError``.
22+
audio_ext (str): The extension of audio files to find. (default: ".wav")
23+
"""
1124
def __init__(
12-
self, root: Union[str, Path], num_speakers, sample_rate, audio_ext="wav"
25+
self,
26+
root: Union[str, Path],
27+
num_speakers: int,
28+
sample_rate: int,
29+
audio_ext: str = ".wav",
1330
):
1431
self.root = Path(root)
1532
self.sample_rate = sample_rate
1633
self.mix_dir = (self.root / "mix").resolve()
1734
self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)]
1835

19-
self.files = [p.name for p in self.mix_dir.glob(f"*.{audio_ext}")]
36+
self.files = [p.name for p in self.mix_dir.glob(f"*{audio_ext}")]
2037
self.files.sort()
2138

2239
def _load_audio(self, path) -> torch.Tensor:
@@ -28,7 +45,7 @@ def _load_audio(self, path) -> torch.Tensor:
2845
)
2946
return waveform
3047

31-
def _load_sample(self, filename) -> Tuple[int, torch.Tensor, List[torch.Tensor]]:
48+
def _load_sample(self, filename) -> SampleType:
3249
mixed = self._load_audio(str(self.mix_dir / filename))
3350
srcs = []
3451
for i, dir_ in enumerate(self.src_dirs):
@@ -43,5 +60,11 @@ def _load_sample(self, filename) -> Tuple[int, torch.Tensor, List[torch.Tensor]]
4360
def __len__(self) -> int:
4461
return len(self.files)
4562

46-
def __getitem__(self, key: int) -> Tuple[int, torch.Tensor, List[torch.Tensor]]:
63+
def __getitem__(self, key: int) -> SampleType:
64+
"""Load the n-th sample from the dataset.
65+
Args:
66+
n (int): The index of the sample to be loaded
67+
Returns:
68+
tuple: ``(sample_rate, mix_waveform, list_of_source_waveforms)``
69+
"""
4770
return self._load_sample(self.files[key])

0 commit comments

Comments
 (0)