From 1ddd3bb573aa05725b3fc556831574b01b71631e Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Wed, 17 Jun 2020 06:37:54 -0700 Subject: [PATCH 1/8] upsamplenetwork --- test/test_models.py | 30 +++++++++- torchaudio/models/_wavernn.py | 107 +++++++++++++++++++++++++++++++++- 2 files changed, 135 insertions(+), 2 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 109112ee15..78f2d9d00c 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,5 +1,5 @@ import torch -from torchaudio.models import Wav2Letter, _MelResNet +from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork from . import common_utils @@ -53,3 +53,31 @@ def test_waveform(self): out = model(x) assert out.size() == (n_batch, n_output, n_time - kernel_size + 1) + + +class TestUpsampleNetwork(common_utils.TorchaudioTestCase): + + def test_waveform(self): + """Validate the output dimensions of a _UpsampleNetwork block. + """ + + upsample_scales = [5, 5, 8] + n_batch = 2 + n_time = 200 + n_freq = 100 + n_output = 256 + n_res_block = 10 + n_hidden = 128 + kernel_size = 5 + + total_scale = 1 + for upsample_scale in upsample_scales: + total_scale *= upsample_scale + + model = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) + + x = torch.rand(n_batch, n_freq, n_time) + out1, out2 = model(x) + + assert out1.size() == (n_batch, total_scale * (n_time - kernel_size + 1), n_freq) + assert out2.size() == (n_batch, total_scale * (n_time - kernel_size + 1), n_output) diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index 4d9dc3144a..f89db550e8 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -1,7 +1,9 @@ +from typing import List + from torch import Tensor from torch import nn -__all__ = ["_ResBlock", "_MelResNet"] +__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork"] class _ResBlock(nn.Module): @@ -85,3 +87,106 @@ def forward(self, specgram: Tensor) -> Tensor: """ return self.melresnet_model(specgram) + + +class _Stretch2d(nn.Module): + r"""Stretch layer upscales the frequency and time dimensions of a spectrogram. + + Args: + time_scale: the scale factor in time dimension + freq_scale: the scale factor in frequency dimension + + Examples + >>> stretch2d = _Stretch2d(time_scale=10, freq_scale=5) + + >>> input = torch.rand(10, 100, 512) # a random spectrogram + >>> output = stretch2d(input) # shape: (10, 500, 5120) + """ + + def __init__(self, + time_scale: int, + freq_scale: int) -> None: + super().__init__() + + self.freq_scale = freq_scale + self.time_scale = time_scale + + def forward(self, specgram: Tensor) -> Tensor: + r"""Pass the input through the _Stretch2d layer. + Args: + specgram (Tensor): the input sequence to the _Stretch2d layer (..., n_freq, n_time). + + Return: + Tensor shape: (..., n_freq * freq_scale, n_time * time_scale) + """ + + return specgram.repeat_interleave(self.freq_scale, -2).repeat_interleave(self.time_scale, -1) + + +class _UpsampleNetwork(nn.Module): + r"""Upsample block upscales the dimensions of a spectrogram to match waveform. + + Args: + upsample_scales: the list of upsample scales + n_res_block: the number of ResBlock in stack (default=10) + n_freq: the number of bins in a spectrogram (default=128) + n_hidden: the number of hidden dimensions (default=128) + n_output: the number of output dimensions (default=128) + kernel_size: the number of kernel size in the first Conv1d layer (default=5) + + Examples + >>> upsamplenetwork = _UpsampleNetwork(upsample_scales=[4, 4, 16]) + >>> input = torch.rand(10, 128, 10) # a random spectrogram + >>> output = upsamplenetwork(input) # shape: (10, 1536, 128), (10, 1536, 128) + """ + + def __init__(self, + upsample_scales: List[int], + n_res_block: int = 10, + n_freq: int = 128, + n_hidden: int = 128, + n_output: int = 128, + kernel_size: int = 5) -> None: + super().__init__() + + total_scale = 1 + for upsample_scale in upsample_scales: + total_scale *= upsample_scale + + self.indent = (kernel_size - 1) // 2 * total_scale + self.resnet = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) + self.resnet_stretch = _Stretch2d(total_scale, 1) + + up_layers = [] + for scale in upsample_scales: + stretch = _Stretch2d(scale, 1) + conv = nn.Conv2d(in_channels=1, + out_channels=1, + kernel_size=(1, scale * 2 + 1), + padding=(0, scale), + bias=False) + conv.weight.data.fill_(1. / (scale * 2 + 1)) + up_layers.append(stretch) + up_layers.append(conv) + self.upsample_layers = nn.Sequential(*up_layers) + + def forward(self, specgram: Tensor) -> Tensor: + r"""Pass the input through the _UpsampleNetwork layer. + Args: + specgram (Tensor): the input sequence to the _UpsampleNetwork layer (n_batch, n_freq, n_time) + + Return: + Tensor shape: (n_batch, (n_time - kernel_size + 1) * total_scale, n_freq), + (n_batch, (n_time - kernel_size + 1) * total_scale, n_output) + where total_scale is the product of all elements in upsample_scales. + """ + + resnet_output = self.resnet(specgram).unsqueeze(1) + resnet_output = self.resnet_stretch(resnet_output) + resnet_output = resnet_output.squeeze(1) + + specgram = specgram.unsqueeze(1) + upsampling_output = self.upsample_layers(specgram) + upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent] + + return upsampling_output.transpose(1, 2), resnet_output.transpose(1, 2) From 62cc7d5e99165f8e7439b81e3158d306c083d03e Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Sun, 21 Jun 2020 09:07:55 -0700 Subject: [PATCH 2/8] update name --- test/test_models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_models.py b/test/test_models.py index 78f2d9d00c..0daa6a7c38 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,3 +1,5 @@ +import unittest + import torch from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork From 1310586b3783aabc0c4d222ddde8a5d365c1219a Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Fri, 26 Jun 2020 06:02:01 -0700 Subject: [PATCH 3/8] update name and docstring --- test/test_models.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 0daa6a7c38..2c56a1f7ea 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,5 +1,3 @@ -import unittest - import torch from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork @@ -38,7 +36,11 @@ def test_mfcc(self): class TestMelResNet(common_utils.TorchaudioTestCase): def test_waveform(self): +<<<<<<< HEAD """Validate the output dimensions of a _MelResNet block. +======= + """test the output dimensions after _MelResNet block. +>>>>>>> update name and docstring """ n_batch = 2 @@ -60,7 +62,11 @@ def test_waveform(self): class TestUpsampleNetwork(common_utils.TorchaudioTestCase): def test_waveform(self): +<<<<<<< HEAD """Validate the output dimensions of a _UpsampleNetwork block. +======= + """test the output dimensions after _UpsampleNetwork block. +>>>>>>> update name and docstring """ upsample_scales = [5, 5, 8] From 52895459f501db5c8cb4f2bc211eb73507f5de2c Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Fri, 26 Jun 2020 09:33:32 -0700 Subject: [PATCH 4/8] update format --- test/test_models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 2c56a1f7ea..ea7f617f50 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -36,11 +36,7 @@ def test_mfcc(self): class TestMelResNet(common_utils.TorchaudioTestCase): def test_waveform(self): -<<<<<<< HEAD """Validate the output dimensions of a _MelResNet block. -======= - """test the output dimensions after _MelResNet block. ->>>>>>> update name and docstring """ n_batch = 2 @@ -62,11 +58,15 @@ def test_waveform(self): class TestUpsampleNetwork(common_utils.TorchaudioTestCase): def test_waveform(self): +<<<<<<< HEAD <<<<<<< HEAD """Validate the output dimensions of a _UpsampleNetwork block. ======= """test the output dimensions after _UpsampleNetwork block. >>>>>>> update name and docstring +======= + """Validate the output dimensions of a _UpsampleNetwork block. +>>>>>>> update format """ upsample_scales = [5, 5, 8] From a8d1450ea758de2c88d832fffca135484009bc22 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Mon, 29 Jun 2020 09:07:17 -0700 Subject: [PATCH 5/8] rebase --- test/test_models.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index ea7f617f50..78f2d9d00c 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -58,15 +58,7 @@ def test_waveform(self): class TestUpsampleNetwork(common_utils.TorchaudioTestCase): def test_waveform(self): -<<<<<<< HEAD -<<<<<<< HEAD """Validate the output dimensions of a _UpsampleNetwork block. -======= - """test the output dimensions after _UpsampleNetwork block. ->>>>>>> update name and docstring -======= - """Validate the output dimensions of a _UpsampleNetwork block. ->>>>>>> update format """ upsample_scales = [5, 5, 8] From 0d56fb2a2d4b09269cc4dc19c386446000ed11fd Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Tue, 30 Jun 2020 08:42:46 -0700 Subject: [PATCH 6/8] update docstring --- torchaudio/models/_wavernn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index f89db550e8..f605861d6f 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -90,7 +90,7 @@ def forward(self, specgram: Tensor) -> Tensor: class _Stretch2d(nn.Module): - r"""Stretch layer upscales the frequency and time dimensions of a spectrogram. + r"""Upscale the frequency and time dimensions of a spectrogram. Args: time_scale: the scale factor in time dimension From b31fbb252f9adc4c9663615afec26f89839e9986 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Tue, 30 Jun 2020 08:45:53 -0700 Subject: [PATCH 7/8] update docstring --- torchaudio/models/_wavernn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index f605861d6f..8e2cb04a3d 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -124,7 +124,7 @@ def forward(self, specgram: Tensor) -> Tensor: class _UpsampleNetwork(nn.Module): - r"""Upsample block upscales the dimensions of a spectrogram to match waveform. + r"""Upscale the dimensions of a spectrogram to match waveform. Args: upsample_scales: the list of upsample scales From 44bff0423286f417618280d202c717e3f736880a Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Wed, 1 Jul 2020 11:38:23 -0700 Subject: [PATCH 8/8] remove transpose and update docstring --- test/test_models.py | 4 ++-- torchaudio/models/_wavernn.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 78f2d9d00c..519fbc7b26 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -79,5 +79,5 @@ def test_waveform(self): x = torch.rand(n_batch, n_freq, n_time) out1, out2 = model(x) - assert out1.size() == (n_batch, total_scale * (n_time - kernel_size + 1), n_freq) - assert out2.size() == (n_batch, total_scale * (n_time - kernel_size + 1), n_output) + assert out1.size() == (n_batch, n_freq, total_scale * (n_time - kernel_size + 1)) + assert out2.size() == (n_batch, n_output, total_scale * (n_time - kernel_size + 1)) diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index 8e2cb04a3d..1df9eb0637 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -113,6 +113,7 @@ def __init__(self, def forward(self, specgram: Tensor) -> Tensor: r"""Pass the input through the _Stretch2d layer. + Args: specgram (Tensor): the input sequence to the _Stretch2d layer (..., n_freq, n_time). @@ -124,7 +125,7 @@ def forward(self, specgram: Tensor) -> Tensor: class _UpsampleNetwork(nn.Module): - r"""Upscale the dimensions of a spectrogram to match waveform. + r"""Upscale the dimensions of a spectrogram. Args: upsample_scales: the list of upsample scales @@ -172,12 +173,13 @@ def __init__(self, def forward(self, specgram: Tensor) -> Tensor: r"""Pass the input through the _UpsampleNetwork layer. + Args: specgram (Tensor): the input sequence to the _UpsampleNetwork layer (n_batch, n_freq, n_time) Return: - Tensor shape: (n_batch, (n_time - kernel_size + 1) * total_scale, n_freq), - (n_batch, (n_time - kernel_size + 1) * total_scale, n_output) + Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale), + (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) where total_scale is the product of all elements in upsample_scales. """ @@ -189,4 +191,4 @@ def forward(self, specgram: Tensor) -> Tensor: upsampling_output = self.upsample_layers(specgram) upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent] - return upsampling_output.transpose(1, 2), resnet_output.transpose(1, 2) + return upsampling_output, resnet_output