diff --git a/test/test_models.py b/test/test_models.py index 109112ee15..519fbc7b26 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,5 +1,5 @@ import torch -from torchaudio.models import Wav2Letter, _MelResNet +from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork from . import common_utils @@ -53,3 +53,31 @@ def test_waveform(self): out = model(x) assert out.size() == (n_batch, n_output, n_time - kernel_size + 1) + + +class TestUpsampleNetwork(common_utils.TorchaudioTestCase): + + def test_waveform(self): + """Validate the output dimensions of a _UpsampleNetwork block. + """ + + upsample_scales = [5, 5, 8] + n_batch = 2 + n_time = 200 + n_freq = 100 + n_output = 256 + n_res_block = 10 + n_hidden = 128 + kernel_size = 5 + + total_scale = 1 + for upsample_scale in upsample_scales: + total_scale *= upsample_scale + + model = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) + + x = torch.rand(n_batch, n_freq, n_time) + out1, out2 = model(x) + + assert out1.size() == (n_batch, n_freq, total_scale * (n_time - kernel_size + 1)) + assert out2.size() == (n_batch, n_output, total_scale * (n_time - kernel_size + 1)) diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index 4d9dc3144a..1df9eb0637 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -1,7 +1,9 @@ +from typing import List + from torch import Tensor from torch import nn -__all__ = ["_ResBlock", "_MelResNet"] +__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork"] class _ResBlock(nn.Module): @@ -85,3 +87,108 @@ def forward(self, specgram: Tensor) -> Tensor: """ return self.melresnet_model(specgram) + + +class _Stretch2d(nn.Module): + r"""Upscale the frequency and time dimensions of a spectrogram. + + Args: + time_scale: the scale factor in time dimension + freq_scale: the scale factor in frequency dimension + + Examples + >>> stretch2d = _Stretch2d(time_scale=10, freq_scale=5) + + >>> input = torch.rand(10, 100, 512) # a random spectrogram + >>> output = stretch2d(input) # shape: (10, 500, 5120) + """ + + def __init__(self, + time_scale: int, + freq_scale: int) -> None: + super().__init__() + + self.freq_scale = freq_scale + self.time_scale = time_scale + + def forward(self, specgram: Tensor) -> Tensor: + r"""Pass the input through the _Stretch2d layer. + + Args: + specgram (Tensor): the input sequence to the _Stretch2d layer (..., n_freq, n_time). + + Return: + Tensor shape: (..., n_freq * freq_scale, n_time * time_scale) + """ + + return specgram.repeat_interleave(self.freq_scale, -2).repeat_interleave(self.time_scale, -1) + + +class _UpsampleNetwork(nn.Module): + r"""Upscale the dimensions of a spectrogram. + + Args: + upsample_scales: the list of upsample scales + n_res_block: the number of ResBlock in stack (default=10) + n_freq: the number of bins in a spectrogram (default=128) + n_hidden: the number of hidden dimensions (default=128) + n_output: the number of output dimensions (default=128) + kernel_size: the number of kernel size in the first Conv1d layer (default=5) + + Examples + >>> upsamplenetwork = _UpsampleNetwork(upsample_scales=[4, 4, 16]) + >>> input = torch.rand(10, 128, 10) # a random spectrogram + >>> output = upsamplenetwork(input) # shape: (10, 1536, 128), (10, 1536, 128) + """ + + def __init__(self, + upsample_scales: List[int], + n_res_block: int = 10, + n_freq: int = 128, + n_hidden: int = 128, + n_output: int = 128, + kernel_size: int = 5) -> None: + super().__init__() + + total_scale = 1 + for upsample_scale in upsample_scales: + total_scale *= upsample_scale + + self.indent = (kernel_size - 1) // 2 * total_scale + self.resnet = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) + self.resnet_stretch = _Stretch2d(total_scale, 1) + + up_layers = [] + for scale in upsample_scales: + stretch = _Stretch2d(scale, 1) + conv = nn.Conv2d(in_channels=1, + out_channels=1, + kernel_size=(1, scale * 2 + 1), + padding=(0, scale), + bias=False) + conv.weight.data.fill_(1. / (scale * 2 + 1)) + up_layers.append(stretch) + up_layers.append(conv) + self.upsample_layers = nn.Sequential(*up_layers) + + def forward(self, specgram: Tensor) -> Tensor: + r"""Pass the input through the _UpsampleNetwork layer. + + Args: + specgram (Tensor): the input sequence to the _UpsampleNetwork layer (n_batch, n_freq, n_time) + + Return: + Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale), + (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) + where total_scale is the product of all elements in upsample_scales. + """ + + resnet_output = self.resnet(specgram).unsqueeze(1) + resnet_output = self.resnet_stretch(resnet_output) + resnet_output = resnet_output.squeeze(1) + + specgram = specgram.unsqueeze(1) + upsampling_output = self.upsample_layers(specgram) + upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent] + + return upsampling_output, resnet_output