From c9a0c1f5f975903cb92801cf176fe9200c1ad6c3 Mon Sep 17 00:00:00 2001 From: "Charles J.Y. Yoon" <1242029+jaeyeun97@users.noreply.github.com> Date: Sat, 14 Dec 2019 02:02:59 +0900 Subject: [PATCH 01/11] Griffin-Lim Transformation Implementation --- torchaudio/transforms.py | 112 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index ed0fcece25..dabc00c22f 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -9,6 +9,7 @@ __all__ = [ 'Spectrogram', + 'GriffinLim', 'AmplitudeToDB', 'MelScale', 'MelSpectrogram', @@ -70,6 +71,117 @@ def forward(self, waveform): self.win_length, self.power, self.normalized) +class GriffinLim(torch.nn.Module): + r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation. + Implementation ported from `librosa`. + + .. [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto. + "librosa: Audio and music signal analysis in python." + In Proceedings of the 14th python in science conference, pp. 18-25. 2015. + + .. [2] Perraudin, N., Balazs, P., & Søndergaard, P. L. + "A fast Griffin-Lim algorithm," + IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4), + Oct. 2013. + + .. [3] D. W. Griffin and J. S. Lim, + "Signal estimation from modified short-time Fourier transform," + IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984. + + Args: + n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins + win_length (int): Window size. (Default: ``n_fft``) + hop_length (int, optional): Length of hop between STFT windows. ( + Default: ``win_length // 2``) + pad (int): Two sided padding of signal. (Default: ``0``) + window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) + power (int): Exponent for the magnitude spectrogram, + (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``) + normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``) + wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``) + length (int): Array length of the expected output. (Default: ``None``) + momentum (float): The momentum parameter for fast Griffin-Lim. + Setting this to 0 recovers the original Griffin-Lim method. + Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99) + """ + __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized', + 'length', 'momentum'] + + def __init__(self, n_fft=400, n_iter=32, hop_length=None, win_length=None, + window_fn=torch.hann_window, wkwargs=None, normalized=False, + power=2., length=None, momentum=0.99): + super(GriffinLim, self).__init__() + + assert momentum < 1, f'momentum={momentum} > 1 can be unstable' + assert momentum > 0, f'momentum={momentum} < 0' + + self.n_fft = n_fft + self.n_iter = n_iter + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 2 + self.window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) + self.normalized = normalized + self.length = length + self.power = power + self.momentum = momentum / (1 + momentum) + + def forward(self, S): + r""" + Args: + S (torch.Tensor): A magnitude-only STFT spectrogram of dimension (channel, freq, frames), + where freq is ``n_fft // 2 + 1``. + + Returns: + torch.Tensor: waveform of (channel, time), where time equals the ``length`` parameter if given. + """ + self.window = self.window.to(dtype=S.dtype, device=S.device) + + S = S.pow(1/self.power) + if self.normalized: + S *= self.window.pow(2).sum().sqrt() + + # randomly initialize the phase + batch, freq, frames = S.size() + angles = 2 * math.pi * torch.rand(batch, freq, frames) + angles = torch.stack([angles.cos(), angles.sin()], dim=-1).to(dtype=S.dtype, device=S.device) + S = S.unsqueeze(-1).expand_as(angles) + + # And initialize the previous iterate to 0 + rebuilt = 0. + + for _ in range(self.n_iter): + # Store the previous iterate + tprev = rebuilt + + # Invert with our current estimate of the phases + inverse = F.istft(S * angles, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + length=self.length).float() + + # Rebuild the spectrogram + rebuilt = inverse.stft(n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + pad_mode='reflect') + + # Update our phase estimates + angles = rebuilt.sub(self.momentum).mul_(tprev) + angles = angles.div_(F.complex_norm(angles).add_(1e-16).unsqueeze(-1).expand_as(angles)) + + # Return the final phase estimates + return F.istft(S * angles, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + length=self.length) + + class AmplitudeToDB(torch.jit.ScriptModule): r"""Turn a tensor from the power/amplitude scale to the decibel scale. From aa350155d44e8bfc4bf45a3c54735803acfd2d28 Mon Sep 17 00:00:00 2001 From: "Charles J.Y. Yoon" <1242029+jaeyeun97@users.noreply.github.com> Date: Sat, 14 Dec 2019 03:07:50 +0900 Subject: [PATCH 02/11] Griffin-Lim Docs --- docs/source/transforms.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 3ad461a6b6..b460115a88 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -16,6 +16,13 @@ Transforms are common audio transforms. They can be chained together using :clas .. automethod:: forward +:hidden:`GriffinLim` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: GriffinLim + + .. automethod:: forward + :hidden:`AmplitudeToDB` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 3efca154847f28faa67dbe47689e022c63ff857b Mon Sep 17 00:00:00 2001 From: "Charles J.Y. Yoon" <1242029+jaeyeun97@users.noreply.github.com> Date: Sat, 14 Dec 2019 04:36:22 +0900 Subject: [PATCH 03/11] Remove f-string from backwards compatibility --- torchaudio/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index dabc00c22f..67169769a7 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -113,8 +113,8 @@ def __init__(self, n_fft=400, n_iter=32, hop_length=None, win_length=None, power=2., length=None, momentum=0.99): super(GriffinLim, self).__init__() - assert momentum < 1, f'momentum={momentum} > 1 can be unstable' - assert momentum > 0, f'momentum={momentum} < 0' + assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum + assert momentum > 0, 'momentum=%s < 0' % momentum self.n_fft = n_fft self.n_iter = n_iter From 778d6d5b39def74b5b054442c4d6dbe5e96857e6 Mon Sep 17 00:00:00 2001 From: "Charles J.Y. Yoon" <1242029+jaeyeun97@users.noreply.github.com> Date: Sun, 15 Dec 2019 01:08:17 +0900 Subject: [PATCH 04/11] iSTFT is now jit-able. --- torchaudio/functional.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 77908ee927..6190956d09 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -118,7 +118,7 @@ def istft( """ stft_matrix_dim = stft_matrix.dim() assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim) - assert stft_matrix.nelement() > 0 + assert stft_matrix.numel() > 0 if stft_matrix_dim == 3: # add a channel dimension @@ -126,7 +126,7 @@ def istft( # pack batch shape = stft_matrix.size() - stft_matrix = stft_matrix.reshape(-1, *shape[-3:]) + stft_matrix = stft_matrix.reshape(-1, shape[-3], shape[-2], shape[-1]) # *shape[-3:]) dtype = stft_matrix.dtype device = stft_matrix.device @@ -151,7 +151,9 @@ def istft( assert 0 < win_length <= n_fft if window is None: - window = torch.ones(win_length, requires_grad=False, device=device, dtype=dtype) + window = torch.ones(win_length) + window.to(device=device, dtype=dtype) + # window.requires_grad_(False) assert window.dim() == 1 and window.size(0) == win_length @@ -174,9 +176,8 @@ def istft( # each column of a channel is a frame which needs to be overlap added at the right place ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frame) - eye = torch.eye(n_fft, requires_grad=False, device=device, dtype=dtype).unsqueeze( - 1 - ) # size (n_fft, 1, n_fft) + eye = torch.eye(n_fft) # requires_grad=False + eye = eye.to(device=device, dtype=dtype).unsqueeze(1) # size (n_fft, 1, n_fft) # this does overlap add where the frames of ytmp are added such that the i'th frame of # ytmp is added starting at i*hop_length in the output From d5c8b5d7198ad7766850f9924494d35c3554d6d2 Mon Sep 17 00:00:00 2001 From: "Charles J.Y. Yoon" <1242029+jaeyeun97@users.noreply.github.com> Date: Sun, 15 Dec 2019 01:09:14 +0900 Subject: [PATCH 05/11] Comment changes --- torchaudio/functional.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 6190956d09..2658a73c92 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -153,7 +153,6 @@ def istft( if window is None: window = torch.ones(win_length) window.to(device=device, dtype=dtype) - # window.requires_grad_(False) assert window.dim() == 1 and window.size(0) == win_length @@ -176,7 +175,7 @@ def istft( # each column of a channel is a frame which needs to be overlap added at the right place ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frame) - eye = torch.eye(n_fft) # requires_grad=False + eye = torch.eye(n_fft) eye = eye.to(device=device, dtype=dtype).unsqueeze(1) # size (n_fft, 1, n_fft) # this does overlap add where the frames of ytmp are added such that the i'th frame of From 14a37a2e5f0456de837a2ca8a3c6ac1752d4c8a9 Mon Sep 17 00:00:00 2001 From: "Charles J.Y. Yoon" <1242029+jaeyeun97@users.noreply.github.com> Date: Sun, 15 Dec 2019 02:31:31 +0900 Subject: [PATCH 06/11] Functional Implementation & now jitable --- test/test_functional.py | 17 ++++++++ test/test_transforms.py | 4 ++ torchaudio/functional.py | 91 ++++++++++++++++++++++++++++++++++++++++ torchaudio/transforms.py | 67 ++++------------------------- 4 files changed, 121 insertions(+), 58 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 0867600489..ff4f98c982 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -52,6 +52,23 @@ def test_torchscript_spectrogram(self): F.spectrogram, tensor, pad, window, n_fft, hop, ws, power, normalize ) + def test_torchscript_griffinlim(self): + tensor = torch.rand((1, 201, 6)) + n_fft = 400 + ws = 400 + hop = 200 + window = torch.hann_window(ws) + power = 2 + normalize = False + momentum = 0.99 + n_iter = 32 + length = 1000 + rand_init = False + + _test_torchscript_functional( + F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, rand_init + ) + def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): computed = F.compute_deltas(specgram, win_length=win_length) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) diff --git a/test/test_transforms.py b/test/test_transforms.py index c22169741f..7a00526490 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -67,6 +67,10 @@ def test_scriptmodule_Spectrogram(self): tensor = torch.rand((1, 1000)) _test_script_module(transforms.Spectrogram, tensor) + def test_scriptmodule_GriffinLim(self): + tensor = torch.rand((1, 201, 6)) + _test_script_module(transforms.GriffinLim, tensor, length=1000, rand_init=False) + def test_mu_law_companding(self): quantization_channels = 256 diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 2658a73c92..66f442f1b0 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -7,6 +7,7 @@ __all__ = [ "istft", "spectrogram", + "griffinlim", "amplitude_to_DB", "create_fb_matrix", "create_dct", @@ -270,6 +271,96 @@ def spectrogram( return spec_f +def griffinlim( + spectrogram, window, n_fft, hop_length, win_length, power, normalized, n_iter, momentum, length, rand_init +): + # type: (Tensor, Tensor, int, int, int, int, bool, int, float, Optional[int], bool) -> Tensor + r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation. + Implementation ported from `librosa`. + + .. [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto. + "librosa: Audio and music signal analysis in python." + In Proceedings of the 14th python in science conference, pp. 18-25. 2015. + + .. [2] Perraudin, N., Balazs, P., & Søndergaard, P. L. + "A fast Griffin-Lim algorithm," + IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4), + Oct. 2013. + + .. [3] D. W. Griffin and J. S. Lim, + "Signal estimation from modified short-time Fourier transform," + IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984. + + Args: + spectrogram (torch.Tensor): A magnitude-only STFT spectrogram of dimension (channel, freq, frames) + where freq is ``n_fft // 2 + 1``. + window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window + n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins + hop_length (int): Length of hop between STFT windows. ( + Default: ``win_length // 2``) + win_length (int): Window size. (Default: ``n_fft``) + power (int): Exponent for the magnitude spectrogram, + (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``) + normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``) + n_iter (int): Number of iteration for phase recovery process. + momentum (float): The momentum parameter for fast Griffin-Lim. + Setting this to 0 recovers the original Griffin-Lim method. + Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99) + length (Optional[int]): Array length of the expected output. (Default: ``None``) + rand_init (bool): Initializes phase randomly if true and to zero otherwise. + + Returns: + torch.Tensor: waveform of (channel, time), where time equals the ``length`` parameter if given. + """ + assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum + assert momentum > 0, 'momentum=%s < 0' % momentum + + spectrogram = spectrogram.pow(1/power) + if normalized: + spectrogram *= window.pow(2).sum().sqrt() + + # randomly initialize the phase + batch, freq, frames = spectrogram.size() + if rand_init: + angles = 2 * math.pi * torch.rand(batch, freq, frames) + else: + angles = torch.zeros(batch, freq, frames) + angles = torch.stack([angles.cos(), angles.sin()], dim=-1) \ + .to(dtype=spectrogram.dtype, device=spectrogram.device) + spectrogram = spectrogram.unsqueeze(-1).expand_as(angles) + + # And initialize the previous iterate to 0 + rebuilt = torch.tensor(0.) + + for _ in range(n_iter): + # Store the previous iterate + tprev = rebuilt + + # Invert with our current estimate of the phases + inverse = istft(spectrogram * angles, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + length=length).float() + + # Rebuild the spectrogram + rebuilt = _stft(inverse, n_fft, hop_length, win_length, window, + True, 'reflect', False, True) + + # Update our phase estimates + angles = rebuilt.sub(momentum).mul_(tprev) + angles = angles.div_(complex_norm(angles).add_(1e-16).unsqueeze(-1).expand_as(angles)) + + # Return the final phase estimates + return istft(spectrogram * angles, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + length=length) + + def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): # type: (Tensor, float, float, float, Optional[float]) -> Tensor r"""Turn a tensor from the power/amplitude scale to the decibel scale. diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 67169769a7..f7d1e4e8eb 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -90,27 +90,28 @@ class GriffinLim(torch.nn.Module): Args: n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins + n_iter (int, optional): Number of iteration for phase recovery process. win_length (int): Window size. (Default: ``n_fft``) hop_length (int, optional): Length of hop between STFT windows. ( Default: ``win_length // 2``) - pad (int): Two sided padding of signal. (Default: ``0``) window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) power (int): Exponent for the magnitude spectrogram, (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``) normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``) wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``) - length (int): Array length of the expected output. (Default: ``None``) momentum (float): The momentum parameter for fast Griffin-Lim. - Setting this to 0 recovers the original Griffin-Lim method. - Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99) + Setting this to 0 recovers the original Griffin-Lim method. + Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99) + length (int, optional): Array length of the expected output. (Default: ``None``) + rand_init(bool): """ __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized', 'length', 'momentum'] def __init__(self, n_fft=400, n_iter=32, hop_length=None, win_length=None, window_fn=torch.hann_window, wkwargs=None, normalized=False, - power=2., length=None, momentum=0.99): + power=2, length=None, momentum=0.99, rand_init=True): super(GriffinLim, self).__init__() assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum @@ -125,61 +126,11 @@ def __init__(self, n_fft=400, n_iter=32, hop_length=None, win_length=None, self.length = length self.power = power self.momentum = momentum / (1 + momentum) + self.rand_init = rand_init def forward(self, S): - r""" - Args: - S (torch.Tensor): A magnitude-only STFT spectrogram of dimension (channel, freq, frames), - where freq is ``n_fft // 2 + 1``. - - Returns: - torch.Tensor: waveform of (channel, time), where time equals the ``length`` parameter if given. - """ - self.window = self.window.to(dtype=S.dtype, device=S.device) - - S = S.pow(1/self.power) - if self.normalized: - S *= self.window.pow(2).sum().sqrt() - - # randomly initialize the phase - batch, freq, frames = S.size() - angles = 2 * math.pi * torch.rand(batch, freq, frames) - angles = torch.stack([angles.cos(), angles.sin()], dim=-1).to(dtype=S.dtype, device=S.device) - S = S.unsqueeze(-1).expand_as(angles) - - # And initialize the previous iterate to 0 - rebuilt = 0. - - for _ in range(self.n_iter): - # Store the previous iterate - tprev = rebuilt - - # Invert with our current estimate of the phases - inverse = F.istft(S * angles, - n_fft=self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window=self.window, - length=self.length).float() - - # Rebuild the spectrogram - rebuilt = inverse.stft(n_fft=self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window=self.window, - pad_mode='reflect') - - # Update our phase estimates - angles = rebuilt.sub(self.momentum).mul_(tprev) - angles = angles.div_(F.complex_norm(angles).add_(1e-16).unsqueeze(-1).expand_as(angles)) - - # Return the final phase estimates - return F.istft(S * angles, - n_fft=self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window=self.window, - length=self.length) + return F.griffinlim(S, self.window, self.n_fft, self.hop_length, self.win_length, + self.power, self.normalized, self.n_iter, self.momentum, self.length, self.rand_init) class AmplitudeToDB(torch.jit.ScriptModule): From 2d03b1172945571c0d8a177737c2e4ab2ff7c38e Mon Sep 17 00:00:00 2001 From: "Charles J.Y. Yoon" <1242029+jaeyeun97@users.noreply.github.com> Date: Sun, 15 Dec 2019 13:30:33 +0900 Subject: [PATCH 07/11] flake8 --- torchaudio/functional.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 66f442f1b0..edd0971e3b 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -127,7 +127,7 @@ def istft( # pack batch shape = stft_matrix.size() - stft_matrix = stft_matrix.reshape(-1, shape[-3], shape[-2], shape[-1]) # *shape[-3:]) + stft_matrix = stft_matrix.reshape(-1, shape[-3], shape[-2], shape[-1]) dtype = stft_matrix.dtype device = stft_matrix.device @@ -152,8 +152,8 @@ def istft( assert 0 < win_length <= n_fft if window is None: - window = torch.ones(win_length) - window.to(device=device, dtype=dtype) + window = torch.ones(win_length) + window.to(device=device, dtype=dtype) assert window.dim() == 1 and window.size(0) == win_length @@ -273,8 +273,8 @@ def spectrogram( def griffinlim( spectrogram, window, n_fft, hop_length, win_length, power, normalized, n_iter, momentum, length, rand_init -): - # type: (Tensor, Tensor, int, int, int, int, bool, int, float, Optional[int], bool) -> Tensor +): + # type: (Tensor, Tensor, int, int, int, int, bool, int, float, Optional[int], bool) -> Tensor r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation. Implementation ported from `librosa`. @@ -305,7 +305,7 @@ def griffinlim( n_iter (int): Number of iteration for phase recovery process. momentum (float): The momentum parameter for fast Griffin-Lim. Setting this to 0 recovers the original Griffin-Lim method. - Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99) + Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99) length (Optional[int]): Array length of the expected output. (Default: ``None``) rand_init (bool): Initializes phase randomly if true and to zero otherwise. @@ -315,7 +315,7 @@ def griffinlim( assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum assert momentum > 0, 'momentum=%s < 0' % momentum - spectrogram = spectrogram.pow(1/power) + spectrogram = spectrogram.pow(1 / power) if normalized: spectrogram *= window.pow(2).sum().sqrt() From b34e36f0b62ee0f0da5b7c14474b23d6ab8149c3 Mon Sep 17 00:00:00 2001 From: "Charles J.Y. Yoon" <1242029+jaeyeun97@users.noreply.github.com> Date: Sun, 15 Dec 2019 20:49:39 +0900 Subject: [PATCH 08/11] Doc & GPU Fix --- torchaudio/transforms.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index f7d1e4e8eb..48557f2cf4 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -104,7 +104,7 @@ class GriffinLim(torch.nn.Module): Setting this to 0 recovers the original Griffin-Lim method. Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99) length (int, optional): Array length of the expected output. (Default: ``None``) - rand_init(bool): + rand_init(bool): Initializes phase randomly if true and to zero otherwise. """ __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized', 'length', 'momentum'] @@ -121,7 +121,8 @@ def __init__(self, n_fft=400, n_iter=32, hop_length=None, win_length=None, self.n_iter = n_iter self.win_length = win_length if win_length is not None else n_fft self.hop_length = hop_length if hop_length is not None else self.win_length // 2 - self.window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) + window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) + self.register_buffer('window', window) self.normalized = normalized self.length = length self.power = power @@ -129,8 +130,8 @@ def __init__(self, n_fft=400, n_iter=32, hop_length=None, win_length=None, self.rand_init = rand_init def forward(self, S): - return F.griffinlim(S, self.window, self.n_fft, self.hop_length, self.win_length, - self.power, self.normalized, self.n_iter, self.momentum, self.length, self.rand_init) + return F.griffinlim(S, self.window, self.n_fft, self.hop_length, self.win_length, self.power, + self.normalized, self.n_iter, self.momentum, self.length, self.rand_init) class AmplitudeToDB(torch.jit.ScriptModule): From cff6cfc5dcbcc63528b721dcef632d9e017997b5 Mon Sep 17 00:00:00 2001 From: "Charles J.Y. Yoon" <1242029+jaeyeun97@users.noreply.github.com> Date: Tue, 24 Dec 2019 19:24:00 +0900 Subject: [PATCH 09/11] Librosa comparison test --- test/test_functional.py | 31 +++++++++++++++++++++++++++++-- torchaudio/functional.py | 6 ++---- torchaudio/transforms.py | 4 ++-- 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index ff4f98c982..1999883815 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -63,12 +63,39 @@ def test_torchscript_griffinlim(self): momentum = 0.99 n_iter = 32 length = 1000 - rand_init = False + init = 0 _test_torchscript_functional( - F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, rand_init + F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0 ) + @unittest.skipIf(not IMPORT_LIBROSA, 'Librosa not available') + def test_griffinlim(self): + tensor = torch.rand((1, 1000)) + + n_fft = 400 + ws = 400 + hop = 100 + window = torch.hann_window(ws) + normalize = False + momentum = 0.99 + n_iter = 8 + length = 1000 + rand_init = False + init = 'random' if rand_init else None + + specgram = F.spectrogram(tensor, 0, window, n_fft, hop, ws, 2, normalize).sqrt() + ta_out = F.griffinlim(specgram, window, n_fft, hop, ws, 1, normalize, + n_iter, momentum, length, rand_init) + lr_out = librosa.griffinlim(specgram.squeeze(0).numpy(), n_iter=n_iter, hop_length=hop, + momentum=momentum, init=init, length=length) + lr_out = torch.from_numpy(lr_out).unsqueeze(0) + + ta_spec = F.spectrogram(ta_out, 0, window, n_fft, hop, ws, 2, normalize).sqrt() + lr_spec = F.spectrogram(lr_out, 0, window, n_fft, hop, ws, 2, normalize).sqrt() + + self.assertTrue(torch.allclose(ta_spec, lr_spec, atol=5e-3)) + def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): computed = F.compute_deltas(specgram, win_length=win_length) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 67cf1e32b2..fe9ef1a7df 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -307,7 +307,7 @@ def griffinlim( Setting this to 0 recovers the original Griffin-Lim method. Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99) length (Optional[int]): Array length of the expected output. (Default: ``None``) - rand_init (bool): Initializes phase randomly if true and to zero otherwise. + rand_init (bool): Initializes phase randomly if True, to zero otherwise. (Default: ``True``) Returns: torch.Tensor: waveform of (channel, time), where time equals the ``length`` parameter if given. @@ -316,8 +316,6 @@ def griffinlim( assert momentum > 0, 'momentum=%s < 0' % momentum spectrogram = spectrogram.pow(1 / power) - if normalized: - spectrogram *= window.pow(2).sum().sqrt() # randomly initialize the phase batch, freq, frames = spectrogram.size() @@ -349,7 +347,7 @@ def griffinlim( True, 'reflect', False, True) # Update our phase estimates - angles = rebuilt.sub(momentum).mul_(tprev) + angles = rebuilt - tprev.mul_(momentum / (1 + momentum)) angles = angles.div_(complex_norm(angles).add_(1e-16).unsqueeze(-1).expand_as(angles)) # Return the final phase estimates diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 48557f2cf4..904df175a3 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -104,10 +104,10 @@ class GriffinLim(torch.nn.Module): Setting this to 0 recovers the original Griffin-Lim method. Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99) length (int, optional): Array length of the expected output. (Default: ``None``) - rand_init(bool): Initializes phase randomly if true and to zero otherwise. + rand_init (bool): Initializes phase randomly if True and to zero otherwise. (Default: ``True``) """ __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized', - 'length', 'momentum'] + 'length', 'momentum', 'rand_init'] def __init__(self, n_fft=400, n_iter=32, hop_length=None, win_length=None, window_fn=torch.hann_window, wkwargs=None, normalized=False, From 158f05adcdc2c799ed45efaa53b1c87d2e2a91d1 Mon Sep 17 00:00:00 2001 From: Vincent QB Date: Thu, 26 Dec 2019 10:21:33 -0500 Subject: [PATCH 10/11] test directly griffinlim's output. tighter atol. --- test/test_functional.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 1999883815..f461f90aec 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -91,10 +91,7 @@ def test_griffinlim(self): momentum=momentum, init=init, length=length) lr_out = torch.from_numpy(lr_out).unsqueeze(0) - ta_spec = F.spectrogram(ta_out, 0, window, n_fft, hop, ws, 2, normalize).sqrt() - lr_spec = F.spectrogram(lr_out, 0, window, n_fft, hop, ws, 2, normalize).sqrt() - - self.assertTrue(torch.allclose(ta_spec, lr_spec, atol=5e-3)) + self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5)) def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): computed = F.compute_deltas(specgram, win_length=win_length) From 9ee64dcc6cf33dff42a01475186abb49d2518c7c Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 26 Dec 2019 10:57:57 -0500 Subject: [PATCH 11/11] matching signature to docstring. --- torchaudio/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index afa6f0343d..0c5d2c645c 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -109,9 +109,9 @@ class GriffinLim(torch.nn.Module): __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized', 'length', 'momentum', 'rand_init'] - def __init__(self, n_fft=400, n_iter=32, hop_length=None, win_length=None, - window_fn=torch.hann_window, wkwargs=None, normalized=False, - power=2, length=None, momentum=0.99, rand_init=True): + def __init__(self, n_fft=400, n_iter=32, win_length=None, hop_length=None, + window_fn=torch.hann_window, power=2, normalized=False, wkwargs=None, + momentum=0.99, length=None, rand_init=True): super(GriffinLim, self).__init__() assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum