66
77import torchaudio
88
9+ SampleType = Tuple [int , torch .Tensor , List [torch .Tensor ]]
10+
911
1012class WSJ0Mix (Dataset ):
13+ """Create a Dataset for wsj0-mix.
14+
15+ Args:
16+ root (str or Path): Path to the directory where the dataset is found.
17+ num_speakers (int): The number of speakers, which determines the directories
18+ to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
19+ N source audios.
20+ sample_rate (int): Expected sample rate of audio files. If any of the audio has a
21+ different sample rate, raises ``ValueError``.
22+ audio_ext (str): The extension of audio files to find. (default: ".wav")
23+ """
1124 def __init__ (
12- self , root : Union [str , Path ], num_speakers , sample_rate , audio_ext = "wav"
25+ self ,
26+ root : Union [str , Path ],
27+ num_speakers : int ,
28+ sample_rate : int ,
29+ audio_ext : str = ".wav" ,
1330 ):
1431 self .root = Path (root )
1532 self .sample_rate = sample_rate
1633 self .mix_dir = (self .root / "mix" ).resolve ()
1734 self .src_dirs = [(self .root / f"s{ i + 1 } " ).resolve () for i in range (num_speakers )]
1835
19- self .files = [p .name for p in self .mix_dir .glob (f"*. { audio_ext } " )]
36+ self .files = [p .name for p in self .mix_dir .glob (f"*{ audio_ext } " )]
2037 self .files .sort ()
2138
2239 def _load_audio (self , path ) -> torch .Tensor :
@@ -28,7 +45,7 @@ def _load_audio(self, path) -> torch.Tensor:
2845 )
2946 return waveform
3047
31- def _load_sample (self , filename ) -> Tuple [ int , torch . Tensor , List [ torch . Tensor ]] :
48+ def _load_sample (self , filename ) -> SampleType :
3249 mixed = self ._load_audio (str (self .mix_dir / filename ))
3350 srcs = []
3451 for i , dir_ in enumerate (self .src_dirs ):
@@ -43,5 +60,11 @@ def _load_sample(self, filename) -> Tuple[int, torch.Tensor, List[torch.Tensor]]
4360 def __len__ (self ) -> int :
4461 return len (self .files )
4562
46- def __getitem__ (self , key : int ) -> Tuple [int , torch .Tensor , List [torch .Tensor ]]:
63+ def __getitem__ (self , key : int ) -> SampleType :
64+ """Load the n-th sample from the dataset.
65+ Args:
66+ n (int): The index of the sample to be loaded
67+ Returns:
68+ tuple: ``(sample_rate, mix_waveform, list_of_source_waveforms)``
69+ """
4770 return self ._load_sample (self .files [key ])
0 commit comments