Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ class Tester(unittest.TestCase):

sr = 16000
freq = 440
volume = .3
volume = 0.3
sig = (torch.cos(2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
# sig = (torch.cos((1+torch.arange(0, 4 * sr) * 2) / sr * 2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
sig.unsqueeze_(1)
sig = (sig * volume * 2**31).long()

Expand Down Expand Up @@ -75,11 +74,11 @@ def test_mel(self):

audio = self.sig.clone()
audio = transforms.Scale()(audio)
self.assertTrue(audio.dim() == 2)
self.assertTrue(len(audio.size()) == 2)
result = transforms.MEL()(audio)
self.assertTrue(result.dim() == 3)
self.assertTrue(len(result.size()) == 3)
result = transforms.BLC2CBL()(result)
self.assertTrue(result.dim() == 3)
self.assertTrue(len(result.size()) == 3)

def test_compose(self):

Expand Down Expand Up @@ -122,13 +121,6 @@ def test_mu_law_companding(self):
sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu)
self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.)

def test_mel2(self):
audio_orig = self.sig.clone() # (16000, 1)
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
spectrogram_torch = transforms.MEL2()(audio_scaled) # (1, 319, 40)
self.assertTrue(spectrogram_torch.dim() == 3)
self.assertTrue(spectrogram_torch.max() <= 0.)

if __name__ == '__main__':
unittest.main()
210 changes: 3 additions & 207 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,12 @@
from __future__ import division, print_function
import torch
from torch.autograd import Variable
import numpy as np
try:
import librosa
except ImportError:
librosa = None


def _check_is_variable(tensor):
if isinstance(tensor, torch.Tensor):
is_variable = False
tensor = Variable(tensor, requires_grad=False)
elif isinstance(tensor, Variable):
is_variable = True
else:
raise TypeError("tensor should be a Variable or Tensor, but is {}".format(type(tensor)))

return tensor, is_variable


def _tlog10(x):
"""Pytorch Log10
"""
return torch.log(x) / torch.log(x.new([10]))


class Compose(object):
"""Composes several transforms together.

Expand Down Expand Up @@ -139,200 +120,16 @@ def __call__(self, tensor):
"""

Args:
tensor (Tensor): Tensor of audio signal with shape (LxC)
tensor (Tensor): Tensor of spectrogram with shape (BxLxC)

Returns:
tensor (Tensor): Tensor of audio signal with shape (CxL)
tensor (Tensor): Tensor of spectrogram with shape (CxBxL)

"""

return tensor.transpose(0, 1).contiguous()


class SPECTROGRAM(object):
"""Create a spectrogram from a raw audio signal

Args:
sr (int): sample rate of audio signal
ws (int): window size, often called the fft size as well
hop (int, optional): length of hop between STFT windows. default: ws // 2
n_fft (int, optional): number of fft bins. default: ws // 2 + 1
pad (int): two sided padding of signal
window (torch windowing function): default: torch.hann_window
wkwargs (dict, optional): arguments for window function

"""
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
pad=0, window=torch.hann_window, wkwargs=None):
if isinstance(window, Variable):
self.window = window
else:
self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
self.window = Variable(self.window, volatile=True)
self.sr = sr
self.ws = ws
self.hop = hop if hop is not None else ws // 2
self.n_fft = n_fft # number of fft bins
self.pad = pad
self.wkwargs = wkwargs

def __call__(self, sig):
"""
Args:
sig (Tensor or Variable): Tensor of audio of size (c, n)

Returns:
spec_f (Tensor or Variable): channels x hops x n_fft (c, l, f), where channels
is unchanged, hops is the number of hops, and n_fft is the
number of fourier bins, which should be the window size divided
by 2 plus 1.

"""
sig, is_variable = _check_is_variable(sig)

assert sig.dim() == 2

spec_f = torch.stft(sig, self.ws, self.hop, self.n_fft,
True, self.window, self.pad) # (c, l, n_fft, 2)
spec_f /= self.window.pow(2).sum().sqrt()
spec_f = spec_f.pow(2).sum(-1) # get power of "complex" tensor (c, l, n_fft)
return spec_f if is_variable else spec_f.data


class F2M(object):
"""This turns a normal STFT into a MEL Frequency STFT, using a conversion
matrix. This uses triangular filter banks.

Args:
n_mels (int): number of MEL bins
sr (int): sample rate of audio signal
f_max (float, optional): maximum frequency. default: sr // 2
f_min (float): minimum frequency. default: 0
"""
def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0.):
self.n_mels = n_mels
self.sr = sr
self.f_max = f_max if f_max is not None else sr // 2
self.f_min = f_min

def __call__(self, spec_f):

spec_f, is_variable = _check_is_variable(spec_f)
n_fft = spec_f.size(2)

m_min = 0. if self.f_min == 0 else 2595 * np.log10(1. + (self.f_min / 700))
m_max = 2595 * np.log10(1. + (self.f_max / 700))

m_pts = torch.linspace(m_min, m_max, self.n_mels + 2)
f_pts = (700 * (10**(m_pts / 2595) - 1))

bins = torch.floor(((n_fft - 1) * 2) * f_pts / self.sr).long()

fb = torch.zeros(n_fft, self.n_mels)
for m in range(1, self.n_mels + 1):
f_m_minus = bins[m - 1]
f_m = bins[m]
f_m_plus = bins[m + 1]

if f_m_minus != f_m:
fb[f_m_minus:f_m, m - 1] = (torch.arange(f_m_minus, f_m) - f_m_minus) / (f_m - f_m_minus)
if f_m != f_m_plus:
fb[f_m:f_m_plus, m - 1] = (f_m_plus - torch.arange(f_m, f_m_plus)) / (f_m_plus - f_m)

fb = Variable(fb)
spec_m = torch.matmul(spec_f, fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m if is_variable else spec_m.data


class SPEC2DB(object):
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.

Args:
stype (str): scale of input spectrogram ("power" or "magnitude"). The
power being the elementwise square of the magnitude. default: "power"
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
is -80.
"""
def __init__(self, stype="power", top_db=None):
self.stype = stype
self.top_db = -top_db if top_db > 0 else top_db
self.multiplier = 10. if stype == "power" else 20.

def __call__(self, spec):

spec, is_variable = _check_is_variable(spec)
spec_db = self.multiplier * _tlog10(spec / spec.max()) # power -> dB
if self.top_db is not None:
spec_db = torch.max(spec_db, spec_db.new([self.top_db]))
return spec_db if is_variable else spec_db.data


class MEL2(object):
"""Create MEL Spectrograms from a raw audio signal using the stft
function in PyTorch. Hopefully this solves the speed issue of using
librosa.

Sources:
* https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
* https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html

Args:
sr (int): sample rate of audio signal
ws (int): window size, often called the fft size as well
hop (int, optional): length of hop between STFT windows. default: ws // 2
n_fft (int, optional): number of fft bins. default: ws // 2 + 1
pad (int): two sided padding of signal
n_mels (int): number of MEL bins
window (torch windowing function): default: torch.hann_window
wkwargs (dict, optional): arguments for window function

Example:
>>> sig, sr = torchaudio.load("test.wav", normalization=True)
>>> sig = transforms.LC2CL()(sig) # (n, c) -> (c, n)
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
"""
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
pad=0, n_mels=40, window=torch.hann_window, wkwargs=None):
self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
self.window = Variable(self.window, requires_grad=False)
self.sr = sr
self.ws = ws
self.hop = hop if hop is not None else ws // 2
self.n_fft = n_fft # number of fourier bins (ws // 2 + 1 by default)
self.pad = pad
self.n_mels = n_mels # number of mel frequency bins
self.wkwargs = wkwargs
self.top_db = -80.
self.f_max = None
self.f_min = 0.

def __call__(self, sig):
"""
Args:
sig (Tensor): Tensor of audio of size (channels [c], samples [n])

Returns:
spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels
is unchanged, hops is the number of hops, and n_mels is the
number of mel bins.

"""

sig, is_variable = _check_is_variable(sig)

transforms = Compose([
SPECTROGRAM(self.sr, self.ws, self.hop, self.n_fft,
self.pad, self.window),
F2M(self.n_mels, self.sr, self.f_max, self.f_min),
SPEC2DB("power", self.top_db),
])

spec_mel_db = transforms(sig)

return spec_mel_db if is_variable else spec_mel_db.data


class MEL(object):
"""Create MEL Spectrograms from a raw audio signal. Relatively pretty slow.

Expand All @@ -347,15 +144,14 @@ def __call__(self, tensor):
"""

Args:
tensor (Tensor): Tensor of audio of size (samples [n] x channels [c])
tensor (Tensor): Tensor of audio of size (samples x channels)

Returns:
tensor (Tensor): n_mels x hops x channels (BxLxC), where n_mels is
the number of mel bins, hops is the number of hops, and channels
is unchanged.

"""

if librosa is None:
print("librosa not installed, cannot create spectrograms")
return tensor
Expand Down