diff --git a/docs/source/refs.bib b/docs/source/refs.bib index c6ddb3b61a..614e3902f7 100644 --- a/docs/source/refs.bib +++ b/docs/source/refs.bib @@ -38,6 +38,14 @@ @misc{kalchbrenner2018efficient archivePrefix={arXiv}, primaryClass={cs.SD} } +@inproceedings{shen2018natural, + title={Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions}, + author={Shen, Jonathan and Pang, Ruoming and Weiss, Ron J and Schuster, Mike and Jaitly, Navdeep and Yang, Zongheng and Chen, Zhifeng and Zhang, Yu and Wang, Yuxuan and Skerrv-Ryan, Rj and others}, + year={2017}, + eprint={1712.05884}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} @article{Luo_2019, title={Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation}, volume={27}, diff --git a/test/torchaudio_unittest/models/tacotron2/__init__.py b/test/torchaudio_unittest/models/tacotron2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/torchaudio_unittest/models/tacotron2/model_test_cpu_test.py b/test/torchaudio_unittest/models/tacotron2/model_test_cpu_test.py new file mode 100644 index 0000000000..612699b6e1 --- /dev/null +++ b/test/torchaudio_unittest/models/tacotron2/model_test_cpu_test.py @@ -0,0 +1,23 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase +from .model_test_impl import ( + Tacotron2EncoderTests, + Tacotron2DecoderTests, + Tacotron2Tests, +) + + +class TestTacotron2EncoderFloat32CPU(Tacotron2EncoderTests, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cpu") + + +class TestTacotron2DecoderFloat32CPU(Tacotron2DecoderTests, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cpu") + + +class TestTacotron2Float32CPU(Tacotron2Tests, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cpu") diff --git a/test/torchaudio_unittest/models/tacotron2/model_test_gpu_test.py b/test/torchaudio_unittest/models/tacotron2/model_test_gpu_test.py new file mode 100644 index 0000000000..7a6fdd1142 --- /dev/null +++ b/test/torchaudio_unittest/models/tacotron2/model_test_gpu_test.py @@ -0,0 +1,26 @@ +import torch + +from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase +from .model_test_impl import ( + Tacotron2EncoderTests, + Tacotron2DecoderTests, + Tacotron2Tests, +) + + +@skipIfNoCuda +class TestTacotron2EncoderFloat32CUDA(Tacotron2EncoderTests, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cuda") + + +@skipIfNoCuda +class TestTacotron2DecoderFloat32CUDA(Tacotron2DecoderTests, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cuda") + + +@skipIfNoCuda +class TestTacotron2Float32CUDA(Tacotron2Tests, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cuda") diff --git a/test/torchaudio_unittest/models/tacotron2/model_test_impl.py b/test/torchaudio_unittest/models/tacotron2/model_test_impl.py new file mode 100644 index 0000000000..8d0a4eed2b --- /dev/null +++ b/test/torchaudio_unittest/models/tacotron2/model_test_impl.py @@ -0,0 +1,238 @@ +import torch +from torchaudio.prototype.tacotron2 import Tacotron2, _Encoder, _Decoder +from torchaudio_unittest.common_utils import ( + TestBaseMixin, + TempDirMixin, +) + + +class TorchscriptConsistencyMixin(TempDirMixin): + r"""Mixin to provide easy access assert torchscript consistency""" + + def _assert_torchscript_consistency(self, model, tensors): + path = self.get_temp_path("func.zip") + torch.jit.script(model).save(path) + ts_func = torch.jit.load(path) + + torch.random.manual_seed(40) + output = model(*tensors) + + torch.random.manual_seed(40) + ts_output = ts_func(*tensors) + + self.assertEqual(ts_output, output) + + +class Tacotron2EncoderTests(TestBaseMixin, TorchscriptConsistencyMixin): + def test_tacotron2_torchscript_consistency(self): + r"""Validate the torchscript consistency of a Encoder.""" + n_batch, n_seq, encoder_embedding_dim = 16, 64, 512 + model = _Encoder(encoder_embedding_dim=encoder_embedding_dim, + encoder_n_convolution=3, + encoder_kernel_size=5).to(self.device).eval() + + x = torch.rand( + n_batch, encoder_embedding_dim, n_seq, device=self.device, dtype=self.dtype + ) + input_lengths = ( + torch.ones(n_batch, device=self.device, dtype=torch.int32) * n_seq + ) + + self._assert_torchscript_consistency(model, (x, input_lengths)) + + def test_encoder_output_shape(self): + r"""Feed tensors with specific shape to Tacotron2 Decoder and validate + that it outputs with a tensor with expected shape. + """ + n_batch, n_seq, encoder_embedding_dim = 16, 64, 512 + model = _Encoder(encoder_embedding_dim=encoder_embedding_dim, + encoder_n_convolution=3, + encoder_kernel_size=5).to(self.device).eval() + + x = torch.rand( + n_batch, encoder_embedding_dim, n_seq, device=self.device, dtype=self.dtype + ) + input_lengths = ( + torch.ones(n_batch, device=self.device, dtype=torch.int32) * n_seq + ) + out = model(x, input_lengths) + + assert out.size() == (n_batch, n_seq, encoder_embedding_dim) + + +def _get_decoder_model(n_mels=80, encoder_embedding_dim=512): + model = _Decoder( + n_mels=n_mels, + n_frames_per_step=1, + encoder_embedding_dim=encoder_embedding_dim, + decoder_rnn_dim=1024, + decoder_max_step=2000, + decoder_dropout=0.1, + decoder_early_stopping=False, + attention_rnn_dim=1024, + attention_hidden_dim=128, + attention_location_n_filter=32, + attention_location_kernel_size=31, + attention_dropout=0.1, + prenet_dim=256, + gate_threshold=0.5, + ) + return model + + +class Tacotron2DecoderTests(TestBaseMixin, TorchscriptConsistencyMixin): + def test_decoder_torchscript_consistency(self): + r"""Validate the torchscript consistency of a Decoder.""" + n_batch = 16 + n_mels = 80 + n_seq = 200 + encoder_embedding_dim = 256 + n_time_steps = 150 + + model = _get_decoder_model(n_mels=n_mels, encoder_embedding_dim=encoder_embedding_dim) + model = model.to(self.device).eval() + + memory = torch.rand( + n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device + ) + decoder_inputs = torch.rand( + n_batch, n_mels, n_time_steps, dtype=self.dtype, device=self.device + ) + memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device) + + self._assert_torchscript_consistency( + model, (memory, decoder_inputs, memory_lengths) + ) + + def test_decoder_output_shape(self): + r"""Feed tensors with specific shape to Tacotron2 Decoder and validate + that it outputs with a tensor with expected shape. + """ + n_batch = 16 + n_mels = 80 + n_seq = 200 + encoder_embedding_dim = 256 + n_time_steps = 150 + + model = _get_decoder_model(n_mels=n_mels, encoder_embedding_dim=encoder_embedding_dim) + model = model.to(self.device).eval() + + memory = torch.rand( + n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device + ) + decoder_inputs = torch.rand( + n_batch, n_mels, n_time_steps, dtype=self.dtype, device=self.device + ) + memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device) + + mel_outputs, gate_outputs, alignments = model( + memory, decoder_inputs, memory_lengths + ) + + assert mel_outputs.size() == (n_batch, n_mels, n_time_steps) + assert gate_outputs.size() == (n_batch, n_time_steps) + assert alignments.size() == (n_batch, n_time_steps, n_seq) + + +def _get_tacotron2_model(n_mels): + return Tacotron2( + mask_padding=False, + n_mels=n_mels, + n_symbol=148, + n_frames_per_step=1, + symbol_embedding_dim=512, + encoder_embedding_dim=512, + encoder_n_convolution=3, + encoder_kernel_size=5, + decoder_rnn_dim=1024, + decoder_max_step=2000, + decoder_dropout=0.1, + decoder_early_stopping=True, + attention_rnn_dim=1024, + attention_hidden_dim=128, + attention_location_n_filter=32, + attention_location_kernel_size=31, + attention_dropout=0.1, + prenet_dim=256, + postnet_n_convolution=5, + postnet_kernel_size=5, + postnet_embedding_dim=512, + gate_threshold=0.5, + ) + + +class Tacotron2Tests(TestBaseMixin, TorchscriptConsistencyMixin): + def _get_inputs( + self, n_mels, n_batch: int, max_mel_specgram_length: int, max_text_length: int + ): + text = torch.randint( + 0, 148, (n_batch, max_text_length), dtype=torch.int32, device=self.device + ) + text_lengths = max_text_length * torch.ones( + (n_batch,), dtype=torch.int32, device=self.device + ) + mel_specgram = torch.rand( + n_batch, + n_mels, + max_mel_specgram_length, + dtype=self.dtype, + device=self.device, + ) + mel_specgram_lengths = max_mel_specgram_length * torch.ones( + (n_batch,), dtype=torch.int32, device=self.device + ) + return text, text_lengths, mel_specgram, mel_specgram_lengths + + def test_tacotron2_torchscript_consistency(self): + r"""Validate the torchscript consistency of a Tacotron2.""" + n_batch = 16 + n_mels = 80 + max_mel_specgram_length = 300 + max_text_length = 100 + + model = _get_tacotron2_model(n_mels).to(self.device).eval() + inputs = self._get_inputs( + n_mels, n_batch, max_mel_specgram_length, max_text_length + ) + + self._assert_torchscript_consistency(model, inputs) + + def test_tacotron2_output_shape(self): + r"""Feed tensors with specific shape to Tacotron2 and validate + that it outputs with a tensor with expected shape. + """ + n_batch = 16 + n_mels = 80 + max_mel_specgram_length = 300 + max_text_length = 100 + + model = _get_tacotron2_model(n_mels).to(self.device).eval() + inputs = self._get_inputs( + n_mels, n_batch, max_mel_specgram_length, max_text_length + ) + mel_out, mel_out_postnet, gate_outputs, alignments = model(*inputs) + + assert mel_out.size() == (n_batch, n_mels, max_mel_specgram_length) + assert mel_out_postnet.size() == (n_batch, n_mels, max_mel_specgram_length) + assert gate_outputs.size() == (n_batch, max_mel_specgram_length) + assert alignments.size() == (n_batch, max_mel_specgram_length, max_text_length) + + def test_tacotron2_backward(self): + r"""Make sure calling the backward function on Tacotron2's outputs does + not error out. Following: + https://github.com/pytorch/vision/blob/23b8760374a5aaed53c6e5fc83a7e83dbe3b85df/test/test_models.py#L255 + """ + n_batch = 16 + n_mels = 80 + max_mel_specgram_length = 300 + max_text_length = 100 + + model = _get_tacotron2_model(n_mels).to(self.device) + inputs = self._get_inputs( + n_mels, n_batch, max_mel_specgram_length, max_text_length + ) + mel_out, mel_out_postnet, gate_outputs, _ = model(*inputs) + + mel_out.sum().backward(retain_graph=True) + mel_out_postnet.sum().backward(retain_graph=True) + gate_outputs.sum().backward() diff --git a/torchaudio/prototype/tacotron2.py b/torchaudio/prototype/tacotron2.py new file mode 100644 index 0000000000..6f3a4fe0de --- /dev/null +++ b/torchaudio/prototype/tacotron2.py @@ -0,0 +1,951 @@ +# ***************************************************************************** +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the +# names of its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ***************************************************************************** + +from math import sqrt +from typing import Tuple, List, Optional, Union + +import torch +from torch import nn +from torch import Tensor +from torch.nn import functional as F + + +__all__ = [ + "Tacotron2", +] + + +def _get_linear_layer( + in_dim: int, out_dim: int, bias: bool = True, w_init_gain: str = "linear" +) -> torch.nn.Linear: + r"""Linear layer with xavier uniform initialization. + + Args: + in_dim (int): Size of each input sample. + out_dim (int): Size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias. (Default: ``True``) + w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain`` + for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``) + + Returns: + (torch.nn.Linear): The corresponding linear layer. + """ + linear = torch.nn.Linear(in_dim, out_dim, bias=bias) + torch.nn.init.xavier_uniform_( + linear.weight, gain=torch.nn.init.calculate_gain(w_init_gain) + ) + return linear + + +def _get_conv1d_layer( + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + padding: Optional[Union[str, int, Tuple[int]]] = None, + dilation: int = 1, + bias: bool = True, + w_init_gain: str = "linear", +) -> torch.nn.Conv1d: + r"""1D convolution with xavier uniform initialization. + + Args: + in_channels (int): Number of channels in the input image. + out_channels (int): Number of channels produced by the convolution. + kernel_size (int, optional): Number of channels in the input image. (Default: ``1``) + stride (int, optional): Number of channels in the input image. (Default: ``1``) + padding (str, int or tuple, optional): Padding added to both sides of the input. + (Default: dilation * (kernel_size - 1) / 2) + dilation (int, optional): Number of channels in the input image. (Default: ``1``) + w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain`` + for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``) + + Returns: + (torch.nn.Conv1d): The corresponding Conv1D layer. + """ + if padding is None: + assert kernel_size % 2 == 1 + padding = int(dilation * (kernel_size - 1) / 2) + + conv1d = torch.nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + torch.nn.init.xavier_uniform_( + conv1d.weight, gain=torch.nn.init.calculate_gain(w_init_gain) + ) + + return conv1d + + +def _get_mask_from_lengths(lengths: Tensor) -> Tensor: + r"""Returns a binary mask based on ``lengths``. The ``i``-th row and ``j``-th column of the mask + is ``1`` if ``j`` is smaller than ``i``-th element of ``lengths. + + Args: + lengths (Tensor): The length of each element in the batch, with shape (n_batch, ). + + Returns: + mask (Tensor): The binary mask, with shape (n_batch, max of ``lengths``). + """ + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype) + mask = (ids < lengths.unsqueeze(1)).byte() + mask = torch.le(mask, 0) + return mask + + +class _LocationLayer(nn.Module): + r"""Location layer used in the Attention model. + + Args: + attention_n_filter (int): Number of filters for attention model. + attention_kernel_size (int): Kernel size for attention model. + attention_hidden_dim (int): Dimension of attention hidden representation. + """ + + def __init__( + self, + attention_n_filter: int, + attention_kernel_size: int, + attention_hidden_dim: int, + ): + super().__init__() + padding = int((attention_kernel_size - 1) / 2) + self.location_conv = _get_conv1d_layer( + 2, + attention_n_filter, + kernel_size=attention_kernel_size, + padding=padding, + bias=False, + stride=1, + dilation=1, + ) + self.location_dense = _get_linear_layer( + attention_n_filter, attention_hidden_dim, bias=False, w_init_gain="tanh" + ) + + def forward(self, attention_weights_cat: Tensor) -> Tensor: + r"""Location layer used in the Attention model. + + Args: + attention_weights_cat (Tensor): Cumulative and previous attention weights + with shape (n_batch, 2, max of ``text_lengths``). + + Returns: + processed_attention (Tensor): Cumulative and previous attention weights + with shape (n_batch, ``attention_hidden_dim``). + """ + # (n_batch, attention_n_filter, text_lengths.max()) + processed_attention = self.location_conv(attention_weights_cat) + processed_attention = processed_attention.transpose(1, 2) + # (n_batch, text_lengths.max(), attention_hidden_dim) + processed_attention = self.location_dense(processed_attention) + return processed_attention + + +class _Attention(nn.Module): + r"""Locally sensitive attention model. + + Args: + attention_rnn_dim (int): Number of hidden units for RNN. + encoder_embedding_dim (int): Number of embedding dimensions in the Encoder. + attention_hidden_dim (int): Dimension of attention hidden representation. + attention_location_n_filter (int): Number of filters for Attention model. + attention_location_kernel_size (int): Kernel size for Attention model. + """ + + def __init__( + self, + attention_rnn_dim: int, + encoder_embedding_dim: int, + attention_hidden_dim: int, + attention_location_n_filter: int, + attention_location_kernel_size: int, + ) -> None: + super().__init__() + self.query_layer = _get_linear_layer( + attention_rnn_dim, attention_hidden_dim, bias=False, w_init_gain="tanh" + ) + self.memory_layer = _get_linear_layer( + encoder_embedding_dim, attention_hidden_dim, bias=False, w_init_gain="tanh" + ) + self.v = _get_linear_layer(attention_hidden_dim, 1, bias=False) + self.location_layer = _LocationLayer( + attention_location_n_filter, + attention_location_kernel_size, + attention_hidden_dim, + ) + self.score_mask_value = -float("inf") + + def _get_alignment_energies( + self, query: Tensor, processed_memory: Tensor, attention_weights_cat: Tensor + ) -> Tensor: + r"""Get the alignment vector. + + Args: + query (Tensor): Decoder output with shape (n_batch, n_mels * n_frames_per_step). + processed_memory (Tensor): Processed Encoder outputs + with shape (n_batch, max of ``text_lengths``, attention_hidden_dim). + attention_weights_cat (Tensor): Cumulative and previous attention weights + with shape (n_batch, 2, max of ``text_lengths``). + + Returns: + alignment (Tensor): attention weights, it is a tensor with shape (batch, max of ``text_lengths``). + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v( + torch.tanh(processed_query + processed_attention_weights + processed_memory) + ) + + alignment = energies.squeeze(2) + return alignment + + def forward( + self, + attention_hidden_state: Tensor, + memory: Tensor, + processed_memory: Tensor, + attention_weights_cat: Tensor, + mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + r"""Pass the input through the Attention model. + + Args: + attention_hidden_state (Tensor): Attention rnn last output with shape (n_batch, ``attention_rnn_dim``). + memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + processed_memory (Tensor): Processed Encoder outputs + with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``). + attention_weights_cat (Tensor): Previous and cumulative attention weights + with shape (n_batch, current_num_frames * 2, max of ``text_lengths``). + mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames). + + Returns: + attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``). + attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``). + """ + alignment = self._get_alignment_energies( + attention_hidden_state, processed_memory, attention_weights_cat + ) + + alignment = alignment.masked_fill(mask, self.score_mask_value) + + attention_weights = F.softmax(alignment, dim=1) + attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights + + +class _Prenet(nn.Module): + r"""Prenet Module. It is consists of ``len(output_size)`` linear layers. + + Args: + in_dim (int): The size of each input sample. + output_sizes (list): The output dimension of each linear layers. + """ + + def __init__(self, in_dim: int, out_sizes: List[int]) -> None: + super().__init__() + in_sizes = [in_dim] + out_sizes[:-1] + self.layers = nn.ModuleList( + [ + _get_linear_layer(in_size, out_size, bias=False) + for (in_size, out_size) in zip(in_sizes, out_sizes) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + r"""Pass the input through Prenet. + + Args: + x (Tensor): The input sequence to Prenet with shape (n_batch, in_dim). + + Return: + x (Tensor): Tensor with shape (n_batch, sizes[-1]) + """ + + for linear in self.layers: + x = F.dropout(F.relu(linear(x)), p=0.5, training=True) + return x + + +class _Postnet(nn.Module): + r"""Postnet Module. + + Args: + n_mels (int): Number of mel bins. + postnet_embedding_dim (int): Postnet embedding dimension. + postnet_kernel_size (int): Postnet kernel size. + postnet_n_convolution (int): Number of postnet convolutions. + """ + + def __init__( + self, + n_mels: int, + postnet_embedding_dim: int, + postnet_kernel_size: int, + postnet_n_convolution: int, + ): + super().__init__() + self.convolutions = nn.ModuleList() + + for i in range(postnet_n_convolution): + in_channels = n_mels if i == 0 else postnet_embedding_dim + out_channels = n_mels if i == (postnet_n_convolution - 1) else postnet_embedding_dim + init_gain = "linear" if i == (postnet_n_convolution - 1) else "tanh" + num_features = n_mels if i == (postnet_n_convolution - 1) else postnet_embedding_dim + self.convolutions.append( + nn.Sequential( + _get_conv1d_layer( + in_channels, + out_channels, + kernel_size=postnet_kernel_size, + stride=1, + padding=int((postnet_kernel_size - 1) / 2), + dilation=1, + w_init_gain=init_gain, + ), + nn.BatchNorm1d(num_features), + ) + ) + + self.n_convs = len(self.convolutions) + + def forward(self, x: Tensor) -> Tensor: + r"""Pass the input through Postnet. + + Args: + x (Tensor): The input sequence with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``). + + Return: + x (Tensor): Tensor with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``). + """ + + i = 0 + for conv in self.convolutions: + if i < self.n_convs - 1: + x = F.dropout(torch.tanh(conv(x)), 0.5, training=self.training) + else: + x = F.dropout(conv(x), 0.5, training=self.training) + i += 1 + + return x + + +class _Encoder(nn.Module): + r"""Encoder Module. + + Args: + encoder_embedding_dim (int): Number of embedding dimensions in the encoder. + encoder_n_convolution (int): Number of convolution layers in the encoder. + encoder_kernel_size (int): The kernel size in the encoder. + + Examples + >>> encoder = _Encoder(3, 512, 5) + >>> input = torch.rand(10, 20, 30) + >>> output = encoder(input) # shape: (10, 30, 512) + """ + + def __init__( + self, + encoder_embedding_dim: int, + encoder_n_convolution: int, + encoder_kernel_size: int, + ) -> None: + super().__init__() + + self.convolutions = nn.ModuleList() + for _ in range(encoder_n_convolution): + conv_layer = nn.Sequential( + _get_conv1d_layer( + encoder_embedding_dim, + encoder_embedding_dim, + kernel_size=encoder_kernel_size, + stride=1, + padding=int((encoder_kernel_size - 1) / 2), + dilation=1, + w_init_gain="relu", + ), + nn.BatchNorm1d(encoder_embedding_dim), + ) + self.convolutions.append(conv_layer) + + self.lstm = nn.LSTM( + encoder_embedding_dim, + int(encoder_embedding_dim / 2), + 1, + batch_first=True, + bidirectional=True, + ) + self.lstm.flatten_parameters() + + def forward(self, x: Tensor, input_lengths: Tensor) -> Tensor: + r"""Pass the input through the Encoder. + + Args: + x (Tensor): The input sequences with shape (n_batch, encoder_embedding_dim, n_seq). + input_lengths (Tensor): The length of each input sequence with shape (n_batch, ). + + Return: + x (Tensor): A tensor with shape (n_batch, n_seq, encoder_embedding_dim). + """ + + for conv in self.convolutions: + x = F.dropout(F.relu(conv(x)), 0.5, self.training) + + x = x.transpose(1, 2) + + input_lengths = input_lengths.cpu() + x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True) + + outputs, _ = self.lstm(x) + outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) + + return outputs + + +class _Decoder(nn.Module): + r"""Decoder with Attention model. + + Args: + n_mels (int): number of mel bins + n_frames_per_step (int): number of frames processed per step, only 1 is supported + encoder_embedding_dim (int): the number of embedding dimensions in the encoder. + decoder_rnn_dim (int): number of units in decoder LSTM + decoder_max_step (int): maximum number of output mel spectrograms + decoder_dropout (float): dropout probability for decoder LSTM + decoder_early_stopping (bool): stop decoding when all samples are finished + attention_rnn_dim (int): number of units in attention LSTM + attention_hidden_dim (int): dimension of attention hidden representation + attention_location_n_filter (int): number of filters for attention model + attention_location_kernel_size (int): kernel size for attention model + attention_dropout (float): dropout probability for attention LSTM + prenet_dim (int): number of ReLU units in prenet layers + gate_threshold (float): probability threshold for stop token + """ + + def __init__( + self, + n_mels: int, + n_frames_per_step: int, + encoder_embedding_dim: int, + decoder_rnn_dim: int, + decoder_max_step: int, + decoder_dropout: float, + decoder_early_stopping: bool, + attention_rnn_dim: int, + attention_hidden_dim: int, + attention_location_n_filter: int, + attention_location_kernel_size: int, + attention_dropout: float, + prenet_dim: int, + gate_threshold: float, + ) -> None: + + super().__init__() + self.n_mels = n_mels + self.n_frames_per_step = n_frames_per_step + self.encoder_embedding_dim = encoder_embedding_dim + self.attention_rnn_dim = attention_rnn_dim + self.decoder_rnn_dim = decoder_rnn_dim + self.prenet_dim = prenet_dim + self.decoder_max_step = decoder_max_step + self.gate_threshold = gate_threshold + self.attention_dropout = attention_dropout + self.decoder_dropout = decoder_dropout + self.decoder_early_stopping = decoder_early_stopping + + self.prenet = _Prenet(n_mels * n_frames_per_step, [prenet_dim, prenet_dim]) + + self.attention_rnn = nn.LSTMCell( + prenet_dim + encoder_embedding_dim, attention_rnn_dim + ) + + self.attention_layer = _Attention( + attention_rnn_dim, + encoder_embedding_dim, + attention_hidden_dim, + attention_location_n_filter, + attention_location_kernel_size, + ) + + self.decoder_rnn = nn.LSTMCell( + attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, True + ) + + self.linear_projection = _get_linear_layer( + decoder_rnn_dim + encoder_embedding_dim, n_mels * n_frames_per_step + ) + + self.gate_layer = _get_linear_layer( + decoder_rnn_dim + encoder_embedding_dim, 1, bias=True, w_init_gain="sigmoid" + ) + + def _get_initial_frame(self, memory: Tensor) -> Tensor: + r"""Gets all zeros frames to use as the first decoder input. + + Args: + memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + + Returns: + decoder_input (Tensor): all zeros frames with shape + (n_batch, max of ``text_lengths``, ``n_mels * n_frames_per_step``). + """ + + n_batch = memory.size(0) + dtype = memory.dtype + device = memory.device + decoder_input = torch.zeros( + n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device + ) + return decoder_input + + def _initialize_decoder_states( + self, memory: Tensor + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Initializes attention rnn states, decoder rnn states, attention + weights, attention cumulative weights, attention context, stores memory + and stores processed memory. + + Args: + memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + + Returns: + attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``). + attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``). + attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``). + processed_memory (Tensor): Processed encoder outputs + with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``). + """ + n_batch = memory.size(0) + max_time = memory.size(1) + dtype = memory.dtype + device = memory.device + + attention_hidden = torch.zeros( + n_batch, self.attention_rnn_dim, dtype=dtype, device=device + ) + attention_cell = torch.zeros( + n_batch, self.attention_rnn_dim, dtype=dtype, device=device + ) + + decoder_hidden = torch.zeros( + n_batch, self.decoder_rnn_dim, dtype=dtype, device=device + ) + decoder_cell = torch.zeros( + n_batch, self.decoder_rnn_dim, dtype=dtype, device=device + ) + + attention_weights = torch.zeros(n_batch, max_time, dtype=dtype, device=device) + attention_weights_cum = torch.zeros( + n_batch, max_time, dtype=dtype, device=device + ) + attention_context = torch.zeros( + n_batch, self.encoder_embedding_dim, dtype=dtype, device=device + ) + + processed_memory = self.attention_layer.memory_layer(memory) + + return ( + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + processed_memory, + ) + + def _parse_decoder_inputs(self, decoder_inputs: Tensor) -> Tensor: + r"""Prepares decoder inputs. + + Args: + decoder_inputs (Tensor): Inputs used for teacher-forced training, i.e. mel-specs, + with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``) + + Returns: + inputs (Tensor): Processed decoder inputs with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``). + """ + # (n_batch, n_mels, mel_specgram_lengths.max()) -> (n_batch, mel_specgram_lengths.max(), n_mels) + decoder_inputs = decoder_inputs.transpose(1, 2) + decoder_inputs = decoder_inputs.view( + decoder_inputs.size(0), + int(decoder_inputs.size(1) / self.n_frames_per_step), + -1, + ) + # (n_batch, mel_specgram_lengths.max(), n_mels) -> (mel_specgram_lengths.max(), n_batch, n_mels) + decoder_inputs = decoder_inputs.transpose(0, 1) + return decoder_inputs + + def _parse_decoder_outputs( + self, mel_outputs: Tensor, gate_outputs: Tensor, alignments: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: + r"""Prepares decoder outputs for output + + Args: + mel_outputs (Tensor): mel spectrogram with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``) + gate_outputs (Tensor): predicted stop token with shape (max of ``mel_specgram_lengths``, n_batch) + alignments (Tensor): sequence of attention weights from the decoder + with shape (max of ``mel_specgram_lengths``, n_batch, max of ``text_lengths``) + + Returns: + mel_specgram (Tensor): mel spectrogram with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``) + gate_outputs (Tensor): predicted stop token with shape (n_batch, max of ``mel_specgram_lengths``) + alignments (Tensor): sequence of attention weights from the decoder + with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``) + """ + # (mel_specgram_lengths.max(), n_batch, text_lengths.max()) + # -> (n_batch, mel_specgram_lengths.max(), text_lengths.max()) + alignments = alignments.transpose(0, 1).contiguous() + # (mel_specgram_lengths.max(), n_batch) -> (n_batch, mel_specgram_lengths.max()) + gate_outputs = gate_outputs.transpose(0, 1).contiguous() + # (mel_specgram_lengths.max(), n_batch, n_mels) -> (n_batch, mel_specgram_lengths.max(), n_mels) + mel_specgram = mel_outputs.transpose(0, 1).contiguous() + # decouple frames per step + shape = (mel_specgram.shape[0], -1, self.n_mels) + mel_specgram = mel_specgram.view(*shape) + # (n_batch, mel_specgram_lengths.max(), n_mels) -> (n_batch, n_mels, T_out) + mel_specgram = mel_specgram.transpose(1, 2) + + return mel_specgram, gate_outputs, alignments + + def decode( + self, + decoder_input: Tensor, + attention_hidden: Tensor, + attention_cell: Tensor, + decoder_hidden: Tensor, + decoder_cell: Tensor, + attention_weights: Tensor, + attention_weights_cum: Tensor, + attention_context: Tensor, + memory: Tensor, + processed_memory: Tensor, + mask: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Decoder step using stored states, attention and memory + + Args: + decoder_input (Tensor): Output of the Prenet with shape (n_batch, ``prenet_dim``). + attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``). + attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``). + attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``). + memory (Tensor): Encoder output with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + processed_memory (Tensor): Processed Encoder outputs + with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``). + mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames). + + Returns: + decoder_output: Predicted mel spectrogram for the current frame with shape (n_batch, ``n_mels``). + gate_prediction (Tensor): Prediction of the stop token with shape (n_batch, ``1``). + attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``). + attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``). + attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``). + """ + cell_input = torch.cat((decoder_input, attention_context), -1) + + attention_hidden, attention_cell = self.attention_rnn( + cell_input, (attention_hidden, attention_cell) + ) + attention_hidden = F.dropout( + attention_hidden, self.attention_dropout, self.training + ) + + attention_weights_cat = torch.cat( + (attention_weights.unsqueeze(1), attention_weights_cum.unsqueeze(1)), dim=1 + ) + attention_context, attention_weights = self.attention_layer( + attention_hidden, memory, processed_memory, attention_weights_cat, mask + ) + + attention_weights_cum += attention_weights + decoder_input = torch.cat((attention_hidden, attention_context), -1) + + decoder_hidden, decoder_cell = self.decoder_rnn( + decoder_input, (decoder_hidden, decoder_cell) + ) + decoder_hidden = F.dropout(decoder_hidden, self.decoder_dropout, self.training) + + decoder_hidden_attention_context = torch.cat( + (decoder_hidden, attention_context), dim=1 + ) + decoder_output = self.linear_projection(decoder_hidden_attention_context) + + gate_prediction = self.gate_layer(decoder_hidden_attention_context) + + return ( + decoder_output, + gate_prediction, + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + ) + + def forward( + self, memory: Tensor, mel_specgram_truth: Tensor, memory_lengths: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: + r"""Decoder forward pass for training. + + Args: + memory (Tensor): Encoder outputs + with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + mel_specgram_truth (Tensor): Decoder ground-truth mel-specs for teacher forcing + with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``). + memory_lengths (Tensor): Encoder output lengths for attention masking + (the same as ``text_lengths``) with shape (n_batch, ). + + Returns: + mel_specgram (Tensor): Predicted mel spectrogram + with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``). + gate_outputs (Tensor): Predicted stop token for each timestep + with shape (n_batch, max of ``mel_specgram_lengths``). + alignments (Tensor): Sequence of attention weights from the decoder + with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``). + """ + + decoder_input = self._get_initial_frame(memory).unsqueeze(0) + decoder_inputs = self._parse_decoder_inputs(mel_specgram_truth) + decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) + decoder_inputs = self.prenet(decoder_inputs) + + mask = _get_mask_from_lengths(memory_lengths) + ( + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + processed_memory, + ) = self._initialize_decoder_states(memory) + + mel_outputs, gate_outputs, alignments = [], [], [] + while len(mel_outputs) < decoder_inputs.size(0) - 1: + decoder_input = decoder_inputs[len(mel_outputs)] + ( + mel_output, + gate_output, + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + ) = self.decode( + decoder_input, + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + memory, + processed_memory, + mask, + ) + + mel_outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output.squeeze()] + alignments += [attention_weights] + + mel_specgram, gate_outputs, alignments = self._parse_decoder_outputs( + torch.stack(mel_outputs), torch.stack(gate_outputs), torch.stack(alignments) + ) + + return mel_specgram, gate_outputs, alignments + + +class Tacotron2(nn.Module): + r"""Tacotron2 model based on the implementation from + `Nvidia `_. + + The original implementation was introduced in + *Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions* + [:footcite:`shen2018natural`]. + + Args: + mask_padding (bool, optional): Use mask padding (Default: ``False``). + n_mels (int, optional): Number of mel bins (Default: ``80``). + n_symbol (int, optional): Number of symbols for the input text (Default: ``148``). + n_frames_per_step (int, optional): Number of frames processed per step, only 1 is supported (Default: ``1``). + symbol_embedding_dim (int, optional): Input embedding dimension (Default: ``512``). + encoder_n_convolution (int, optional): Number of encoder convolutions (Default: ``3``). + encoder_kernel_size (int, optional): Encoder kernel size (Default: ``5``). + encoder_embedding_dim (int, optional): Encoder embedding dimension (Default: ``512``). + decoder_rnn_dim (int, optional): Number of units in decoder LSTM (Default: ``1024``). + decoder_max_step (int, optional): Maximum number of output mel spectrograms (Default: ``2000``). + decoder_dropout (float, optional): Dropout probability for decoder LSTM (Default: ``0.1``). + decoder_early_stopping (bool, optional): Continue decoding after all samples are finished (Default: ``True``). + attention_rnn_dim (int, optional): Number of units in attention LSTM (Default: ``1024``). + attention_hidden_dim (int, optional): Dimension of attention hidden representation (Default: ``128``). + attention_location_n_filter (int, optional): Number of filters for attention model (Default: ``32``). + attention_location_kernel_size (int, optional): Kernel size for attention model (Default: ``31``). + attention_dropout (float, optional): Dropout probability for attention LSTM (Default: ``0.1``). + prenet_dim (int, optional): Number of ReLU units in prenet layers (Default: ``256``). + postnet_n_convolution (int, optional): Number of postnet convolutions (Default: ``5``). + postnet_kernel_size (int, optional): Postnet kernel size (Default: ``5``). + postnet_embedding_dim (int, optional): Postnet embedding dimension (Default: ``512``). + gate_threshold (float, optional): Probability threshold for stop token (Default: ``0.5``). + """ + + def __init__( + self, + mask_padding: bool = False, + n_mels: int = 80, + n_symbol: int = 148, + n_frames_per_step: int = 1, + symbol_embedding_dim: int = 512, + encoder_embedding_dim: int = 512, + encoder_n_convolution: int = 3, + encoder_kernel_size: int = 5, + decoder_rnn_dim: int = 1024, + decoder_max_step: int = 2000, + decoder_dropout: float = 0.1, + decoder_early_stopping: bool = True, + attention_rnn_dim: int = 1024, + attention_hidden_dim: int = 128, + attention_location_n_filter: int = 32, + attention_location_kernel_size: int = 31, + attention_dropout: float = 0.1, + prenet_dim: int = 256, + postnet_n_convolution: int = 5, + postnet_kernel_size: int = 5, + postnet_embedding_dim: int = 512, + gate_threshold: float = 0.5, + ) -> None: + super().__init__() + + self.mask_padding = mask_padding + self.n_mels = n_mels + self.n_frames_per_step = n_frames_per_step + self.embedding = nn.Embedding(n_symbol, symbol_embedding_dim) + std = sqrt(2.0 / (n_symbol + symbol_embedding_dim)) + val = sqrt(3.0) * std + self.embedding.weight.data.uniform_(-val, val) + self.encoder = _Encoder( + encoder_embedding_dim, encoder_n_convolution, encoder_kernel_size + ) + self.decoder = _Decoder( + n_mels, + n_frames_per_step, + encoder_embedding_dim, + decoder_rnn_dim, + decoder_max_step, + decoder_dropout, + decoder_early_stopping, + attention_rnn_dim, + attention_hidden_dim, + attention_location_n_filter, + attention_location_kernel_size, + attention_dropout, + prenet_dim, + gate_threshold, + ) + self.postnet = _Postnet( + n_mels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolution + ) + + def forward( + self, + text: Tensor, + text_lengths: Tensor, + mel_specgram: Tensor, + mel_specgram_lengths: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r"""Pass the input through the Tacotron2 model. This is in teacher + forcing mode, which is generally used for training. + + The input ``text`` should be padded with zeros to length max of ``text_lengths``. + The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``. + + Args: + text (Tensor): The input text to Tacotron2 with shape (n_batch, max of ``text_lengths``). + text_lengths (Tensor): The length of each text with shape (n_batch). + mel_specgram (Tensor): The target mel spectrogram + with shape (n_batch, n_mels, max of ``mel_specgram_lengths``). + mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape (n_batch). + + Returns: + mel_specgram (Tensor): Mel spectrogram before Postnet + with shape (n_batch, n_mels, max of ``mel_specgram_lengths``). + mel_specgram_postnet (Tensor): Mel spectrogram after Postnet + with shape (n_batch, n_mels, max of ``mel_specgram_lengths``). + stop_token (Tensor): The output for stop token at each time step + with shape (n_batch, max of ``mel_specgram_lengths``). + alignment (Tensor): Sequence of attention weights from the decoder. + with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``). + """ + + embedded_inputs = self.embedding(text).transpose(1, 2) + + encoder_outputs = self.encoder(embedded_inputs, text_lengths) + mel_specgram, gate_outputs, alignments = self.decoder( + encoder_outputs, mel_specgram, memory_lengths=text_lengths + ) + + mel_specgram_postnet = self.postnet(mel_specgram) + mel_specgram_postnet = mel_specgram + mel_specgram_postnet + + if self.mask_padding: + mask = _get_mask_from_lengths(mel_specgram_lengths) + mask = mask.expand(self.n_mels, mask.size(0), mask.size(1)) + mask = mask.permute(1, 0, 2) + + mel_specgram.masked_fill_(mask, 0.0) + mel_specgram_postnet.masked_fill_(mask, 0.0) + gate_outputs.masked_fill_(mask[:, 0, :], 1e3) + + return mel_specgram, mel_specgram_postnet, gate_outputs, alignments