11from __future__ import division , print_function
22import torch
3+ from torch .autograd import Variable
34import numpy as np
45try :
56 import librosa
67except ImportError :
78 librosa = None
89
910
11+ def _check_is_variable (tensor ):
12+ if isinstance (tensor , torch .Tensor ):
13+ is_variable = False
14+ tensor = Variable (tensor , requires_grad = False )
15+ elif isinstance (tensor , Variable ):
16+ is_variable = True
17+ else :
18+ raise TypeError ("tensor should be a Variable or Tensor, but is {}" .format (type (tensor )))
19+
20+ return tensor , is_variable
21+
22+
23+ def _tlog10 (x ):
24+ """Pytorch Log10
25+ """
26+ return torch .log (x ) / torch .log (x .new ([10 ]))
27+
28+
1029class Compose (object ):
1130 """Composes several transforms together.
1231
@@ -120,16 +139,200 @@ def __call__(self, tensor):
120139 """
121140
122141 Args:
123- tensor (Tensor): Tensor of spectrogram with shape (BxLxC )
142+ tensor (Tensor): Tensor of audio signal with shape (LxC )
124143
125144 Returns:
126- tensor (Tensor): Tensor of spectrogram with shape (CxBxL )
145+ tensor (Tensor): Tensor of audio signal with shape (CxL )
127146
128147 """
129148
130149 return tensor .transpose (0 , 1 ).contiguous ()
131150
132151
152+ class SPECTROGRAM (object ):
153+ """Create a spectrogram from a raw audio signal
154+
155+ Args:
156+ sr (int): sample rate of audio signal
157+ ws (int): window size, often called the fft size as well
158+ hop (int, optional): length of hop between STFT windows. default: ws // 2
159+ n_fft (int, optional): number of fft bins. default: ws // 2 + 1
160+ pad (int): two sided padding of signal
161+ window (torch windowing function): default: torch.hann_window
162+ wkwargs (dict, optional): arguments for window function
163+
164+ """
165+ def __init__ (self , sr = 16000 , ws = 400 , hop = None , n_fft = None ,
166+ pad = 0 , window = torch .hann_window , wkwargs = None ):
167+ if isinstance (window , Variable ):
168+ self .window = window
169+ else :
170+ self .window = window (ws ) if wkwargs is None else window (ws , ** wkwargs )
171+ self .window = Variable (self .window , volatile = True )
172+ self .sr = sr
173+ self .ws = ws
174+ self .hop = hop if hop is not None else ws // 2
175+ self .n_fft = n_fft # number of fft bins
176+ self .pad = pad
177+ self .wkwargs = wkwargs
178+
179+ def __call__ (self , sig ):
180+ """
181+ Args:
182+ sig (Tensor or Variable): Tensor of audio of size (c, n)
183+
184+ Returns:
185+ spec_f (Tensor or Variable): channels x hops x n_fft (c, l, f), where channels
186+ is unchanged, hops is the number of hops, and n_fft is the
187+ number of fourier bins, which should be the window size divided
188+ by 2 plus 1.
189+
190+ """
191+ sig , is_variable = _check_is_variable (sig )
192+
193+ assert sig .dim () == 2
194+
195+ spec_f = torch .stft (sig , self .ws , self .hop , self .n_fft ,
196+ True , self .window , self .pad ) # (c, l, n_fft, 2)
197+ spec_f /= self .window .pow (2 ).sum ().sqrt ()
198+ spec_f = spec_f .pow (2 ).sum (- 1 ) # get power of "complex" tensor (c, l, n_fft)
199+ return spec_f if is_variable else spec_f .data
200+
201+
202+ class F2M (object ):
203+ """This turns a normal STFT into a MEL Frequency STFT, using a conversion
204+ matrix. This uses triangular filter banks.
205+
206+ Args:
207+ n_mels (int): number of MEL bins
208+ sr (int): sample rate of audio signal
209+ f_max (float, optional): maximum frequency. default: sr // 2
210+ f_min (float): minimum frequency. default: 0
211+ """
212+ def __init__ (self , n_mels = 40 , sr = 16000 , f_max = None , f_min = 0. ):
213+ self .n_mels = n_mels
214+ self .sr = sr
215+ self .f_max = f_max if f_max is not None else sr // 2
216+ self .f_min = f_min
217+
218+ def __call__ (self , spec_f ):
219+
220+ spec_f , is_variable = _check_is_variable (spec_f )
221+ n_fft = spec_f .size (2 )
222+
223+ m_min = 0. if self .f_min == 0 else 2595 * np .log10 (1. + (self .f_min / 700 ))
224+ m_max = 2595 * np .log10 (1. + (self .f_max / 700 ))
225+
226+ m_pts = torch .linspace (m_min , m_max , self .n_mels + 2 )
227+ f_pts = (700 * (10 ** (m_pts / 2595 ) - 1 ))
228+
229+ bins = torch .floor (((n_fft - 1 ) * 2 ) * f_pts / self .sr ).long ()
230+
231+ fb = torch .zeros (n_fft , self .n_mels )
232+ for m in range (1 , self .n_mels + 1 ):
233+ f_m_minus = bins [m - 1 ]
234+ f_m = bins [m ]
235+ f_m_plus = bins [m + 1 ]
236+
237+ if f_m_minus != f_m :
238+ fb [f_m_minus :f_m , m - 1 ] = (torch .arange (f_m_minus , f_m ) - f_m_minus ) / (f_m - f_m_minus )
239+ if f_m != f_m_plus :
240+ fb [f_m :f_m_plus , m - 1 ] = (f_m_plus - torch .arange (f_m , f_m_plus )) / (f_m_plus - f_m )
241+
242+ fb = Variable (fb )
243+ spec_m = torch .matmul (spec_f , fb ) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
244+ return spec_m if is_variable else spec_m .data
245+
246+
247+ class SPEC2DB (object ):
248+ """Turns a spectrogram from the power/amplitude scale to the decibel scale.
249+
250+ Args:
251+ stype (str): scale of input spectrogram ("power" or "magnitude"). The
252+ power being the elementwise square of the magnitude. default: "power"
253+ top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
254+ is -80.
255+ """
256+ def __init__ (self , stype = "power" , top_db = None ):
257+ self .stype = stype
258+ self .top_db = - top_db if top_db > 0 else top_db
259+ self .multiplier = 10. if stype == "power" else 20.
260+
261+ def __call__ (self , spec ):
262+
263+ spec , is_variable = _check_is_variable (spec )
264+ spec_db = self .multiplier * _tlog10 (spec / spec .max ()) # power -> dB
265+ if self .top_db is not None :
266+ spec_db = torch .max (spec_db , spec_db .new ([self .top_db ]))
267+ return spec_db if is_variable else spec_db .data
268+
269+
270+ class MEL2 (object ):
271+ """Create MEL Spectrograms from a raw audio signal using the stft
272+ function in PyTorch. Hopefully this solves the speed issue of using
273+ librosa.
274+
275+ Sources:
276+ * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
277+ * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
278+ * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
279+
280+ Args:
281+ sr (int): sample rate of audio signal
282+ ws (int): window size, often called the fft size as well
283+ hop (int, optional): length of hop between STFT windows. default: ws // 2
284+ n_fft (int, optional): number of fft bins. default: ws // 2 + 1
285+ pad (int): two sided padding of signal
286+ n_mels (int): number of MEL bins
287+ window (torch windowing function): default: torch.hann_window
288+ wkwargs (dict, optional): arguments for window function
289+
290+ Example:
291+ >>> sig, sr = torchaudio.load("test.wav", normalization=True)
292+ >>> sig = transforms.LC2CL()(sig) # (n, c) -> (c, n)
293+ >>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
294+ """
295+ def __init__ (self , sr = 16000 , ws = 400 , hop = None , n_fft = None ,
296+ pad = 0 , n_mels = 40 , window = torch .hann_window , wkwargs = None ):
297+ self .window = window (ws ) if wkwargs is None else window (ws , ** wkwargs )
298+ self .window = Variable (self .window , requires_grad = False )
299+ self .sr = sr
300+ self .ws = ws
301+ self .hop = hop if hop is not None else ws // 2
302+ self .n_fft = n_fft # number of fourier bins (ws // 2 + 1 by default)
303+ self .pad = pad
304+ self .n_mels = n_mels # number of mel frequency bins
305+ self .wkwargs = wkwargs
306+ self .top_db = - 80.
307+ self .f_max = None
308+ self .f_min = 0.
309+
310+ def __call__ (self , sig ):
311+ """
312+ Args:
313+ sig (Tensor): Tensor of audio of size (channels [c], samples [n])
314+
315+ Returns:
316+ spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels
317+ is unchanged, hops is the number of hops, and n_mels is the
318+ number of mel bins.
319+
320+ """
321+
322+ sig , is_variable = _check_is_variable (sig )
323+
324+ transforms = Compose ([
325+ SPECTROGRAM (self .sr , self .ws , self .hop , self .n_fft ,
326+ self .pad , self .window ),
327+ F2M (self .n_mels , self .sr , self .f_max , self .f_min ),
328+ SPEC2DB ("power" , self .top_db ),
329+ ])
330+
331+ spec_mel_db = transforms (sig )
332+
333+ return spec_mel_db if is_variable else spec_mel_db .data
334+
335+
133336class MEL (object ):
134337 """Create MEL Spectrograms from a raw audio signal. Relatively pretty slow.
135338
@@ -144,14 +347,15 @@ def __call__(self, tensor):
144347 """
145348
146349 Args:
147- tensor (Tensor): Tensor of audio of size (samples x channels)
350+ tensor (Tensor): Tensor of audio of size (samples [n] x channels [c] )
148351
149352 Returns:
150353 tensor (Tensor): n_mels x hops x channels (BxLxC), where n_mels is
151354 the number of mel bins, hops is the number of hops, and channels
152355 is unchanged.
153356
154357 """
358+
155359 if librosa is None :
156360 print ("librosa not installed, cannot create spectrograms" )
157361 return tensor
0 commit comments