From 3ef78e2d180ca367d06de5ca11e7bcf13939df35 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Fri, 17 Jul 2020 10:28:48 -0700 Subject: [PATCH 1/3] update variable names in wavernn --- test/test_models.py | 47 ++++++++++++---------- torchaudio/models/_wavernn.py | 73 +++++++++++++++++++---------------- 2 files changed, 65 insertions(+), 55 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index c54a57cebd..3ca8e8d44c 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -42,17 +42,17 @@ def test_waveform(self): n_batch = 2 n_time = 200 n_freq = 100 - n_output = 128 + n_output_melresnet = 128 n_res_block = 10 - n_hidden = 128 + n_hidden_resblock = 128 kernel_size = 5 - model = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) + model = _MelResNet(n_res_block, n_freq, n_hidden_resblock, n_output_melresnet, kernel_size) x = torch.rand(n_batch, n_freq, n_time) out = model(x) - assert out.size() == (n_batch, n_output, n_time - kernel_size + 1) + assert out.size() == (n_batch, n_output_melresnet, n_time - kernel_size + 1) class TestUpsampleNetwork(common_utils.TorchaudioTestCase): @@ -65,22 +65,27 @@ def test_waveform(self): n_batch = 2 n_time = 200 n_freq = 100 - n_output = 256 + n_output_melresnet = 256 n_res_block = 10 - n_hidden = 128 + n_hidden_resblock = 128 kernel_size = 5 total_scale = 1 for upsample_scale in upsample_scales: total_scale *= upsample_scale - model = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) + model = _UpsampleNetwork(upsample_scales, + n_res_block, + n_freq, + n_hidden_resblock, + n_output_melresnet, + kernel_size) x = torch.rand(n_batch, n_freq, n_time) out1, out2 = model(x) assert out1.size() == (n_batch, n_freq, total_scale * (n_time - kernel_size + 1)) - assert out2.size() == (n_batch, n_output, total_scale * (n_time - kernel_size + 1)) + assert out2.size() == (n_batch, n_output_melresnet, total_scale * (n_time - kernel_size + 1)) class TestWaveRNN(common_utils.TorchaudioTestCase): @@ -92,26 +97,26 @@ def test_waveform(self): upsample_scales = [5, 5, 8] n_rnn = 512 n_fc = 512 - n_bits = 9 + n_classes = 512 sample_rate = 24000 hop_length = 200 n_batch = 2 n_time = 200 n_freq = 100 - n_output = 256 + n_output_melresnet = 256 n_res_block = 10 - n_hidden = 128 + n_hidden_resblock = 128 kernel_size = 5 - mode = 'waveform' + loss = 'waveform' - model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block, - n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode) + model = _WaveRNN(upsample_scales, n_classes, sample_rate, hop_length, n_res_block, + n_rnn, n_fc, kernel_size, n_freq, n_hidden_resblock, n_output_melresnet, loss) x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1)) mels = torch.rand(n_batch, 1, n_freq, n_time) out = model(x, mels) - assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 2 ** n_bits) + assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes) def test_mol(self): """Validate the output dimensions of a _WaveRNN model in mol mode. @@ -120,20 +125,20 @@ def test_mol(self): upsample_scales = [5, 5, 8] n_rnn = 512 n_fc = 512 - n_bits = 9 + n_classes = 512 sample_rate = 24000 hop_length = 200 n_batch = 2 n_time = 200 n_freq = 100 - n_output = 256 + n_output_melresnet = 256 n_res_block = 10 - n_hidden = 128 + n_hidden_resblock = 128 kernel_size = 5 - mode = 'mol' + loss = 'mol' - model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block, - n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode) + model = _WaveRNN(upsample_scales, n_classes, sample_rate, hop_length, n_res_block, + n_rnn, n_fc, kernel_size, n_freq, n_hidden_resblock, n_output_melresnet, loss) x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1)) mels = torch.rand(n_batch, 1, n_freq, n_time) diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index cd2e89a10c..93ec95a5cd 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -50,8 +50,8 @@ class _MelResNet(nn.Module): Args: n_res_block: the number of ResBlock in stack (default=10) n_freq: the number of bins in a spectrogram (default=128) - n_hidden: the number of hidden dimensions (default=128) - n_output: the number of output dimensions (default=128) + n_hidden_resblock: the number of hidden dimensions of resblock (default=128) + n_output_melresnet: the number of output dimensions of melresnet (default=128) kernel_size: the number of kernel size in the first Conv1d layer (default=5) Examples @@ -63,19 +63,19 @@ class _MelResNet(nn.Module): def __init__(self, n_res_block: int = 10, n_freq: int = 128, - n_hidden: int = 128, - n_output: int = 128, + n_hidden_resblock: int = 128, + n_output_melresnet: int = 128, kernel_size: int = 5) -> None: super().__init__() - ResBlocks = [_ResBlock(n_hidden) for _ in range(n_res_block)] + ResBlocks = [_ResBlock(n_hidden_resblock) for _ in range(n_res_block)] self.melresnet_model = nn.Sequential( - nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False), - nn.BatchNorm1d(n_hidden), + nn.Conv1d(in_channels=n_freq, out_channels=n_hidden_resblock, kernel_size=kernel_size, bias=False), + nn.BatchNorm1d(n_hidden_resblock), nn.ReLU(inplace=True), *ResBlocks, - nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1) + nn.Conv1d(in_channels=n_hidden_resblock, out_channels=n_output_melresnet, kernel_size=1) ) def forward(self, specgram: Tensor) -> Tensor: @@ -84,7 +84,7 @@ def forward(self, specgram: Tensor) -> Tensor: specgram (Tensor): the input sequence to the _MelResNet layer (n_batch, n_freq, n_time). Return: - Tensor shape: (n_batch, n_output, n_time - kernel_size + 1) + Tensor shape: (n_batch, n_output_melresnet, n_time - kernel_size + 1) """ return self.melresnet_model(specgram) @@ -132,8 +132,8 @@ class _UpsampleNetwork(nn.Module): upsample_scales: the list of upsample scales n_res_block: the number of ResBlock in stack (default=10) n_freq: the number of bins in a spectrogram (default=128) - n_hidden: the number of hidden dimensions (default=128) - n_output: the number of output dimensions (default=128) + n_hidden_resblock: the number of hidden dimensions of resblock (default=128) + n_output_melresnet: the number of output dimensions of melresnet (default=128) kernel_size: the number of kernel size in the first Conv1d layer (default=5) Examples @@ -146,8 +146,8 @@ def __init__(self, upsample_scales: List[int], n_res_block: int = 10, n_freq: int = 128, - n_hidden: int = 128, - n_output: int = 128, + n_hidden_resblock: int = 128, + n_output_melresnet: int = 128, kernel_size: int = 5) -> None: super().__init__() @@ -156,7 +156,7 @@ def __init__(self, total_scale *= upsample_scale self.indent = (kernel_size - 1) // 2 * total_scale - self.resnet = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) + self.resnet = _MelResNet(n_res_block, n_freq, n_hidden_resblock, n_output_melresnet, kernel_size) self.resnet_stretch = _Stretch2d(total_scale, 1) up_layers = [] @@ -180,7 +180,7 @@ def forward(self, specgram: Tensor) -> Tensor: Return: Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale), - (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) + (n_batch, n_output_melresnet, (n_time - kernel_size + 1) * total_scale) where total_scale is the product of all elements in upsample_scales. """ @@ -205,7 +205,7 @@ class _WaveRNN(nn.Module): Args: upsample_scales: the list of upsample scales - n_bits: the bits of output waveform + n_classes: the number of output classes sample_rate: the rate of audio dimensions (samples per second) hop_length: the number of samples between the starts of consecutive frames n_res_block: the number of ResBlock in stack (default=10) @@ -213,22 +213,22 @@ class _WaveRNN(nn.Module): n_fc: the dimension of fully connected layer (default=512) kernel_size: the number of kernel size in the first Conv1d layer (default=5) n_freq: the number of bins in a spectrogram (default=128) - n_hidden: the number of hidden dimensions (default=128) - n_output: the number of output dimensions (default=128) - mode: the mode of waveform in ['waveform', 'mol'] (default='waveform') + n_hidden_resblock: the number of hidden dimensions of resblock (default=128) + n_output_melresnet: the number of output dimensions of melresnet (default=128) + loss: the type of loss in ['waveform', 'mol'] (default='waveform') Example - >>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_bits=9, sample_rate=24000, hop_length=200) + >>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_classes=512, sample_rate=24000, hop_length=200) >>> waveform, sample_rate = torchaudio.load(file) >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length) >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time) >>> output = wavernn(waveform, specgram) - >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits) + >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes) """ def __init__(self, upsample_scales: List[int], - n_bits: int, + n_classes: int, sample_rate: int, hop_length: int, n_res_block: int = 10, @@ -236,23 +236,23 @@ def __init__(self, n_fc: int = 512, kernel_size: int = 5, n_freq: int = 128, - n_hidden: int = 128, - n_output: int = 128, - mode: str = 'waveform') -> None: + n_hidden_resblock: int = 128, + n_output_melresnet: int = 128, + loss: str = 'waveform') -> None: super().__init__() - self.mode = mode + self.loss = loss self.kernel_size = kernel_size - if self.mode == 'waveform': - self.n_classes = 2 ** n_bits - elif self.mode == 'mol': + if self.loss == 'waveform': + self.n_classes = n_classes + elif self.loss == 'mol': self.n_classes = 30 else: - raise ValueError(f"Expected mode: `waveform` or `mol`, but found {self.mode}") + raise ValueError(f"Expected loss: `waveform` or `mol`, but found {self.loss}") self.n_rnn = n_rnn - self.n_aux = n_output // 4 + self.n_aux = n_output_melresnet // 4 self.hop_length = hop_length self.sample_rate = sample_rate @@ -262,7 +262,12 @@ def __init__(self, if total_scale != self.hop_length: raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}") - self.upsample = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) + self.upsample = _UpsampleNetwork(upsample_scales, + n_res_block, + n_freq, + n_hidden_resblock, + n_output_melresnet, + kernel_size) self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn) self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True) @@ -283,7 +288,7 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: specgram: the input spectrogram to the _WaveRNN layer (n_batch, 1, n_freq, n_time) Return: - Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits) + Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes) """ assert waveform.size(1) == 1, 'Require the input channel of waveform is 1' @@ -296,7 +301,7 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) # output of upsample: # specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale) - # aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) + # aux: (n_batch, n_output_melresnet, (n_time - kernel_size + 1) * total_scale) specgram, aux = self.upsample(specgram) specgram = specgram.transpose(1, 2) aux = aux.transpose(1, 2) From c0dff02cc64dd0785d0d6b7e020d65982f6044fd Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Fri, 17 Jul 2020 11:35:00 -0700 Subject: [PATCH 2/3] change mode name --- test/test_models.py | 6 +++--- torchaudio/models/_wavernn.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 3ca8e8d44c..b9023ca098 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -91,7 +91,7 @@ def test_waveform(self): class TestWaveRNN(common_utils.TorchaudioTestCase): def test_waveform(self): - """Validate the output dimensions of a _WaveRNN model in waveform mode. + """Validate the output dimensions of a _WaveRNN model in crossentropy loss. """ upsample_scales = [5, 5, 8] @@ -107,7 +107,7 @@ def test_waveform(self): n_res_block = 10 n_hidden_resblock = 128 kernel_size = 5 - loss = 'waveform' + loss = 'crossentropy' model = _WaveRNN(upsample_scales, n_classes, sample_rate, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden_resblock, n_output_melresnet, loss) @@ -119,7 +119,7 @@ def test_waveform(self): assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes) def test_mol(self): - """Validate the output dimensions of a _WaveRNN model in mol mode. + """Validate the output dimensions of a _WaveRNN model in mol loss. """ upsample_scales = [5, 5, 8] diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index 93ec95a5cd..35c2b24d46 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -215,7 +215,7 @@ class _WaveRNN(nn.Module): n_freq: the number of bins in a spectrogram (default=128) n_hidden_resblock: the number of hidden dimensions of resblock (default=128) n_output_melresnet: the number of output dimensions of melresnet (default=128) - loss: the type of loss in ['waveform', 'mol'] (default='waveform') + loss: the type of loss in ['crossentropy', 'mol'] (default='crossentropy') Example >>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_classes=512, sample_rate=24000, hop_length=200) @@ -238,18 +238,18 @@ def __init__(self, n_freq: int = 128, n_hidden_resblock: int = 128, n_output_melresnet: int = 128, - loss: str = 'waveform') -> None: + loss: str = 'crossentropy') -> None: super().__init__() self.loss = loss self.kernel_size = kernel_size - if self.loss == 'waveform': + if self.loss == 'crossentropy': self.n_classes = n_classes elif self.loss == 'mol': self.n_classes = 30 else: - raise ValueError(f"Expected loss: `waveform` or `mol`, but found {self.loss}") + raise ValueError(f"Expected loss: `crossentropy` or `mol`, but found {self.loss}") self.n_rnn = n_rnn self.n_aux = n_output_melresnet // 4 From 4168fc518168538d2da58678e7b0290fbab60578 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Fri, 17 Jul 2020 13:21:17 -0700 Subject: [PATCH 3/3] update variable and remove mode --- test/test_models.py | 58 ++++++++------------------------ torchaudio/models/_wavernn.py | 63 ++++++++++++++--------------------- 2 files changed, 39 insertions(+), 82 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index b9023ca098..10ca5c827b 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -42,17 +42,17 @@ def test_waveform(self): n_batch = 2 n_time = 200 n_freq = 100 - n_output_melresnet = 128 + n_output = 128 n_res_block = 10 - n_hidden_resblock = 128 + n_hidden = 128 kernel_size = 5 - model = _MelResNet(n_res_block, n_freq, n_hidden_resblock, n_output_melresnet, kernel_size) + model = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) x = torch.rand(n_batch, n_freq, n_time) out = model(x) - assert out.size() == (n_batch, n_output_melresnet, n_time - kernel_size + 1) + assert out.size() == (n_batch, n_output, n_time - kernel_size + 1) class TestUpsampleNetwork(common_utils.TorchaudioTestCase): @@ -65,9 +65,9 @@ def test_waveform(self): n_batch = 2 n_time = 200 n_freq = 100 - n_output_melresnet = 256 + n_output = 256 n_res_block = 10 - n_hidden_resblock = 128 + n_hidden = 128 kernel_size = 5 total_scale = 1 @@ -77,71 +77,41 @@ def test_waveform(self): model = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, - n_hidden_resblock, - n_output_melresnet, + n_hidden, + n_output, kernel_size) x = torch.rand(n_batch, n_freq, n_time) out1, out2 = model(x) assert out1.size() == (n_batch, n_freq, total_scale * (n_time - kernel_size + 1)) - assert out2.size() == (n_batch, n_output_melresnet, total_scale * (n_time - kernel_size + 1)) + assert out2.size() == (n_batch, n_output, total_scale * (n_time - kernel_size + 1)) class TestWaveRNN(common_utils.TorchaudioTestCase): def test_waveform(self): - """Validate the output dimensions of a _WaveRNN model in crossentropy loss. + """Validate the output dimensions of a _WaveRNN model. """ upsample_scales = [5, 5, 8] n_rnn = 512 n_fc = 512 n_classes = 512 - sample_rate = 24000 hop_length = 200 n_batch = 2 n_time = 200 n_freq = 100 - n_output_melresnet = 256 + n_output = 256 n_res_block = 10 - n_hidden_resblock = 128 + n_hidden = 128 kernel_size = 5 - loss = 'crossentropy' - model = _WaveRNN(upsample_scales, n_classes, sample_rate, hop_length, n_res_block, - n_rnn, n_fc, kernel_size, n_freq, n_hidden_resblock, n_output_melresnet, loss) + model = _WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, + n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output) x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1)) mels = torch.rand(n_batch, 1, n_freq, n_time) out = model(x, mels) assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes) - - def test_mol(self): - """Validate the output dimensions of a _WaveRNN model in mol loss. - """ - - upsample_scales = [5, 5, 8] - n_rnn = 512 - n_fc = 512 - n_classes = 512 - sample_rate = 24000 - hop_length = 200 - n_batch = 2 - n_time = 200 - n_freq = 100 - n_output_melresnet = 256 - n_res_block = 10 - n_hidden_resblock = 128 - kernel_size = 5 - loss = 'mol' - - model = _WaveRNN(upsample_scales, n_classes, sample_rate, hop_length, n_res_block, - n_rnn, n_fc, kernel_size, n_freq, n_hidden_resblock, n_output_melresnet, loss) - - x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1)) - mels = torch.rand(n_batch, 1, n_freq, n_time) - out = model(x, mels) - - assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 30) diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index 35c2b24d46..afe70a39b2 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -50,8 +50,8 @@ class _MelResNet(nn.Module): Args: n_res_block: the number of ResBlock in stack (default=10) n_freq: the number of bins in a spectrogram (default=128) - n_hidden_resblock: the number of hidden dimensions of resblock (default=128) - n_output_melresnet: the number of output dimensions of melresnet (default=128) + n_hidden: the number of hidden dimensions of resblock (default=128) + n_output: the number of output dimensions of melresnet (default=128) kernel_size: the number of kernel size in the first Conv1d layer (default=5) Examples @@ -63,19 +63,19 @@ class _MelResNet(nn.Module): def __init__(self, n_res_block: int = 10, n_freq: int = 128, - n_hidden_resblock: int = 128, - n_output_melresnet: int = 128, + n_hidden: int = 128, + n_output: int = 128, kernel_size: int = 5) -> None: super().__init__() - ResBlocks = [_ResBlock(n_hidden_resblock) for _ in range(n_res_block)] + ResBlocks = [_ResBlock(n_hidden) for _ in range(n_res_block)] self.melresnet_model = nn.Sequential( - nn.Conv1d(in_channels=n_freq, out_channels=n_hidden_resblock, kernel_size=kernel_size, bias=False), - nn.BatchNorm1d(n_hidden_resblock), + nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False), + nn.BatchNorm1d(n_hidden), nn.ReLU(inplace=True), *ResBlocks, - nn.Conv1d(in_channels=n_hidden_resblock, out_channels=n_output_melresnet, kernel_size=1) + nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1) ) def forward(self, specgram: Tensor) -> Tensor: @@ -84,7 +84,7 @@ def forward(self, specgram: Tensor) -> Tensor: specgram (Tensor): the input sequence to the _MelResNet layer (n_batch, n_freq, n_time). Return: - Tensor shape: (n_batch, n_output_melresnet, n_time - kernel_size + 1) + Tensor shape: (n_batch, n_output, n_time - kernel_size + 1) """ return self.melresnet_model(specgram) @@ -132,8 +132,8 @@ class _UpsampleNetwork(nn.Module): upsample_scales: the list of upsample scales n_res_block: the number of ResBlock in stack (default=10) n_freq: the number of bins in a spectrogram (default=128) - n_hidden_resblock: the number of hidden dimensions of resblock (default=128) - n_output_melresnet: the number of output dimensions of melresnet (default=128) + n_hidden: the number of hidden dimensions of resblock (default=128) + n_output: the number of output dimensions of melresnet (default=128) kernel_size: the number of kernel size in the first Conv1d layer (default=5) Examples @@ -146,8 +146,8 @@ def __init__(self, upsample_scales: List[int], n_res_block: int = 10, n_freq: int = 128, - n_hidden_resblock: int = 128, - n_output_melresnet: int = 128, + n_hidden: int = 128, + n_output: int = 128, kernel_size: int = 5) -> None: super().__init__() @@ -156,7 +156,7 @@ def __init__(self, total_scale *= upsample_scale self.indent = (kernel_size - 1) // 2 * total_scale - self.resnet = _MelResNet(n_res_block, n_freq, n_hidden_resblock, n_output_melresnet, kernel_size) + self.resnet = _MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) self.resnet_stretch = _Stretch2d(total_scale, 1) up_layers = [] @@ -180,7 +180,7 @@ def forward(self, specgram: Tensor) -> Tensor: Return: Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale), - (n_batch, n_output_melresnet, (n_time - kernel_size + 1) * total_scale) + (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) where total_scale is the product of all elements in upsample_scales. """ @@ -206,19 +206,17 @@ class _WaveRNN(nn.Module): Args: upsample_scales: the list of upsample scales n_classes: the number of output classes - sample_rate: the rate of audio dimensions (samples per second) hop_length: the number of samples between the starts of consecutive frames n_res_block: the number of ResBlock in stack (default=10) n_rnn: the dimension of RNN layer (default=512) n_fc: the dimension of fully connected layer (default=512) kernel_size: the number of kernel size in the first Conv1d layer (default=5) n_freq: the number of bins in a spectrogram (default=128) - n_hidden_resblock: the number of hidden dimensions of resblock (default=128) - n_output_melresnet: the number of output dimensions of melresnet (default=128) - loss: the type of loss in ['crossentropy', 'mol'] (default='crossentropy') + n_hidden: the number of hidden dimensions of resblock (default=128) + n_output: the number of output dimensions of melresnet (default=128) Example - >>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_classes=512, sample_rate=24000, hop_length=200) + >>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200) >>> waveform, sample_rate = torchaudio.load(file) >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length) >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time) @@ -229,32 +227,21 @@ class _WaveRNN(nn.Module): def __init__(self, upsample_scales: List[int], n_classes: int, - sample_rate: int, hop_length: int, n_res_block: int = 10, n_rnn: int = 512, n_fc: int = 512, kernel_size: int = 5, n_freq: int = 128, - n_hidden_resblock: int = 128, - n_output_melresnet: int = 128, - loss: str = 'crossentropy') -> None: + n_hidden: int = 128, + n_output: int = 128) -> None: super().__init__() - self.loss = loss self.kernel_size = kernel_size - - if self.loss == 'crossentropy': - self.n_classes = n_classes - elif self.loss == 'mol': - self.n_classes = 30 - else: - raise ValueError(f"Expected loss: `crossentropy` or `mol`, but found {self.loss}") - self.n_rnn = n_rnn - self.n_aux = n_output_melresnet // 4 + self.n_aux = n_output // 4 self.hop_length = hop_length - self.sample_rate = sample_rate + self.n_classes = n_classes total_scale = 1 for upsample_scale in upsample_scales: @@ -265,8 +252,8 @@ def __init__(self, self.upsample = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, - n_hidden_resblock, - n_output_melresnet, + n_hidden, + n_output, kernel_size) self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn) @@ -301,7 +288,7 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) # output of upsample: # specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale) - # aux: (n_batch, n_output_melresnet, (n_time - kernel_size + 1) * total_scale) + # aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) specgram, aux = self.upsample(specgram) specgram = specgram.transpose(1, 2) aux = aux.transpose(1, 2)