diff --git a/test/test_models.py b/test/test_models.py index c54a57cebd..10ca5c827b 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -74,7 +74,12 @@ def test_waveform(self): for upsample_scale in upsample_scales: total_scale *= upsample_scale - model = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) + model = _UpsampleNetwork(upsample_scales, + n_res_block, + n_freq, + n_hidden, + n_output, + kernel_size) x = torch.rand(n_batch, n_freq, n_time) out1, out2 = model(x) @@ -86,14 +91,13 @@ def test_waveform(self): class TestWaveRNN(common_utils.TorchaudioTestCase): def test_waveform(self): - """Validate the output dimensions of a _WaveRNN model in waveform mode. + """Validate the output dimensions of a _WaveRNN model. """ upsample_scales = [5, 5, 8] n_rnn = 512 n_fc = 512 - n_bits = 9 - sample_rate = 24000 + n_classes = 512 hop_length = 200 n_batch = 2 n_time = 200 @@ -102,41 +106,12 @@ def test_waveform(self): n_res_block = 10 n_hidden = 128 kernel_size = 5 - mode = 'waveform' - model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block, - n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode) + model = _WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, + n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output) x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1)) mels = torch.rand(n_batch, 1, n_freq, n_time) out = model(x, mels) - assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 2 ** n_bits) - - def test_mol(self): - """Validate the output dimensions of a _WaveRNN model in mol mode. - """ - - upsample_scales = [5, 5, 8] - n_rnn = 512 - n_fc = 512 - n_bits = 9 - sample_rate = 24000 - hop_length = 200 - n_batch = 2 - n_time = 200 - n_freq = 100 - n_output = 256 - n_res_block = 10 - n_hidden = 128 - kernel_size = 5 - mode = 'mol' - - model = _WaveRNN(upsample_scales, n_bits, sample_rate, hop_length, n_res_block, - n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output, mode) - - x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1)) - mels = torch.rand(n_batch, 1, n_freq, n_time) - out = model(x, mels) - - assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), 30) + assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes) diff --git a/torchaudio/models/_wavernn.py b/torchaudio/models/_wavernn.py index cd2e89a10c..afe70a39b2 100644 --- a/torchaudio/models/_wavernn.py +++ b/torchaudio/models/_wavernn.py @@ -50,8 +50,8 @@ class _MelResNet(nn.Module): Args: n_res_block: the number of ResBlock in stack (default=10) n_freq: the number of bins in a spectrogram (default=128) - n_hidden: the number of hidden dimensions (default=128) - n_output: the number of output dimensions (default=128) + n_hidden: the number of hidden dimensions of resblock (default=128) + n_output: the number of output dimensions of melresnet (default=128) kernel_size: the number of kernel size in the first Conv1d layer (default=5) Examples @@ -132,8 +132,8 @@ class _UpsampleNetwork(nn.Module): upsample_scales: the list of upsample scales n_res_block: the number of ResBlock in stack (default=10) n_freq: the number of bins in a spectrogram (default=128) - n_hidden: the number of hidden dimensions (default=128) - n_output: the number of output dimensions (default=128) + n_hidden: the number of hidden dimensions of resblock (default=128) + n_output: the number of output dimensions of melresnet (default=128) kernel_size: the number of kernel size in the first Conv1d layer (default=5) Examples @@ -205,31 +205,28 @@ class _WaveRNN(nn.Module): Args: upsample_scales: the list of upsample scales - n_bits: the bits of output waveform - sample_rate: the rate of audio dimensions (samples per second) + n_classes: the number of output classes hop_length: the number of samples between the starts of consecutive frames n_res_block: the number of ResBlock in stack (default=10) n_rnn: the dimension of RNN layer (default=512) n_fc: the dimension of fully connected layer (default=512) kernel_size: the number of kernel size in the first Conv1d layer (default=5) n_freq: the number of bins in a spectrogram (default=128) - n_hidden: the number of hidden dimensions (default=128) - n_output: the number of output dimensions (default=128) - mode: the mode of waveform in ['waveform', 'mol'] (default='waveform') + n_hidden: the number of hidden dimensions of resblock (default=128) + n_output: the number of output dimensions of melresnet (default=128) Example - >>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_bits=9, sample_rate=24000, hop_length=200) + >>> wavernn = _waveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200) >>> waveform, sample_rate = torchaudio.load(file) >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length) >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time) >>> output = wavernn(waveform, specgram) - >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits) + >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes) """ def __init__(self, upsample_scales: List[int], - n_bits: int, - sample_rate: int, + n_classes: int, hop_length: int, n_res_block: int = 10, n_rnn: int = 512, @@ -237,24 +234,14 @@ def __init__(self, kernel_size: int = 5, n_freq: int = 128, n_hidden: int = 128, - n_output: int = 128, - mode: str = 'waveform') -> None: + n_output: int = 128) -> None: super().__init__() - self.mode = mode self.kernel_size = kernel_size - - if self.mode == 'waveform': - self.n_classes = 2 ** n_bits - elif self.mode == 'mol': - self.n_classes = 30 - else: - raise ValueError(f"Expected mode: `waveform` or `mol`, but found {self.mode}") - self.n_rnn = n_rnn self.n_aux = n_output // 4 self.hop_length = hop_length - self.sample_rate = sample_rate + self.n_classes = n_classes total_scale = 1 for upsample_scale in upsample_scales: @@ -262,7 +249,12 @@ def __init__(self, if total_scale != self.hop_length: raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}") - self.upsample = _UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) + self.upsample = _UpsampleNetwork(upsample_scales, + n_res_block, + n_freq, + n_hidden, + n_output, + kernel_size) self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn) self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True) @@ -283,7 +275,7 @@ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: specgram: the input spectrogram to the _WaveRNN layer (n_batch, 1, n_freq, n_time) Return: - Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, 2 ** n_bits) + Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes) """ assert waveform.size(1) == 1, 'Require the input channel of waveform is 1'