Skip to content

Conversation

@yangarbiter
Copy link
Contributor

@yangarbiter yangarbiter commented Jul 6, 2021

Closes #855.

Porting Tacotron2 from Nvidia.

Here are the differences between Nvidia's and this version:

  • These functions and classes are moved into tacotron2.py instead of separated files.
from tacotron2_common.layers import ConvNorm, LinearNorm
from tacotron2_common.utils import to_gpu, get_mask_from_lengths
  • The forward function of the original Encoder is not torchscriptable due to this line. But removing it does not seem to affect the result. So the current version of forward for Encoder is torchscriptable.

  • New tests added.

The tests are similar to the ones in #855, which tests the output dimension of Encoder, Decoder, and Tacotron2. I've also created separate tests for GPU.

I've also added a torchscript consistency test for the Decoder. Encoder and Tacotron2 are not torchscriptable because they've called the flatten_parameters function and it is not torchscriptable (pytorch/pytorch#46375). After checking with @gibiansky, flatten_parameters can be moved to init so now Encoder and Tacotron2 are torchscriptable.

CC @vincentqb @mthrok @carolineechen. It would be great if you can leave some comments. Thanks in advance.

@yangarbiter yangarbiter changed the title add Tacotron2 model Add Tacotron2 model Jul 6, 2021
@yangarbiter yangarbiter mentioned this pull request Jul 6, 2021
@yangarbiter yangarbiter force-pushed the port_tacotron2_model branch 3 times, most recently from dfe5d71 to 35218f7 Compare July 7, 2021 05:04
@yangarbiter yangarbiter requested a review from vincentqb July 7, 2021 05:05
@yangarbiter yangarbiter force-pushed the port_tacotron2_model branch 7 times, most recently from 470d0d9 to 4e19648 Compare July 7, 2021 22:14
@yangarbiter yangarbiter force-pushed the port_tacotron2_model branch from 4e19648 to e129b8b Compare July 7, 2021 22:20
Copy link
Contributor

@mthrok mthrok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a quick review.

'wav2vec2_base',
'wav2vec2_large',
'wav2vec2_large_lv60k',
'_Tacotron2',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why exporting the class suffixed with underscore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously I use underscore to represent prototype.
I've moved Tacotron2 to the torchaudio.prototype directory to avoid confusion.
Please let me know if this is the right place to put at this stage.
Thanks.



__all__ = [
"_Tacotron2",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks strange to have the class marked private in __all__.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, we want to keep it private because we only have the model in this PR but no training/inference pipeline yet. After several more PRs when we have everything merged, we will change Tacotron2 from private to public class.

Copy link
Contributor Author

@yangarbiter yangarbiter Jul 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above (#1621 (comment)).
I've moved Tacotron2 to the torchaudio.prototype directory to avoid confusion.
Please let me know if this is the right place to put at this stage.
Thanks.

assert out.size() == (n_batch, n_seq, encoder_embedding_dim)

@parameterized.expand([(32, 64, 512, torch.float32, torch.device('cpu'))])
def test_cpu_encoder_output(self, n_batch, n_seq, encoder_embedding_dim, dtype, device):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please write a short script of what is "intended" in this test.
For example; "A should meet condition B".

Running pytest --collect-only will give this one-liner as the summary of the tests.

Copy link
Contributor Author

@yangarbiter yangarbiter Jul 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added more comments for each of the test functions here, but I am not sure why they didn't show up on my pytest --collect-only.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offline chat with @mthrok. pytest --collect-only -v will do the job.


@parameterized.expand([(32, 64, 512, torch.float32, torch.device('cuda'))])
@skipIfNoCuda
def test_gpu_encoder_output(self, n_batch, n_seq, encoder_embedding_dim, dtype, device):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please split CPU tests and CUDA tests into separate files. Currently we do not run CUDA tests in fb infra and skipping will trigger unnecessary alerts.

Copy link
Contributor Author

@yangarbiter yangarbiter Jul 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_seq,
device,
dtype):
"""Validate the output dimensions of a Decoder.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you replace "validate" with something more descriptive?

What conditions output dimensions should be satisfying?

Copy link
Contributor Author

@yangarbiter yangarbiter Jul 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed it to

        """Feed tensors with specific shape to Tacotron2 Decoder and validate
        that it outputs with a tensor with expected shape.
        """

Please let me know if this is clear enough or more detailed comments should be written.
Thanks.

eprint={1712.05884},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like docs/source/models.rst is not updated.
But I cannot tell if Tacotoron2 class is public or private.
If it's public, it has to be added to models.rst.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We plan to make it public after we include the training pipeline in the examples. If this is the case, should I remove this for now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, it's fine. It's referred from the docstring.
You can update models.rst when you make the model public.

@yangarbiter yangarbiter force-pushed the port_tacotron2_model branch 2 times, most recently from 105f596 to ade6ecc Compare July 8, 2021 17:50
@yangarbiter yangarbiter force-pushed the port_tacotron2_model branch from ade6ecc to 4c9ca00 Compare July 8, 2021 17:59
@yangarbiter yangarbiter force-pushed the port_tacotron2_model branch from c6f30a3 to 88242a3 Compare July 8, 2021 19:45
with shape (n_batch, mel_specgram_lengths.max(), text_lengths.max()).
"""

text_lengths, mel_specgram_lengths = text_lengths.data, mel_specgram_lengths.data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is .data needed here? in general, we should never need to call .data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in the new commit.
Thanks.

Comment on lines 796 to 798
def _parse_output(self,
outputs: Tuple[Tensor, Tensor, Tensor, Tensor],
output_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is a single purpose function, i'd recommend moving this code in place, with a comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in the new commit.
Thanks.

Comment on lines 823 to 826
text (Tensor): The input text to Tacotron2 with shape (n_batch, text_lengths.max()).
text_lengths (Tensor): The length of each text with shape (n_batch).
mel_specgram (Tensor): The target mel spectrogram with shape (n_batch, n_mels, mel_specgram_lengths.max()).
mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape (n_batch).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
text (Tensor): The input text to Tacotron2 with shape (n_batch, text_lengths.max()).
text_lengths (Tensor): The length of each text with shape (n_batch).
mel_specgram (Tensor): The target mel spectrogram with shape (n_batch, n_mels, mel_specgram_lengths.max()).
mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape (n_batch).
text (Tensor): The input text to Tacotron2 with shape (n_batch, max of text_lengths).
text_lengths (Tensor): The length of each text with shape (n_batch).
mel_specgram (Tensor): The target mel spectrogram with shape (n_batch, n_mel, max of mel_specgram_lengths).
mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape (n_batch).

or follow convention here for instance. thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eprint={1712.05884},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, it's fine. It's referred from the docstring.
You can update models.rst when you make the model public.


class Tacotron2EncoderTests(TestBaseMixin, TorchscriptConsistencyMixin):

def _get_model(self, encoder_embedding_dim):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would simply place the construction of _Encode in each method as this helper method does nothing additional.

However, I do see that both use cases are followed by .to(self.device).eval(). Moving it inside of _get_model better justifies this helper method. (otherwise self is not used, so this can be plain function as well.)
Looking at the fact that xs are placed in a specific device/dtype in each method, I think performing the logic around device/dtype placement in the same function would be easier to grasp the test. So I recommend simply getting rid of _get_model method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out!
It is addressed here


class Tacotron2DecoderTests(TestBaseMixin, TorchscriptConsistencyMixin):

def _get_model(self, n_mel=80, encoder_embedding_dim=512):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is more justified than the case of _Encoder, but self is not used. So I would turn this into a plain function. (or class method, though it is super rare when class / static methods have to be used.)

PyLint would complain this; http://pylint-messages.wikidot.com/messages:r0201

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out!
It is addressed here.

]


class _LinearNorm(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this is being ported from Nvidia's code but I wonder if we can change this to a function. After all, what this class does is instantiate a class from PyTorch and initialize it with Xavier formula. This can be simply replaced with

def _get_linear_norm(...):
    linear = torch.nn.Linear(...)
    torch.nn.init.xavier_uniform_(linear.weight, ...)
    return linear

Of course, the resulting model structure is slightly different and if you have a pre-trained model already, then a model surgery will be necessary to use it.
But in my opinion, it is more maintainable if there is less custom layers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed it this way. Please let me know how you think :).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, please do the same for _ConvNorm and replace by _get_conv_norm with a similar structure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added here

return self.conv(signal)


def get_mask_from_lengths(lengths: Tensor) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since other non-public interfaces are prefixed with underscore, this function looks public API. (I am guessing it is not)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out. It's addressed here

decoder_input (Tensor): all zeros frames (n_batch, text_lengths.max(), n_mel * n_frames_per_step)
"""

B = memory.size(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we replace B with something more readable? something like batch_size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out. It's addressed here.

with shape (n_batch, text_lengths.max(), attention_hidden_dim).
"""
B = memory.size(0)
MAX_TIME = memory.size(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use lower case max_time?
All capitalized looks like this is global variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out. It's addressed here.

@yangarbiter yangarbiter force-pushed the port_tacotron2_model branch from c7a6aa9 to 1ec2255 Compare July 12, 2021 21:02
@yangarbiter yangarbiter force-pushed the port_tacotron2_model branch from 1ec2255 to 68b1273 Compare July 12, 2021 21:15
Comment on lines 273 to 300
self.convolutions.append(
nn.Sequential(
_ConvNorm(n_mel, postnet_embedding_dim,
kernel_size=postnet_kernel_size, stride=1,
padding=int((postnet_kernel_size - 1) / 2),
dilation=1, w_init_gain='tanh'),
nn.BatchNorm1d(postnet_embedding_dim))
)

for _ in range(1, postnet_n_convolution - 1):
self.convolutions.append(
nn.Sequential(
_ConvNorm(postnet_embedding_dim,
postnet_embedding_dim,
kernel_size=postnet_kernel_size, stride=1,
padding=int((postnet_kernel_size - 1) / 2),
dilation=1, w_init_gain='tanh'),
nn.BatchNorm1d(postnet_embedding_dim))
)

self.convolutions.append(
nn.Sequential(
_ConvNorm(postnet_embedding_dim, n_mel,
kernel_size=postnet_kernel_size, stride=1,
padding=int((postnet_kernel_size - 1) / 2),
dilation=1, w_init_gain='linear'),
nn.BatchNorm1d(n_mel))
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combining these blocks is more readable.

for i in range(postnet_n_convolution):
  in_channels = n_mel if i == 0 else postnet_embedding_dim
  out_channels = n_mel if i == postnet_n_convolution - 1 else postnet_embedding_dim
  init_gain = 'linear' if i == postnet_n_convolution - 1 else 'tanh'
  self.convolutions.append(
            nn.Sequential(
                _ConvNorm(in_channels, out_channels,
                          kernel_size=postnet_kernel_size, stride=1,
                          padding=int((postnet_kernel_size - 1) / 2),
                          dilation=1, w_init_gain='linear'),
                nn.BatchNorm1d(n_mel)
             )
        )

I'd also prefer a way of removing the sequential, but giving how it is used below, I'd be ok keeping it as is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Addressed here

) -> None:
super().__init__()

convolutions = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please create nn.ModuleList directly here, and this will remove the line where the list is converted to ModuleList.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Addressed here.

@yangarbiter yangarbiter force-pushed the port_tacotron2_model branch 2 times, most recently from 6290e10 to 06d23a5 Compare July 13, 2021 18:02
@yangarbiter yangarbiter force-pushed the port_tacotron2_model branch from 06d23a5 to 318b977 Compare July 13, 2021 18:07
Comment on lines 127 to 133
n_batch, n_mel, n_seq, encoder_embedding_dim, n_time_steps = (
16,
80,
200,
256,
150,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I recommend:

Suggested change
n_batch, n_mel, n_seq, encoder_embedding_dim, n_time_steps = (
16,
80,
200,
256,
150,
)
n_batch = 16
n_mel = 80
n_seq = 200
encoder_embedding_dim = 256
n_time_steps = 150

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Addressed here.

@yangarbiter yangarbiter force-pushed the port_tacotron2_model branch from 68ef62e to 335eef9 Compare July 13, 2021 18:47
@yangarbiter
Copy link
Contributor Author

Following @vincentqb suggestion, I've also run the code through black.

@yangarbiter yangarbiter force-pushed the port_tacotron2_model branch 5 times, most recently from 3ef0820 to 8ca29d6 Compare July 13, 2021 22:59
Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @mthrok do you have any other feedback?

@yangarbiter yangarbiter force-pushed the port_tacotron2_model branch from 8ca29d6 to 8dcb175 Compare July 14, 2021 20:02
@yangarbiter yangarbiter force-pushed the port_tacotron2_model branch from a587d67 to 6a3d32d Compare July 19, 2021 18:23

mel_out.sum().backward(retain_graph=True)
mel_out_postnet.sum().backward(retain_graph=True)
gate_outputs.sum().backward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: If we decide to land loss function in library, then we can use it here to mimic the expected use case more closely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants