Skip to content

Commit f75b458

Browse files
committed
Add Conv-TasNet model
1 parent 52a18a9 commit f75b458

File tree

1 file changed

+318
-0
lines changed
  • examples/source_separation/conv_tasnet

1 file changed

+318
-0
lines changed
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
"""Implements Conv-TasNet with building blocks of it."""
2+
3+
from typing import Tuple, Optional
4+
5+
import torch
6+
7+
8+
class ConvBlock(torch.nn.Module):
9+
"""1D Convolutional block.
10+
11+
Args:
12+
channels (int): The number of input/output channels, <B, Sc>
13+
hidden_channels (int): The number of channels in the internal layers, <H>.
14+
kernel_size (int): The convolution kernel size of the middle layer, <P>.
15+
padding (int): Padding value of the convolution in the middle layer.
16+
dilation (int): Dilation value of the convolution in the middle layer.
17+
causal (bool): Switch causal/non-causal implementation.
18+
no_redisual (bool): Disable residual block/output.
19+
20+
References:
21+
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
22+
Luo, Yi and Mesgarani, Nima
23+
https://arxiv.org/abs/1809.07454
24+
"""
25+
def __init__(
26+
self,
27+
io_channels: int,
28+
hidden_channels: int,
29+
kernel_size: int,
30+
padding: int,
31+
dilation: int = 1,
32+
causal: bool = False,
33+
no_residual: bool = False,
34+
):
35+
super().__init__()
36+
37+
if causal:
38+
raise NotImplementedError("causal=True is not implemented")
39+
40+
self.conv_layers = torch.nn.Sequential(
41+
torch.nn.Conv1d(
42+
in_channels=io_channels, out_channels=hidden_channels, kernel_size=1
43+
),
44+
torch.nn.PReLU(),
45+
torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
46+
torch.nn.Conv1d(
47+
in_channels=hidden_channels,
48+
out_channels=hidden_channels,
49+
kernel_size=kernel_size,
50+
padding=padding,
51+
dilation=dilation,
52+
groups=hidden_channels,
53+
),
54+
torch.nn.PReLU(),
55+
torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
56+
)
57+
58+
self.res_out = (
59+
None
60+
if no_residual
61+
else torch.nn.Conv1d(
62+
in_channels=hidden_channels, out_channels=io_channels, kernel_size=1
63+
)
64+
)
65+
self.skip_out = torch.nn.Conv1d(
66+
in_channels=hidden_channels, out_channels=io_channels, kernel_size=1
67+
)
68+
69+
def forward(
70+
self, input: torch.Tensor
71+
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
72+
feature = self.conv_layers(input)
73+
if self.res_out is None:
74+
residual = None
75+
else:
76+
residual = self.res_out(feature)
77+
skip_out = self.skip_out(feature)
78+
return residual, skip_out
79+
80+
81+
class MaskGenerator(torch.nn.Module):
82+
"""TCN (Temporal Convolution Network) Separation Module
83+
84+
Generates masks for separation.
85+
86+
Args:
87+
input_dim (int): Input feature dimension, <N>.
88+
num_sources (int): The number of sources to separate.
89+
kernel_size (int): The convolution kernel size of conv blocks, <P>.
90+
num_featrs (int): Input/output feature dimenstion of conv blocks, <B, Sc>.
91+
num_hidden (int): Intermediate feature dimention of conv blocks, <H>
92+
num_layers (int): The number of conv blocks in one stack, <X>.
93+
num_stacks (int): The number of conv block stacks, <R>.
94+
causal (bool): Switch causal/non-causal implementation.
95+
96+
References:
97+
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
98+
Luo, Yi and Mesgarani, Nima
99+
https://arxiv.org/abs/1809.07454
100+
"""
101+
102+
def __init__(
103+
self,
104+
input_dim: int,
105+
num_sources: int,
106+
kernel_size: int,
107+
num_feats: int,
108+
num_hidden: int,
109+
num_layers: int,
110+
num_stacks: int,
111+
causal: bool = False,
112+
):
113+
if causal:
114+
raise NotImplementedError("causal=True is not implemented")
115+
116+
super().__init__()
117+
118+
self.input_dim = input_dim
119+
self.num_sources = num_sources
120+
121+
self.norm_layers = torch.nn.Sequential(
122+
torch.nn.GroupNorm(num_groups=1, num_channels=input_dim, eps=1e-8),
123+
torch.nn.Conv1d(
124+
in_channels=input_dim, out_channels=num_feats, kernel_size=1
125+
),
126+
)
127+
self.receptive_field = 0
128+
self.conv_layers = torch.nn.ModuleList([])
129+
for s in range(num_stacks):
130+
for l in range(num_layers):
131+
multi = 2 ** l
132+
self.conv_layers.append(
133+
ConvBlock(
134+
io_channels=num_feats,
135+
hidden_channels=num_hidden,
136+
kernel_size=kernel_size,
137+
dilation=multi,
138+
padding=multi,
139+
causal=causal,
140+
# The last ConvBlock does not need residual
141+
no_residual=(l == (num_layers - 1) and s == (num_stacks - 1)),
142+
)
143+
)
144+
self.receptive_field += kernel_size if s == 0 and l == 0 else (kernel_size - 1) * multi
145+
self.output_layer = torch.nn.Sequential(
146+
torch.nn.PReLU(),
147+
torch.nn.Conv1d(
148+
in_channels=num_feats, out_channels=input_dim * num_sources, kernel_size=1),
149+
torch.nn.Sigmoid(),
150+
)
151+
152+
def forward(self, input: torch.Tensor) -> torch.Tensor:
153+
"""Generate separation mask.
154+
155+
Args:
156+
input (torch.Tensor): 3D Tensor with shape [batch, features, frames]
157+
158+
Returns:
159+
torch.Tensor: shape [batch, num_sources, features, frames]
160+
"""
161+
batch_size = input.shape[0]
162+
feats = self.norm_layers(input)
163+
output = 0.0
164+
for layer in self.conv_layers:
165+
residual, skip = layer(feats)
166+
if residual is not None: # the last conv layer does not produce residual
167+
feats = feats + residual
168+
output = output + skip
169+
output = self.output_layer(output)
170+
return output.view(batch_size, self.num_sources, self.input_dim, -1)
171+
172+
173+
class ConvTasNet(torch.nn.Module):
174+
"""Conv-TasNet: a fully-convolutional time-domain audio separation network
175+
176+
Args:
177+
num_sources (int): The number of sources to split.
178+
enc_kernel_size (int): The convolution kernel size of the encoder/decoder, <L>.
179+
enc_num_feats (int): The feature dimensions passed to mask generator, <N>.
180+
msk_kernel_size (int): The convolution kernel size of the mask generator, <P>.
181+
msk_num_feats (int): The input/output feature dimension of conv block in the mask generator, <B, Sc>.
182+
msk_num_hidden_feats (int): The internal feature dimension of conv block of the mask generator, <H>.
183+
msk_num_layers (int): The number of layers in one conv block of the mask generator, <X>.
184+
msk_num_stacks (int): The numbr of conv blocks of the mask generator, <R>.
185+
causal (bool): Switch causal/non-causal implementation.
186+
187+
References:
188+
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
189+
Luo, Yi and Mesgarani, Nima
190+
https://arxiv.org/abs/1809.07454
191+
"""
192+
def __init__(
193+
self,
194+
num_sources: int = 2,
195+
# encoder/decoder parameters
196+
enc_kernel_size: int = 16,
197+
enc_num_feats: int = 512,
198+
# mask generator parameters
199+
msk_kernel_size: int = 3,
200+
msk_num_feats: int = 128,
201+
msk_num_hidden_feats: int = 512,
202+
msk_num_layers: int = 8,
203+
msk_num_stacks: int = 3,
204+
causal: bool = False,
205+
):
206+
super().__init__()
207+
208+
if causal:
209+
raise NotImplementedError("causal=True is not implemented")
210+
211+
self.num_sources = num_sources
212+
self.enc_num_feats = enc_num_feats
213+
self.enc_kernel_size = enc_kernel_size
214+
self.enc_stride = enc_kernel_size // 2
215+
216+
self.encoder = torch.nn.Conv1d(
217+
in_channels=1,
218+
out_channels=enc_num_feats,
219+
kernel_size=enc_kernel_size,
220+
stride=self.enc_stride,
221+
padding=self.enc_stride,
222+
bias=False,
223+
)
224+
self.mask_generator = MaskGenerator(
225+
input_dim=enc_num_feats,
226+
num_sources=num_sources,
227+
kernel_size=msk_kernel_size,
228+
num_feats=msk_num_feats,
229+
num_hidden=msk_num_hidden_feats,
230+
num_layers=msk_num_layers,
231+
num_stacks=msk_num_stacks,
232+
)
233+
self.decoder = torch.nn.ConvTranspose1d(
234+
in_channels=enc_num_feats,
235+
out_channels=1,
236+
kernel_size=enc_kernel_size,
237+
stride=self.enc_stride,
238+
padding=self.enc_stride,
239+
bias=False,
240+
)
241+
242+
def _pad_input(self, input: torch.Tensor) -> Tuple[torch.Tensor, int]:
243+
"""Pad input Tensor so that the end of the input tensor corresponds with
244+
245+
1. (if kernel size is odd) the center of the last convolution kernel
246+
or 2. (if kernel size is even) the end of the first half of the last convolution kernel
247+
248+
Assuming that the resulting Tensor will be zero-padded with the size of stride
249+
on the both ends in Conv1D
250+
251+
|<--- k_1 --->|
252+
| | |<-- k_n-1 -->|
253+
| | | |<--- k_n --->|
254+
| | | | |
255+
| | | | |
256+
| v v v |
257+
|<---->|<--- input signal --->|<--->|<---->|
258+
stride PAD stride
259+
260+
Args:
261+
input (torch.Tensor): 3D Tensor with shape (batch_size, channels==1, frames)
262+
263+
Returns:
264+
torch.Tensor: Padded Tensor
265+
int: Number of paddings performed
266+
"""
267+
batch_size, num_channels, num_frames = input.shape
268+
is_odd = self.enc_kernel_size % 2
269+
num_strides = (num_frames - is_odd) // self.enc_stride
270+
num_remainings = num_frames - (is_odd + num_strides * self.enc_stride)
271+
if num_remainings == 0:
272+
return input, 0
273+
274+
num_paddings = self.enc_stride - num_remainings
275+
pad = torch.zeros(
276+
batch_size,
277+
num_channels,
278+
num_paddings,
279+
dtype=input.dtype,
280+
device=input.device,
281+
)
282+
return torch.cat([input, pad], 2), num_paddings
283+
284+
def forward(self, input: torch.Tensor) -> torch.Tensor:
285+
"""Perform source separation. Generate audio source waveforms.
286+
287+
Args:
288+
input (torch.Tensor): 3D Tensor with shape [batch, channel==1, frames]
289+
290+
Returns:
291+
torch.Tensor: 3D Tensor with shape [batch, channel==num_sources, frames]
292+
"""
293+
if input.ndim != 3 or input.shape[1] != 1:
294+
raise ValueError(
295+
f"Expected 3D tensor (batch, channel==1, frames). Found: {input.shape}"
296+
)
297+
298+
# B: batch size
299+
# L: input frame length
300+
# L': padded input frame length
301+
# F: feature dimension
302+
# M: feature frame length
303+
# S: number of sources
304+
305+
padded, num_pads = self._pad_input(input) # B, 1, L'
306+
batch_size, num_padded_frames = padded.shape[0], padded.shape[2]
307+
feats = self.encoder(padded) # B, F, M
308+
masked = self.mask_generator(feats) * feats.unsqueeze(1) # B, S, F, M
309+
masked = masked.view(
310+
batch_size * self.num_sources, self.enc_num_feats, -1
311+
) # B*S, F, M
312+
decoded = self.decoder(masked) # B*S, 1, L'
313+
output = decoded.view(
314+
batch_size, self.num_sources, num_padded_frames
315+
) # B, S, L'
316+
if num_pads > 0:
317+
output = output[..., :-num_pads] # B, S, L
318+
return output

0 commit comments

Comments
 (0)