Skip to content
Merged
7 changes: 7 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ Transforms are common audio transforms. They can be chained together using :clas

.. automethod:: forward

:hidden:`GriffinLim`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: GriffinLim

.. automethod:: forward

:hidden:`AmplitudeToDB`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
41 changes: 41 additions & 0 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,47 @@ def test_torchscript_spectrogram(self):
F.spectrogram, tensor, pad, window, n_fft, hop, ws, power, normalize
)

def test_torchscript_griffinlim(self):
tensor = torch.rand((1, 201, 6))
n_fft = 400
ws = 400
hop = 200
window = torch.hann_window(ws)
power = 2
normalize = False
momentum = 0.99
n_iter = 32
length = 1000
init = 0

_test_torchscript_functional(
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0
)

@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa not available')
def test_griffinlim(self):
tensor = torch.rand((1, 1000))

n_fft = 400
ws = 400
hop = 100
window = torch.hann_window(ws)
normalize = False
momentum = 0.99
n_iter = 8
length = 1000
rand_init = False
init = 'random' if rand_init else None

specgram = F.spectrogram(tensor, 0, window, n_fft, hop, ws, 2, normalize).sqrt()
ta_out = F.griffinlim(specgram, window, n_fft, hop, ws, 1, normalize,
n_iter, momentum, length, rand_init)
lr_out = librosa.griffinlim(specgram.squeeze(0).numpy(), n_iter=n_iter, hop_length=hop,
momentum=momentum, init=init, length=length)
lr_out = torch.from_numpy(lr_out).unsqueeze(0)

self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5))

def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
Expand Down
4 changes: 4 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def test_scriptmodule_Spectrogram(self):
tensor = torch.rand((1, 1000))
_test_script_module(transforms.Spectrogram, tensor)

def test_scriptmodule_GriffinLim(self):
tensor = torch.rand((1, 201, 6))
_test_script_module(transforms.GriffinLim, tensor, length=1000, rand_init=False)

def test_mu_law_companding(self):

quantization_channels = 256
Expand Down
101 changes: 95 additions & 6 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
__all__ = [
"istft",
"spectrogram",
"griffinlim",
"amplitude_to_DB",
"create_fb_matrix",
"create_dct",
Expand Down Expand Up @@ -118,15 +119,15 @@ def istft(
"""
stft_matrix_dim = stft_matrix.dim()
assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim)
assert stft_matrix.nelement() > 0
assert stft_matrix.numel() > 0

if stft_matrix_dim == 3:
# add a channel dimension
stft_matrix = stft_matrix.unsqueeze(0)

# pack batch
shape = stft_matrix.size()
stft_matrix = stft_matrix.reshape(-1, *shape[-3:])
stft_matrix = stft_matrix.reshape(-1, shape[-3], shape[-2], shape[-1])

dtype = stft_matrix.dtype
device = stft_matrix.device
Expand All @@ -151,7 +152,8 @@ def istft(
assert 0 < win_length <= n_fft

if window is None:
window = torch.ones(win_length, requires_grad=False, device=device, dtype=dtype)
window = torch.ones(win_length)
window.to(device=device, dtype=dtype)

assert window.dim() == 1 and window.size(0) == win_length

Expand All @@ -174,9 +176,8 @@ def istft(
# each column of a channel is a frame which needs to be overlap added at the right place
ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frame)

eye = torch.eye(n_fft, requires_grad=False, device=device, dtype=dtype).unsqueeze(
1
) # size (n_fft, 1, n_fft)
eye = torch.eye(n_fft)
eye = eye.to(device=device, dtype=dtype).unsqueeze(1) # size (n_fft, 1, n_fft)

# this does overlap add where the frames of ytmp are added such that the i'th frame of
# ytmp is added starting at i*hop_length in the output
Expand Down Expand Up @@ -270,6 +271,94 @@ def spectrogram(
return spec_f


def griffinlim(
spectrogram, window, n_fft, hop_length, win_length, power, normalized, n_iter, momentum, length, rand_init
):
# type: (Tensor, Tensor, int, int, int, int, bool, int, float, Optional[int], bool) -> Tensor
r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
Implementation ported from `librosa`.

.. [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto.
"librosa: Audio and music signal analysis in python."
In Proceedings of the 14th python in science conference, pp. 18-25. 2015.

.. [2] Perraudin, N., Balazs, P., & Søndergaard, P. L.
"A fast Griffin-Lim algorithm,"
IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
Oct. 2013.

.. [3] D. W. Griffin and J. S. Lim,
"Signal estimation from modified short-time Fourier transform,"
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.

Args:
spectrogram (torch.Tensor): A magnitude-only STFT spectrogram of dimension (channel, freq, frames)
where freq is ``n_fft // 2 + 1``.
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
hop_length (int): Length of hop between STFT windows. (
Default: ``win_length // 2``)
win_length (int): Window size. (Default: ``n_fft``)
power (int): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``)
n_iter (int): Number of iteration for phase recovery process.
momentum (float): The momentum parameter for fast Griffin-Lim.
Setting this to 0 recovers the original Griffin-Lim method.
Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99)
length (Optional[int]): Array length of the expected output. (Default: ``None``)
rand_init (bool): Initializes phase randomly if True, to zero otherwise. (Default: ``True``)

Returns:
torch.Tensor: waveform of (channel, time), where time equals the ``length`` parameter if given.
"""
assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
assert momentum > 0, 'momentum=%s < 0' % momentum

spectrogram = spectrogram.pow(1 / power)

# randomly initialize the phase
batch, freq, frames = spectrogram.size()
if rand_init:
angles = 2 * math.pi * torch.rand(batch, freq, frames)
else:
angles = torch.zeros(batch, freq, frames)
angles = torch.stack([angles.cos(), angles.sin()], dim=-1) \
.to(dtype=spectrogram.dtype, device=spectrogram.device)
spectrogram = spectrogram.unsqueeze(-1).expand_as(angles)

# And initialize the previous iterate to 0
rebuilt = torch.tensor(0.)

for _ in range(n_iter):
# Store the previous iterate
tprev = rebuilt

# Invert with our current estimate of the phases
inverse = istft(spectrogram * angles,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
length=length).float()

# Rebuild the spectrogram
rebuilt = _stft(inverse, n_fft, hop_length, win_length, window,
True, 'reflect', False, True)

# Update our phase estimates
angles = rebuilt - tprev.mul_(momentum / (1 + momentum))
angles = angles.div_(complex_norm(angles).add_(1e-16).unsqueeze(-1).expand_as(angles))

# Return the final phase estimates
return istft(spectrogram * angles,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
length=length)


def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
r"""Turn a tensor from the power/amplitude scale to the decibel scale.
Expand Down
64 changes: 64 additions & 0 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

__all__ = [
'Spectrogram',
'GriffinLim',
'AmplitudeToDB',
'MelScale',
'MelSpectrogram',
Expand Down Expand Up @@ -70,6 +71,69 @@ def forward(self, waveform):
self.win_length, self.power, self.normalized)


class GriffinLim(torch.nn.Module):
r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
Implementation ported from `librosa`.

.. [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto.
"librosa: Audio and music signal analysis in python."
In Proceedings of the 14th python in science conference, pp. 18-25. 2015.

.. [2] Perraudin, N., Balazs, P., & Søndergaard, P. L.
"A fast Griffin-Lim algorithm,"
IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
Oct. 2013.

.. [3] D. W. Griffin and J. S. Lim,
"Signal estimation from modified short-time Fourier transform,"
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.

Args:
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
n_iter (int, optional): Number of iteration for phase recovery process.
win_length (int): Window size. (Default: ``n_fft``)
hop_length (int, optional): Length of hop between STFT windows. (
Default: ``win_length // 2``)
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
power (int): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
momentum (float): The momentum parameter for fast Griffin-Lim.
Setting this to 0 recovers the original Griffin-Lim method.
Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99)
length (int, optional): Array length of the expected output. (Default: ``None``)
rand_init (bool): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
"""
__constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized',
'length', 'momentum', 'rand_init']

def __init__(self, n_fft=400, n_iter=32, win_length=None, hop_length=None,
window_fn=torch.hann_window, power=2, normalized=False, wkwargs=None,
momentum=0.99, length=None, rand_init=True):
super(GriffinLim, self).__init__()

assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
assert momentum > 0, 'momentum=%s < 0' % momentum

self.n_fft = n_fft
self.n_iter = n_iter
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window)
self.normalized = normalized
self.length = length
self.power = power
self.momentum = momentum / (1 + momentum)
self.rand_init = rand_init

def forward(self, S):
return F.griffinlim(S, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)


class AmplitudeToDB(torch.jit.ScriptModule):
r"""Turn a tensor from the power/amplitude scale to the decibel scale.

Expand Down