-
Notifications
You must be signed in to change notification settings - Fork 736
Add Tacotron2 loss function #1625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
b217afd
3365afa
1996fec
7c522f0
b6288bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .loss_function import Tacotron2Loss | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| # ***************************************************************************** | ||
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Redistribution and use in source and binary forms, with or without | ||
| # modification, are permitted provided that the following conditions are met: | ||
| # * Redistributions of source code must retain the above copyright | ||
| # notice, this list of conditions and the following disclaimer. | ||
| # * Redistributions in binary form must reproduce the above copyright | ||
| # notice, this list of conditions and the following disclaimer in the | ||
| # documentation and/or other materials provided with the distribution. | ||
| # * Neither the name of the NVIDIA CORPORATION nor the | ||
| # names of its contributors may be used to endorse or promote products | ||
| # derived from this software without specific prior written permission. | ||
| # | ||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | ||
| # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | ||
| # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
| # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | ||
| # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | ||
| # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | ||
| # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | ||
| # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
| # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | ||
| # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
| # | ||
| # ***************************************************************************** | ||
|
|
||
| from typing import Tuple | ||
|
|
||
| from torch import nn, Tensor | ||
|
|
||
|
|
||
| class Tacotron2Loss(nn.Module): | ||
| """Tacotron2 loss function adapted from: | ||
| https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/loss_function.py | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def forward( | ||
| self, | ||
| model_outputs: Tuple[Tensor, Tensor, Tensor], | ||
| targets: Tuple[Tensor, Tensor], | ||
| ) -> Tuple[Tensor, Tensor, Tensor]: | ||
| r"""Pass the input through the Tacotron2 loss. | ||
|
|
||
| The original implementation was introduced in | ||
| *Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions* | ||
| [:footcite:`shen2018natural`]. | ||
|
|
||
| Args: | ||
| model_outputs (tuple of three Tensors): The outputs of the | ||
| Tacotron2. These outputs should include three items: | ||
| (1) the predicted mel spectrogram before the postnet (``mel_specgram``) | ||
| with shape (batch, mel, time). | ||
| (2) predicted mel spectrogram after the postnet (``mel_specgram_postnet``) | ||
| with shape (batch, mel, time), and | ||
| (3) the stop token prediction (``gate_out``) with shape (batch). | ||
| targets (tuple of two Tensors): The ground truth mel spectrogram (batch, mel, time) and | ||
| stop token with shape (batch). | ||
|
|
||
| Returns: | ||
| mel_loss (Tensor): The mean MSE of the mel_specgram and ground truth mel spectrogram with shape (batch, ). | ||
| mel_postnet_loss (Tensor): The mean MSE of the mel_specgram_postnet and | ||
| ground truth mel spectrogram with shape (batch, ). | ||
| gate_loss (Tensor): The mean binary cross entropy loss of | ||
| the prediction on the stop token with shape (batch, ). | ||
| """ | ||
| mel_target, gate_target = targets[0], targets[1] | ||
| mel_target.requires_grad = False | ||
| gate_target.requires_grad = False | ||
| gate_target = gate_target.view(-1, 1) | ||
|
|
||
| mel_specgram, mel_specgram_postnet, gate_out = model_outputs | ||
| gate_out = gate_out.view(-1, 1) | ||
| mel_loss = nn.MSELoss(reduction="mean")(mel_specgram, mel_target) | ||
| mel_postnet_loss = nn.MSELoss(reduction="mean")( | ||
| mel_specgram_postnet, mel_target | ||
| ) | ||
| gate_loss = nn.BCEWithLogitsLoss(reduction="mean")(gate_out, gate_target) | ||
| return mel_loss, mel_postnet_loss, gate_loss |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,197 @@ | ||
| import os | ||
| import unittest | ||
| import tempfile | ||
|
|
||
| import torch | ||
| from torch.autograd import gradcheck, gradgradcheck | ||
|
|
||
| from loss_function import Tacotron2Loss | ||
|
|
||
|
|
||
| def skipIfNoCuda(test_item): | ||
| if torch.cuda.is_available(): | ||
| return test_item | ||
| force_cuda_test = os.environ.get("TORCHAUDIO_TEST_FORCE_CUDA", "0") | ||
| if force_cuda_test not in ["0", "1"]: | ||
| raise ValueError('"TORCHAUDIO_TEST_FORCE_CUDA" must be either "0" or "1".') | ||
| if force_cuda_test == "1": | ||
| raise RuntimeError( | ||
| '"TORCHAUDIO_TEST_FORCE_CUDA" is set but CUDA is not available.' | ||
| ) | ||
| return unittest.skip("CUDA is not available.")(test_item) | ||
|
|
||
|
|
||
| class TempDirMixin: | ||
| """Mixin to provide easy access to temp dir""" | ||
|
|
||
| temp_dir_ = None | ||
|
|
||
| @classmethod | ||
| def get_base_temp_dir(cls): | ||
| # If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory. | ||
| # this is handy for debugging. | ||
| key = "TORCHAUDIO_TEST_TEMP_DIR" | ||
| if key in os.environ: | ||
| return os.environ[key] | ||
| if cls.temp_dir_ is None: | ||
| cls.temp_dir_ = tempfile.TemporaryDirectory() | ||
| return cls.temp_dir_.name | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls): | ||
| super().tearDownClass() | ||
| if cls.temp_dir_ is not None: | ||
| cls.temp_dir_.cleanup() | ||
| cls.temp_dir_ = None | ||
|
|
||
| def get_temp_path(self, *paths): | ||
| temp_dir = os.path.join(self.get_base_temp_dir(), self.id()) | ||
| path = os.path.join(temp_dir, *paths) | ||
| os.makedirs(os.path.dirname(path), exist_ok=True) | ||
| return path | ||
|
|
||
|
|
||
| def _get_inputs(dtype, device): | ||
| n_mel, n_batch, max_mel_specgram_length = 3, 2, 4 | ||
|
||
| mel_specgram = torch.rand( | ||
| n_batch, n_mel, max_mel_specgram_length, dtype=dtype, device=device | ||
| ) | ||
| mel_specgram_postnet = torch.rand( | ||
| n_batch, n_mel, max_mel_specgram_length, dtype=dtype, device=device | ||
| ) | ||
| gate_out = torch.rand(n_batch, dtype=dtype, device=device) | ||
| truth_mel_specgram = torch.rand( | ||
| n_batch, n_mel, max_mel_specgram_length, dtype=dtype, device=device | ||
| ) | ||
| truth_gate_out = torch.rand(n_batch, dtype=dtype, device=device) | ||
|
|
||
| return ( | ||
| mel_specgram, | ||
| mel_specgram_postnet, | ||
| gate_out, | ||
| truth_mel_specgram, | ||
| truth_gate_out, | ||
| ) | ||
|
|
||
|
|
||
| class Tacotron2LossTest(unittest.TestCase, TempDirMixin): | ||
|
||
|
|
||
| dtype = torch.float64 | ||
| device = "cpu" | ||
|
|
||
| def _assert_torchscript_consistency(self, fn, tensors): | ||
| path = self.get_temp_path("func.zip") | ||
| torch.jit.script(fn).save(path) | ||
| ts_func = torch.jit.load(path) | ||
|
|
||
| torch.random.manual_seed(40) | ||
| output = fn(*tensors) | ||
|
|
||
| torch.random.manual_seed(40) | ||
| ts_output = ts_func(*tensors) | ||
|
|
||
| self.assertEqual(ts_output, output) | ||
|
|
||
| def test_cpu_torchscript_consistency(self): | ||
| f"""Validate the torchscript consistency of Tacotron2Loss.""" | ||
| dtype = torch.float32 | ||
| device = torch.device("cpu") | ||
|
|
||
| def _fn(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out): | ||
| loss_fn = Tacotron2Loss() | ||
| return loss_fn( | ||
| (mel_specgram, mel_specgram_postnet, gate_out), | ||
| (truth_mel_specgram, truth_gate_out), | ||
| ) | ||
|
|
||
| self._assert_torchscript_consistency(_fn, _get_inputs(dtype, device)) | ||
|
|
||
| @skipIfNoCuda | ||
| def test_gpu_torchscript_consistency(self): | ||
| f"""Validate the torchscript consistency of Tacotron2Loss.""" | ||
| dtype = torch.float32 | ||
| device = torch.device("cuda") | ||
|
|
||
| def _fn(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out): | ||
| loss_fn = Tacotron2Loss() | ||
| return loss_fn( | ||
| (mel_specgram, mel_specgram_postnet, gate_out), | ||
| (truth_mel_specgram, truth_gate_out), | ||
| ) | ||
|
|
||
| self._assert_torchscript_consistency(_fn, self._get_inputs(dtype, device)) | ||
|
|
||
| def test_cpu_gradcheck(self): | ||
| f"""Performing gradient check on Tacotron2Loss.""" | ||
| dtype = torch.float64 # gradcheck needs a higher numerical accuracy | ||
| device = torch.device("cuda") | ||
|
|
||
| ( | ||
| mel_specgram, | ||
| mel_specgram_postnet, | ||
| gate_out, | ||
| truth_mel_specgram, | ||
| truth_gate_out, | ||
| ) = _get_inputs(dtype, device) | ||
|
|
||
| mel_specgram.requires_grad_(True) | ||
| mel_specgram_postnet.requires_grad_(True) | ||
| gate_out.requires_grad_(True) | ||
|
|
||
| def _fn(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out): | ||
| loss_fn = Tacotron2Loss() | ||
| return loss_fn( | ||
| (mel_specgram, mel_specgram_postnet, gate_out), | ||
| (truth_mel_specgram, truth_gate_out), | ||
| ) | ||
|
|
||
| gradcheck( | ||
| _fn, | ||
| (mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out), | ||
| fast_mode=True, | ||
| ) | ||
| gradgradcheck( | ||
| _fn, | ||
| (mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out), | ||
| fast_mode=True, | ||
| ) | ||
|
|
||
| @skipIfNoCuda | ||
| def test_gpu_gradcheck(self): | ||
| f"""Performing gradient check on Tacotron2Loss.""" | ||
| dtype = torch.float64 # gradcheck needs a higher numerical accuracy | ||
| device = torch.device("cuda") | ||
|
|
||
| ( | ||
| mel_specgram, | ||
| mel_specgram_postnet, | ||
| gate_out, | ||
| truth_mel_specgram, | ||
| truth_gate_out, | ||
| ) = _get_inputs(dtype, device) | ||
|
|
||
| mel_specgram.requires_grad_(True) | ||
| mel_specgram_postnet.requires_grad_(True) | ||
| gate_out.requires_grad_(True) | ||
|
|
||
| def _fn(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out): | ||
| loss_fn = Tacotron2Loss() | ||
| return loss_fn( | ||
| (mel_specgram, mel_specgram_postnet, gate_out), | ||
| (truth_mel_specgram, truth_gate_out), | ||
| ) | ||
|
|
||
| gradcheck( | ||
| _fn, | ||
| (mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out), | ||
| fast_mode=True, | ||
| ) | ||
| gradgradcheck( | ||
| _fn, | ||
| (mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out), | ||
| fast_mode=True, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
|
||
Uh oh!
There was an error while loading. Please reload this page.