Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/pipeline_wavernn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.models._wavernn import _WaveRNN
from torchaudio.models.wavernn import WaveRNN

from datasets import collate_factory, split_process_ljspeech
from losses import LongCrossEntropyLoss, MoLLoss
Expand Down Expand Up @@ -297,7 +297,7 @@ def main(args):

n_classes = 2 ** args.n_bits if args.loss == "crossentropy" else 30

model = _WaveRNN(
model = WaveRNN(
upsample_scales=args.upsample_scales,
n_classes=n_classes,
hop_length=args.hop_length,
Expand Down
26 changes: 13 additions & 13 deletions test/test_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork, _WaveRNN
from torchaudio.models import Wav2Letter, MelResNet, UpsampleNetwork, WaveRNN

from . import common_utils

Expand Down Expand Up @@ -36,7 +36,7 @@ def test_mfcc(self):
class TestMelResNet(common_utils.TorchaudioTestCase):

def test_waveform(self):
"""Validate the output dimensions of a _MelResNet block.
"""Validate the output dimensions of a MelResNet block.
"""

n_batch = 2
Expand All @@ -47,7 +47,7 @@ def test_waveform(self):
n_hidden = 128
kernel_size = 5

model = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
model = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)

x = torch.rand(n_batch, n_freq, n_time)
out = model(x)
Expand All @@ -58,7 +58,7 @@ def test_waveform(self):
class TestUpsampleNetwork(common_utils.TorchaudioTestCase):

def test_waveform(self):
"""Validate the output dimensions of a _UpsampleNetwork block.
"""Validate the output dimensions of a UpsampleNetwork block.
"""

upsample_scales = [5, 5, 8]
Expand All @@ -74,12 +74,12 @@ def test_waveform(self):
for upsample_scale in upsample_scales:
total_scale *= upsample_scale

model = _UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
model = UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)

x = torch.rand(n_batch, n_freq, n_time)
out1, out2 = model(x)
Expand All @@ -91,7 +91,7 @@ def test_waveform(self):
class TestWaveRNN(common_utils.TorchaudioTestCase):

def test_waveform(self):
"""Validate the output dimensions of a _WaveRNN model.
"""Validate the output dimensions of a WaveRNN model.
"""

upsample_scales = [5, 5, 8]
Expand All @@ -107,8 +107,8 @@ def test_waveform(self):
n_hidden = 128
kernel_size = 5

model = _WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)

x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
mels = torch.rand(n_batch, 1, n_freq, n_time)
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .wav2letter import *
from ._wavernn import *
from .wavernn import *
64 changes: 32 additions & 32 deletions torchaudio/models/_wavernn.py → torchaudio/models/wavernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from torch import Tensor
from torch import nn

__all__ = ["_ResBlock", "_MelResNet", "_Stretch2d", "_UpsampleNetwork", "_WaveRNN"]
__all__ = ["ResBlock", "MelResNet", "Stretch2d", "UpsampleNetwork", "WaveRNN"]


class _ResBlock(nn.Module):
class ResBlock(nn.Module):
r"""ResNet block based on "Deep Residual Learning for Image Recognition"

The paper link is https://arxiv.org/pdf/1512.03385.pdf.
Expand All @@ -16,7 +16,7 @@ class _ResBlock(nn.Module):
n_freq: the number of bins in a spectrogram. (Default: ``128``)

Examples
>>> resblock = _ResBlock()
>>> resblock = ResBlock()
>>> input = torch.rand(10, 128, 512) # a random spectrogram
>>> output = resblock(input) # shape: (10, 128, 512)
"""
Expand All @@ -33,9 +33,9 @@ def __init__(self, n_freq: int = 128) -> None:
)

def forward(self, specgram: Tensor) -> Tensor:
r"""Pass the input through the _ResBlock layer.
r"""Pass the input through the ResBlock layer.
Args:
specgram (Tensor): the input sequence to the _ResBlock layer (n_batch, n_freq, n_time).
specgram (Tensor): the input sequence to the ResBlock layer (n_batch, n_freq, n_time).

Return:
Tensor shape: (n_batch, n_freq, n_time)
Expand All @@ -44,7 +44,7 @@ def forward(self, specgram: Tensor) -> Tensor:
return self.resblock_model(specgram) + specgram


class _MelResNet(nn.Module):
class MelResNet(nn.Module):
r"""MelResNet layer uses a stack of ResBlocks on spectrogram.

Args:
Expand All @@ -55,7 +55,7 @@ class _MelResNet(nn.Module):
kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)

Examples
>>> melresnet = _MelResNet()
>>> melresnet = MelResNet()
>>> input = torch.rand(10, 128, 512) # a random spectrogram
>>> output = melresnet(input) # shape: (10, 128, 508)
"""
Expand All @@ -68,7 +68,7 @@ def __init__(self,
kernel_size: int = 5) -> None:
super().__init__()

ResBlocks = [_ResBlock(n_hidden) for _ in range(n_res_block)]
ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)]

self.melresnet_model = nn.Sequential(
nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False),
Expand All @@ -79,9 +79,9 @@ def __init__(self,
)

def forward(self, specgram: Tensor) -> Tensor:
r"""Pass the input through the _MelResNet layer.
r"""Pass the input through the MelResNet layer.
Args:
specgram (Tensor): the input sequence to the _MelResNet layer (n_batch, n_freq, n_time).
specgram (Tensor): the input sequence to the MelResNet layer (n_batch, n_freq, n_time).

Return:
Tensor shape: (n_batch, n_output, n_time - kernel_size + 1)
Expand All @@ -90,15 +90,15 @@ def forward(self, specgram: Tensor) -> Tensor:
return self.melresnet_model(specgram)


class _Stretch2d(nn.Module):
class Stretch2d(nn.Module):
r"""Upscale the frequency and time dimensions of a spectrogram.

Args:
time_scale: the scale factor in time dimension
freq_scale: the scale factor in frequency dimension

Examples
>>> stretch2d = _Stretch2d(time_scale=10, freq_scale=5)
>>> stretch2d = Stretch2d(time_scale=10, freq_scale=5)

>>> input = torch.rand(10, 100, 512) # a random spectrogram
>>> output = stretch2d(input) # shape: (10, 500, 5120)
Expand All @@ -113,10 +113,10 @@ def __init__(self,
self.time_scale = time_scale

def forward(self, specgram: Tensor) -> Tensor:
r"""Pass the input through the _Stretch2d layer.
r"""Pass the input through the Stretch2d layer.

Args:
specgram (Tensor): the input sequence to the _Stretch2d layer (..., n_freq, n_time).
specgram (Tensor): the input sequence to the Stretch2d layer (..., n_freq, n_time).

Return:
Tensor shape: (..., n_freq * freq_scale, n_time * time_scale)
Expand All @@ -125,7 +125,7 @@ def forward(self, specgram: Tensor) -> Tensor:
return specgram.repeat_interleave(self.freq_scale, -2).repeat_interleave(self.time_scale, -1)


class _UpsampleNetwork(nn.Module):
class UpsampleNetwork(nn.Module):
r"""Upscale the dimensions of a spectrogram.

Args:
Expand All @@ -137,7 +137,7 @@ class _UpsampleNetwork(nn.Module):
kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)

Examples
>>> upsamplenetwork = _UpsampleNetwork(upsample_scales=[4, 4, 16])
>>> upsamplenetwork = UpsampleNetwork(upsample_scales=[4, 4, 16])
>>> input = torch.rand(10, 128, 10) # a random spectrogram
>>> output = upsamplenetwork(input) # shape: (10, 1536, 128), (10, 1536, 128)
"""
Expand All @@ -156,12 +156,12 @@ def __init__(self,
total_scale *= upsample_scale

self.indent = (kernel_size - 1) // 2 * total_scale
self.resnet = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
self.resnet_stretch = _Stretch2d(total_scale, 1)
self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
self.resnet_stretch = Stretch2d(total_scale, 1)

up_layers = []
for scale in upsample_scales:
stretch = _Stretch2d(scale, 1)
stretch = Stretch2d(scale, 1)
conv = nn.Conv2d(in_channels=1,
out_channels=1,
kernel_size=(1, scale * 2 + 1),
Expand All @@ -173,10 +173,10 @@ def __init__(self,
self.upsample_layers = nn.Sequential(*up_layers)

def forward(self, specgram: Tensor) -> Tuple[Tensor, Tensor]:
r"""Pass the input through the _UpsampleNetwork layer.
r"""Pass the input through the UpsampleNetwork layer.

Args:
specgram (Tensor): the input sequence to the _UpsampleNetwork layer (n_batch, n_freq, n_time)
specgram (Tensor): the input sequence to the UpsampleNetwork layer (n_batch, n_freq, n_time)

Return:
Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale),
Expand All @@ -195,7 +195,7 @@ def forward(self, specgram: Tensor) -> Tuple[Tensor, Tensor]:
return upsampling_output, resnet_output


class _WaveRNN(nn.Module):
class WaveRNN(nn.Module):
r"""WaveRNN model based on the implementation from `fatchord <https://github.com/fatchord/WaveRNN>`_.

The original implementation was introduced in
Expand All @@ -216,7 +216,7 @@ class _WaveRNN(nn.Module):
n_output: the number of output dimensions of melresnet. (Default: ``128``)

Example
>>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200)
>>> wavernn = WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200)
>>> waveform, sample_rate = torchaudio.load(file)
>>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
>>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
Expand Down Expand Up @@ -249,12 +249,12 @@ def __init__(self,
if total_scale != self.hop_length:
raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")

self.upsample = _UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
self.upsample = UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)

self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
Expand All @@ -268,11 +268,11 @@ def __init__(self,
self.fc3 = nn.Linear(n_fc, self.n_classes)

def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
r"""Pass the input through the _WaveRNN model.
r"""Pass the input through the WaveRNN model.

Args:
waveform: the input waveform to the _WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
specgram: the input spectrogram to the _WaveRNN layer (n_batch, 1, n_freq, n_time)
waveform: the input waveform to the WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
specgram: the input spectrogram to the WaveRNN layer (n_batch, 1, n_freq, n_time)

Return:
Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
Expand Down