From ae2d33f4f16e8cb171f223904ec94329019b87c8 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Wed, 29 Jul 2020 06:47:58 -0700 Subject: [PATCH 1/2] Add encoder --- test/test_models.py | 21 +++++ torchaudio/models/_tacotron.py | 159 +++++++++++++++++++++++++++++++++ 2 files changed, 180 insertions(+) create mode 100644 torchaudio/models/_tacotron.py diff --git a/test/test_models.py b/test/test_models.py index a37ef66d77..272a82b701 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -115,3 +115,24 @@ def test_waveform(self): out = model(x, mels) assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes) + + +class TestEncoder(common_utils.TorchaudioTestCase): + def test_output(self): + """Validate the output dimensions of a _Encoder block. + """ + + n_encoder_convolutions = 3 + n_encoder_embedding = 512 + n_encoder_kernel_size = 5 + n_batch = 32 + n_seq = 64 + + model = _Encoder(n_encoder_convolutions, n_encoder_embedding, n_encoder_kernel_size) + + x = torch.rand(n_batch, n_encoder_embedding, n_seq) + input_length = [n_seq for i in range(n_batch)] + out = model(x, input_length) + + assert out.size() == (n_batch, n_seq, n_encoder_embedding) + \ No newline at end of file diff --git a/torchaudio/models/_tacotron.py b/torchaudio/models/_tacotron.py new file mode 100644 index 0000000000..391cf37608 --- /dev/null +++ b/torchaudio/models/_tacotron.py @@ -0,0 +1,159 @@ +from typing import Optional + +import torch +from torch import Tensor, nn +from torch.nn import functional as F + +__all__ = ["_ConvNorm", "_Encoder"] + + +class _ConvNorm(nn.Module): + r"""1-d convolution layer + + Args: + n_input: the number of input channels. + n_output: the number of output channels. + + Examples + >>> convnorm = _ConvNorm(10, 20) + >>> input = torch.rand(32, 10, 512) + >>> output = convnorm(input) # shape: (32, 20, 512) + """ + + def __init__( + self, + n_input, + n_output, + kernel_size: int = 1, + stride: int = 1, + padding: Optional[int] = None, + dilation: int = 1, + bias: bool = True, + w_init_gain: str = "linear", + ) -> None: + super(_ConvNorm, self).__init__() + + if padding is None: + assert kernel_size % 2 == 1 + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = nn.Conv1d( + n_input, + n_output, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + nn.init.xavier_uniform_( + self.conv.weight, gain=nn.init.calculate_gain(w_init_gain), + ) + + def forward(self, x: Tensor) -> Tensor: + r"""Pass the input through the _ConvNorm layer. + + Args: + x (Tensor): the input sequence to the _ConvNorm layer (n_batch, n_input, n_seq). + + Return: + Tensor shape: (n_batch, n_output, n_seq) + """ + + return self.conv(x) + + +class _Encoder(nn.Module): + r"""Encoder Module + + Args: + n_encoder_convolutions: the number of convolution layers in the encoder. + n_encoder_embedding: the number of embedding dimensions in the encoder. + n_encoder_kernel_size: the kernel size in the encoder. + + Examples + >>> encoder = _Encoder(3, 512, 5) + >>> input = torch.rand(10, 20, 30) + >>> output = encoder(input) # shape: (10, 30, 512) + """ + + def __init__( + self, n_encoder_convolutions, n_encoder_embedding, n_encoder_kernel_size + ) -> None: + super(_Encoder, self).__init__() + + convolutions = [] + for _ in range(n_encoder_convolutions): + conv_layer = nn.Sequential( + _ConvNorm( + n_encoder_embedding, + n_encoder_embedding, + kernel_size=n_encoder_kernel_size, + stride=1, + padding=int((n_encoder_kernel_size - 1) / 2), + dilation=1, + w_init_gain="relu", + ), + nn.BatchNorm1d(n_encoder_embedding), + ) + convolutions.append(conv_layer) + + self.convolutions = nn.ModuleList(convolutions) + + self.lstm = nn.LSTM( + n_encoder_embedding, + int(n_encoder_embedding / 2), + 1, + batch_first=True, + bidirectional=True, + ) + + def forward(self, x: Tensor, input_lengths: Tensor) -> Tensor: + r"""Pass the input through the _Encoder layer. + + Args: + x (Tensor): the input sequence to the _Encoder layer (n_batch, n_encoder_embedding, n_seq). + input_lengths (Tensor): the length of input sequence to the _Encoder layer (n_batch,). + + Return: + Tensor shape: (n_batch, n_seq, n_encoder_embedding) + """ + + for conv in self.convolutions: + x = F.dropout(F.relu(conv(x)), 0.5, self.training) + + x = x.transpose(1, 2) + + input_lengths = input_lengths + x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True) + + self.lstm.flatten_parameters() + outputs, _ = self.lstm(x) + outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) + + return outputs + + def infer(self, x: Tensor, input_lengths: Tensor) -> Tensor: + r"""Pass the input through the _Encoder layer for inference. + + Args: + x (Tensor): the input sequence to the _Encoder layer (n_batch, n_encoder_embedding, n_seq). + input_lengths (Tensor): the length of input sequence to the _Encoder layer (n_batch,). + + Return: + Tensor shape: (n_batch, n_seq, n_encoder_embedding) + """ + + for conv in self.convolutions: + x = F.dropout(F.relu(conv(x)), 0.5, self.training) + + x = x.transpose(1, 2) + + x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True) + + outputs, _ = self.lstm(x) + + outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) + + return outputs From a1619f508644930fa9624cad1a9e41293749a0cb Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Wed, 29 Jul 2020 06:56:31 -0700 Subject: [PATCH 2/2] Add test --- test/test_models.py | 3 +-- torchaudio/models/__init__.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 272a82b701..78ba52f8b9 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,5 +1,5 @@ import torch -from torchaudio.models import Wav2Letter, MelResNet, UpsampleNetwork, WaveRNN +from torchaudio.models import Wav2Letter, MelResNet, UpsampleNetwork, WaveRNN, _Encoder from . import common_utils @@ -135,4 +135,3 @@ def test_output(self): out = model(x, input_length) assert out.size() == (n_batch, n_seq, n_encoder_embedding) - \ No newline at end of file diff --git a/torchaudio/models/__init__.py b/torchaudio/models/__init__.py index 8e05b8b509..a27a1a1fc4 100644 --- a/torchaudio/models/__init__.py +++ b/torchaudio/models/__init__.py @@ -1,2 +1,3 @@ from .wav2letter import * from .wavernn import * +from ._tacotron import *