Skip to content

Commit f0140df

Browse files
committed
Add Conv-TasNet model
1 parent 52a18a9 commit f0140df

File tree

1 file changed

+301
-0
lines changed
  • examples/source_separation/conv_tasnet

1 file changed

+301
-0
lines changed
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
"""Implements Conv-TasNet with building blocks of it."""
2+
3+
from typing import Tuple, Optional
4+
5+
import torch
6+
7+
8+
class ConvBlock(torch.nn.Module):
9+
"""1D Convolutional block.
10+
11+
Args:
12+
in_channels (int): Input channels
13+
hidden_channels (int): The number of channels in the internal layers.
14+
kernel_size (int): The convolution kernel size of the middle layer.
15+
padding (int): Padding value of the convolution in the middle layer.
16+
dilation (int): Dilation value of the convolution in the middle layer.
17+
causal (bool): Switch causal/non-causal implementation.
18+
no_redisual (bool): Disable residual block/output.
19+
20+
References:
21+
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
22+
Luo, Yi and Mesgarani, Nima
23+
https://arxiv.org/abs/1809.07454
24+
"""
25+
def __init__(
26+
self,
27+
in_channels: int,
28+
hidden_channels: int,
29+
kernel_size: int,
30+
padding: int,
31+
dilation: int = 1,
32+
causal: bool = False,
33+
no_residual: bool = False,
34+
):
35+
super().__init__()
36+
37+
if causal:
38+
raise NotImplementedError("causal=True is not implemented")
39+
40+
self.conv_layers = torch.nn.Sequential(
41+
torch.nn.Conv1d(
42+
in_channels=in_channels, out_channels=hidden_channels, kernel_size=1
43+
),
44+
torch.nn.PReLU(),
45+
torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
46+
torch.nn.Conv1d(
47+
in_channels=hidden_channels,
48+
out_channels=hidden_channels,
49+
kernel_size=kernel_size,
50+
padding=padding,
51+
dilation=dilation,
52+
groups=hidden_channels,
53+
),
54+
torch.nn.PReLU(),
55+
torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
56+
)
57+
58+
self.res_out = (
59+
None
60+
if no_residual
61+
else torch.nn.Conv1d(
62+
in_channels=hidden_channels, out_channels=in_channels, kernel_size=1
63+
)
64+
)
65+
self.skip_out = torch.nn.Conv1d(
66+
in_channels=hidden_channels, out_channels=in_channels, kernel_size=1
67+
)
68+
69+
def forward(
70+
self, input: torch.Tensor
71+
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
72+
feature = self.conv_layers(input)
73+
if self.res_out is None:
74+
residual = None
75+
else:
76+
residual = self.res_out(feature)
77+
skip_out = self.skip_out(feature)
78+
return residual, skip_out
79+
80+
81+
class MaskGenerator(torch.nn.Module):
82+
"""TCN (Temporal Convolution Network) Separation Module
83+
84+
Generates masks for separation.
85+
86+
Args:
87+
input_dim (int): Input feature dimension
88+
num_sources (int): The number of sources to separate
89+
kernel_size (int): The convolution kernel size of conv blocks
90+
num_featrs (int): Unit feature dimenstion of conv blocks
91+
num_layers (int): The number of conv blocks in one stack.
92+
num_stacks (int): The number of conv block stacks.
93+
causal (bool): Switch causal/non-causal implementation.
94+
95+
References:
96+
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
97+
Luo, Yi and Mesgarani, Nima
98+
https://arxiv.org/abs/1809.07454
99+
"""
100+
101+
def __init__(
102+
self,
103+
input_dim: int,
104+
num_sources: int,
105+
kernel_size: int,
106+
num_feats: int,
107+
num_layers: int,
108+
num_stacks: int,
109+
causal: bool = False,
110+
):
111+
if causal:
112+
raise NotImplementedError("causal=True is not implemented")
113+
114+
super().__init__()
115+
116+
self.input_dim = input_dim
117+
self.num_sources = num_sources
118+
119+
self.norm_layers = torch.nn.Sequential(
120+
torch.nn.GroupNorm(num_groups=1, num_channels=input_dim, eps=1e-8),
121+
torch.nn.Conv1d(
122+
in_channels=input_dim, out_channels=num_feats, kernel_size=1
123+
),
124+
)
125+
self.conv_layers = torch.nn.ModuleList([])
126+
for s in range(num_stacks):
127+
for l in range(num_layers):
128+
self.conv_layers.append(
129+
ConvBlock(
130+
in_channels=num_feats,
131+
hidden_channels=4 * num_feats,
132+
kernel_size=kernel_size,
133+
dilation=2 ** l,
134+
padding=2 ** l,
135+
causal=causal,
136+
# The last ConvBlock does not need residual
137+
no_residual=(l == (num_layers - 1) and s == (num_stacks - 1)),
138+
)
139+
)
140+
self.output_layer = torch.nn.Sequential(
141+
torch.nn.PReLU(),
142+
torch.nn.Conv1d(num_feats, input_dim * num_sources, 1),
143+
torch.nn.Sigmoid(),
144+
)
145+
146+
def forward(self, input: torch.Tensor) -> torch.Tensor:
147+
batch_size = input.shape[0]
148+
feats = self.norm_layers(input)
149+
output = 0.0
150+
for layer in self.conv_layers:
151+
residual, skip = layer(feats)
152+
if residual is not None: # the last conv layer does not produce residual
153+
feats = feats + residual
154+
output = output + skip
155+
output = self.output_layer(output)
156+
return output.view(batch_size, self.num_sources, self.input_dim, -1)
157+
158+
159+
class ConvTasNet(torch.nn.Module):
160+
"""Conv-TasNet: a fully-convolutional time-domain audio separation network
161+
162+
Args:
163+
num_sources (int): The number of sources to split.
164+
enc_kernel_size (int): The convolution kernel size of the encoder/decoder.
165+
enc_num_feats (int): The feature dimensions passed to mask generator.
166+
msk_kernel_size (int): The convolution kernel size of the mask generator.
167+
msk_num_feats (int): The internal feature dimension of the mask generator.
168+
msk_num_layers (int): The number of layers in one conv block of the mask generator.
169+
mks_num_stacks (int): The numbr of conv blocks of the mask generator
170+
causal (bool): Switch causal/non-causal implementation.
171+
172+
References:
173+
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
174+
Luo, Yi and Mesgarani, Nima
175+
https://arxiv.org/abs/1809.07454
176+
"""
177+
def __init__(
178+
self,
179+
num_sources: int = 2,
180+
# encoder/decoder parameters
181+
enc_kernel_size: int = 32,
182+
enc_num_feats: int = 512,
183+
# mask generator parameters
184+
msk_kernel_size: int = 3,
185+
msk_num_feats: int = 128,
186+
msk_num_layers: int = 8,
187+
msk_num_stacks: int = 3,
188+
causal: bool = False,
189+
):
190+
super().__init__()
191+
192+
if causal:
193+
raise NotImplementedError("causal=True is not implemented")
194+
195+
self.num_sources = num_sources
196+
self.enc_num_feats = enc_num_feats
197+
self.enc_kernel_size = enc_kernel_size
198+
self.enc_stride = enc_kernel_size // 2
199+
200+
self.encoder = torch.nn.Conv1d(
201+
in_channels=1,
202+
out_channels=enc_num_feats,
203+
kernel_size=enc_kernel_size,
204+
stride=self.enc_stride,
205+
padding=self.enc_stride,
206+
bias=False,
207+
)
208+
self.mask_generator = MaskGenerator(
209+
input_dim=enc_num_feats,
210+
num_sources=num_sources,
211+
kernel_size=msk_kernel_size,
212+
num_feats=msk_num_feats,
213+
num_layers=msk_num_layers,
214+
num_stacks=msk_num_stacks,
215+
)
216+
self.decoder = torch.nn.ConvTranspose1d(
217+
in_channels=enc_num_feats,
218+
out_channels=1,
219+
kernel_size=enc_kernel_size,
220+
stride=self.enc_stride,
221+
padding=self.enc_stride,
222+
bias=False,
223+
)
224+
225+
def _pad_input(self, input: torch.Tensor) -> Tuple[torch.Tensor, int]:
226+
"""Pad input Tensor so that the end of the input tensor corresponds with
227+
228+
1. (if kernel size is odd) the center of the last convolution kernel
229+
or 2. (if kernel size is even) the end of the first half of the last convolution kernel
230+
231+
Assuming that the resulting Tensor will be zero-padded with the size of stride
232+
on the both ends in Conv1D
233+
234+
|<--- k_1 --->|
235+
| | |<-- k_n-1 -->|
236+
| | | |<--- k_n --->|
237+
| | | | |
238+
| | | | |
239+
| v v v |
240+
|<---->|<--- input signal --->|<--->|<---->|
241+
stride PAD stride
242+
243+
Args:
244+
input (torch.Tensor): 3D Tensor with shape (batch_size, channels==1, frames)
245+
246+
Returns:
247+
torch.Tensor: Padded Tensor
248+
int: Number of paddings performed
249+
"""
250+
batch_size, num_channels, num_frames = input.shape
251+
is_odd = self.enc_kernel_size % 2
252+
num_strides = (num_frames - is_odd) // self.enc_stride
253+
num_remainings = num_frames - (is_odd + num_strides * self.enc_stride)
254+
if num_remainings == 0:
255+
return input, 0
256+
257+
num_paddings = self.enc_stride - num_remainings
258+
pad = torch.zeros(
259+
batch_size,
260+
num_channels,
261+
num_paddings,
262+
dtype=input.dtype,
263+
device=input.device,
264+
)
265+
return torch.cat([input, pad], 2), num_paddings
266+
267+
def forward(self, input: torch.Tensor) -> torch.Tensor:
268+
"""Perform source separation. Generate audio source waveforms.
269+
270+
Args:
271+
input (torch.Tensor): 3D Tensor with shape (batch, channel==1, frames)
272+
273+
Returns:
274+
torch.Tensor: 3D Tensor with shape (batch, channel==num_sources, frames)
275+
"""
276+
if input.ndim != 3 or input.shape[1] != 1:
277+
raise ValueError(
278+
f"Expected 3D tensor (batch, channel==1, frames). Found: {input.shape}"
279+
)
280+
281+
# B: batch size
282+
# L: input frame length
283+
# L': padded input frame length
284+
# F: feature dimension
285+
# M: feature frame length
286+
# S: number of sources
287+
288+
padded, num_pads = self._pad_input(input) # B, 1, L'
289+
batch_size, num_padded_frames = padded.shape[0], padded.shape[2]
290+
feats = self.encoder(padded) # B, F, M
291+
masked = self.mask_generator(feats) * feats.unsqueeze(1) # B, S, F, M
292+
masked = masked.view(
293+
batch_size * self.num_sources, self.enc_num_feats, -1
294+
) # B*S, F, M
295+
decoded = self.decoder(masked) # B*S, 1, L'
296+
output = decoded.view(
297+
batch_size, self.num_sources, num_padded_frames
298+
) # B, S, L'
299+
if num_pads > 0:
300+
output = output[..., :-num_pads] # B, S, L
301+
return output

0 commit comments

Comments
 (0)