Skip to content

Commit 92ea776

Browse files
authored
Merge pull request #23 from dhpollack/torchspectrograms
pytorch implementation of MEL spectrograms (no librosa req'd)
2 parents 9e7e5fd + 41390a8 commit 92ea776

File tree

2 files changed

+219
-7
lines changed

2 files changed

+219
-7
lines changed

test/test_transforms.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ class Tester(unittest.TestCase):
1010

1111
sr = 16000
1212
freq = 440
13-
volume = 0.3
13+
volume = .3
1414
sig = (torch.cos(2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
15+
# sig = (torch.cos((1+torch.arange(0, 4 * sr) * 2) / sr * 2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
1516
sig.unsqueeze_(1)
1617
sig = (sig * volume * 2**31).long()
1718

@@ -74,11 +75,11 @@ def test_mel(self):
7475

7576
audio = self.sig.clone()
7677
audio = transforms.Scale()(audio)
77-
self.assertTrue(len(audio.size()) == 2)
78+
self.assertTrue(audio.dim() == 2)
7879
result = transforms.MEL()(audio)
79-
self.assertTrue(len(result.size()) == 3)
80+
self.assertTrue(result.dim() == 3)
8081
result = transforms.BLC2CBL()(result)
81-
self.assertTrue(len(result.size()) == 3)
82+
self.assertTrue(result.dim() == 3)
8283

8384
def test_compose(self):
8485

@@ -121,6 +122,13 @@ def test_mu_law_companding(self):
121122
sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu)
122123
self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.)
123124

125+
def test_mel2(self):
126+
audio_orig = self.sig.clone() # (16000, 1)
127+
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1)
128+
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
129+
spectrogram_torch = transforms.MEL2()(audio_scaled) # (1, 319, 40)
130+
self.assertTrue(spectrogram_torch.dim() == 3)
131+
self.assertTrue(spectrogram_torch.max() <= 0.)
124132

125133
if __name__ == '__main__':
126134
unittest.main()

torchaudio/transforms.py

Lines changed: 207 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,31 @@
11
from __future__ import division, print_function
22
import torch
3+
from torch.autograd import Variable
34
import numpy as np
45
try:
56
import librosa
67
except ImportError:
78
librosa = None
89

910

11+
def _check_is_variable(tensor):
12+
if isinstance(tensor, torch.Tensor):
13+
is_variable = False
14+
tensor = Variable(tensor, requires_grad=False)
15+
elif isinstance(tensor, Variable):
16+
is_variable = True
17+
else:
18+
raise TypeError("tensor should be a Variable or Tensor, but is {}".format(type(tensor)))
19+
20+
return tensor, is_variable
21+
22+
23+
def _tlog10(x):
24+
"""Pytorch Log10
25+
"""
26+
return torch.log(x) / torch.log(x.new([10]))
27+
28+
1029
class Compose(object):
1130
"""Composes several transforms together.
1231
@@ -120,16 +139,200 @@ def __call__(self, tensor):
120139
"""
121140
122141
Args:
123-
tensor (Tensor): Tensor of spectrogram with shape (BxLxC)
142+
tensor (Tensor): Tensor of audio signal with shape (LxC)
124143
125144
Returns:
126-
tensor (Tensor): Tensor of spectrogram with shape (CxBxL)
145+
tensor (Tensor): Tensor of audio signal with shape (CxL)
127146
128147
"""
129148

130149
return tensor.transpose(0, 1).contiguous()
131150

132151

152+
class SPECTROGRAM(object):
153+
"""Create a spectrogram from a raw audio signal
154+
155+
Args:
156+
sr (int): sample rate of audio signal
157+
ws (int): window size, often called the fft size as well
158+
hop (int, optional): length of hop between STFT windows. default: ws // 2
159+
n_fft (int, optional): number of fft bins. default: ws // 2 + 1
160+
pad (int): two sided padding of signal
161+
window (torch windowing function): default: torch.hann_window
162+
wkwargs (dict, optional): arguments for window function
163+
164+
"""
165+
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
166+
pad=0, window=torch.hann_window, wkwargs=None):
167+
if isinstance(window, Variable):
168+
self.window = window
169+
else:
170+
self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
171+
self.window = Variable(self.window, volatile=True)
172+
self.sr = sr
173+
self.ws = ws
174+
self.hop = hop if hop is not None else ws // 2
175+
self.n_fft = n_fft # number of fft bins
176+
self.pad = pad
177+
self.wkwargs = wkwargs
178+
179+
def __call__(self, sig):
180+
"""
181+
Args:
182+
sig (Tensor or Variable): Tensor of audio of size (c, n)
183+
184+
Returns:
185+
spec_f (Tensor or Variable): channels x hops x n_fft (c, l, f), where channels
186+
is unchanged, hops is the number of hops, and n_fft is the
187+
number of fourier bins, which should be the window size divided
188+
by 2 plus 1.
189+
190+
"""
191+
sig, is_variable = _check_is_variable(sig)
192+
193+
assert sig.dim() == 2
194+
195+
spec_f = torch.stft(sig, self.ws, self.hop, self.n_fft,
196+
True, self.window, self.pad) # (c, l, n_fft, 2)
197+
spec_f /= self.window.pow(2).sum().sqrt()
198+
spec_f = spec_f.pow(2).sum(-1) # get power of "complex" tensor (c, l, n_fft)
199+
return spec_f if is_variable else spec_f.data
200+
201+
202+
class F2M(object):
203+
"""This turns a normal STFT into a MEL Frequency STFT, using a conversion
204+
matrix. This uses triangular filter banks.
205+
206+
Args:
207+
n_mels (int): number of MEL bins
208+
sr (int): sample rate of audio signal
209+
f_max (float, optional): maximum frequency. default: sr // 2
210+
f_min (float): minimum frequency. default: 0
211+
"""
212+
def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0.):
213+
self.n_mels = n_mels
214+
self.sr = sr
215+
self.f_max = f_max if f_max is not None else sr // 2
216+
self.f_min = f_min
217+
218+
def __call__(self, spec_f):
219+
220+
spec_f, is_variable = _check_is_variable(spec_f)
221+
n_fft = spec_f.size(2)
222+
223+
m_min = 0. if self.f_min == 0 else 2595 * np.log10(1. + (self.f_min / 700))
224+
m_max = 2595 * np.log10(1. + (self.f_max / 700))
225+
226+
m_pts = torch.linspace(m_min, m_max, self.n_mels + 2)
227+
f_pts = (700 * (10**(m_pts / 2595) - 1))
228+
229+
bins = torch.floor(((n_fft - 1) * 2) * f_pts / self.sr).long()
230+
231+
fb = torch.zeros(n_fft, self.n_mels)
232+
for m in range(1, self.n_mels + 1):
233+
f_m_minus = bins[m - 1]
234+
f_m = bins[m]
235+
f_m_plus = bins[m + 1]
236+
237+
if f_m_minus != f_m:
238+
fb[f_m_minus:f_m, m - 1] = (torch.arange(f_m_minus, f_m) - f_m_minus) / (f_m - f_m_minus)
239+
if f_m != f_m_plus:
240+
fb[f_m:f_m_plus, m - 1] = (f_m_plus - torch.arange(f_m, f_m_plus)) / (f_m_plus - f_m)
241+
242+
fb = Variable(fb)
243+
spec_m = torch.matmul(spec_f, fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
244+
return spec_m if is_variable else spec_m.data
245+
246+
247+
class SPEC2DB(object):
248+
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
249+
250+
Args:
251+
stype (str): scale of input spectrogram ("power" or "magnitude"). The
252+
power being the elementwise square of the magnitude. default: "power"
253+
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
254+
is -80.
255+
"""
256+
def __init__(self, stype="power", top_db=None):
257+
self.stype = stype
258+
self.top_db = -top_db if top_db > 0 else top_db
259+
self.multiplier = 10. if stype == "power" else 20.
260+
261+
def __call__(self, spec):
262+
263+
spec, is_variable = _check_is_variable(spec)
264+
spec_db = self.multiplier * _tlog10(spec / spec.max()) # power -> dB
265+
if self.top_db is not None:
266+
spec_db = torch.max(spec_db, spec_db.new([self.top_db]))
267+
return spec_db if is_variable else spec_db.data
268+
269+
270+
class MEL2(object):
271+
"""Create MEL Spectrograms from a raw audio signal using the stft
272+
function in PyTorch. Hopefully this solves the speed issue of using
273+
librosa.
274+
275+
Sources:
276+
* https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
277+
* https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
278+
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
279+
280+
Args:
281+
sr (int): sample rate of audio signal
282+
ws (int): window size, often called the fft size as well
283+
hop (int, optional): length of hop between STFT windows. default: ws // 2
284+
n_fft (int, optional): number of fft bins. default: ws // 2 + 1
285+
pad (int): two sided padding of signal
286+
n_mels (int): number of MEL bins
287+
window (torch windowing function): default: torch.hann_window
288+
wkwargs (dict, optional): arguments for window function
289+
290+
Example:
291+
>>> sig, sr = torchaudio.load("test.wav", normalization=True)
292+
>>> sig = transforms.LC2CL()(sig) # (n, c) -> (c, n)
293+
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
294+
"""
295+
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
296+
pad=0, n_mels=40, window=torch.hann_window, wkwargs=None):
297+
self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
298+
self.window = Variable(self.window, requires_grad=False)
299+
self.sr = sr
300+
self.ws = ws
301+
self.hop = hop if hop is not None else ws // 2
302+
self.n_fft = n_fft # number of fourier bins (ws // 2 + 1 by default)
303+
self.pad = pad
304+
self.n_mels = n_mels # number of mel frequency bins
305+
self.wkwargs = wkwargs
306+
self.top_db = -80.
307+
self.f_max = None
308+
self.f_min = 0.
309+
310+
def __call__(self, sig):
311+
"""
312+
Args:
313+
sig (Tensor): Tensor of audio of size (channels [c], samples [n])
314+
315+
Returns:
316+
spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels
317+
is unchanged, hops is the number of hops, and n_mels is the
318+
number of mel bins.
319+
320+
"""
321+
322+
sig, is_variable = _check_is_variable(sig)
323+
324+
transforms = Compose([
325+
SPECTROGRAM(self.sr, self.ws, self.hop, self.n_fft,
326+
self.pad, self.window),
327+
F2M(self.n_mels, self.sr, self.f_max, self.f_min),
328+
SPEC2DB("power", self.top_db),
329+
])
330+
331+
spec_mel_db = transforms(sig)
332+
333+
return spec_mel_db if is_variable else spec_mel_db.data
334+
335+
133336
class MEL(object):
134337
"""Create MEL Spectrograms from a raw audio signal. Relatively pretty slow.
135338
@@ -144,14 +347,15 @@ def __call__(self, tensor):
144347
"""
145348
146349
Args:
147-
tensor (Tensor): Tensor of audio of size (samples x channels)
350+
tensor (Tensor): Tensor of audio of size (samples [n] x channels [c])
148351
149352
Returns:
150353
tensor (Tensor): n_mels x hops x channels (BxLxC), where n_mels is
151354
the number of mel bins, hops is the number of hops, and channels
152355
is unchanged.
153356
154357
"""
358+
155359
if librosa is None:
156360
print("librosa not installed, cannot create spectrograms")
157361
return tensor

0 commit comments

Comments
 (0)