-
Notifications
You must be signed in to change notification settings - Fork 738
Add Tacotron2 model #1621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
yangarbiter
merged 11 commits into
pytorch:master
from
yangarbiter:port_tacotron2_model
Jul 20, 2021
Merged
Add Tacotron2 model #1621
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
e129b8b
Add Tacotron2 model
yangarbiter 4c9ca00
Refactor tacotron2 tests and inherite style
yangarbiter 88242a3
Move tacotrom2 to prototype
yangarbiter 5fccd01
make the variable name more consistent
yangarbiter ba3dd48
move flatten_parameters to __init__ for torchscriptability
yangarbiter 68b1273
fix some coding styles
yangarbiter 318b977
refactor some more code
yangarbiter 154bb5d
reformat all new files with black
yangarbiter 5d7c7fb
fix backward test and some coding style
yangarbiter 8dcb175
More style changes
yangarbiter 6a3d32d
Change n_mel back to n_mels to be consistent with other functions
yangarbiter File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
23 changes: 23 additions & 0 deletions
23
test/torchaudio_unittest/models/tacotron2/model_test_cpu_test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| import torch | ||
|
|
||
| from torchaudio_unittest.common_utils import PytorchTestCase | ||
| from .model_test_impl import ( | ||
| Tacotron2EncoderTests, | ||
| Tacotron2DecoderTests, | ||
| Tacotron2Tests, | ||
| ) | ||
|
|
||
|
|
||
| class TestTacotron2EncoderFloat32CPU(Tacotron2EncoderTests, PytorchTestCase): | ||
| dtype = torch.float32 | ||
| device = torch.device("cpu") | ||
|
|
||
|
|
||
| class TestTacotron2DecoderFloat32CPU(Tacotron2DecoderTests, PytorchTestCase): | ||
| dtype = torch.float32 | ||
| device = torch.device("cpu") | ||
|
|
||
|
|
||
| class TestTacotron2Float32CPU(Tacotron2Tests, PytorchTestCase): | ||
| dtype = torch.float32 | ||
| device = torch.device("cpu") |
26 changes: 26 additions & 0 deletions
26
test/torchaudio_unittest/models/tacotron2/model_test_gpu_test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| import torch | ||
|
|
||
| from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase | ||
| from .model_test_impl import ( | ||
| Tacotron2EncoderTests, | ||
| Tacotron2DecoderTests, | ||
| Tacotron2Tests, | ||
| ) | ||
|
|
||
|
|
||
| @skipIfNoCuda | ||
| class TestTacotron2EncoderFloat32CUDA(Tacotron2EncoderTests, PytorchTestCase): | ||
| dtype = torch.float32 | ||
| device = torch.device("cuda") | ||
|
|
||
|
|
||
| @skipIfNoCuda | ||
| class TestTacotron2DecoderFloat32CUDA(Tacotron2DecoderTests, PytorchTestCase): | ||
| dtype = torch.float32 | ||
| device = torch.device("cuda") | ||
|
|
||
|
|
||
| @skipIfNoCuda | ||
| class TestTacotron2Float32CUDA(Tacotron2Tests, PytorchTestCase): | ||
| dtype = torch.float32 | ||
| device = torch.device("cuda") |
238 changes: 238 additions & 0 deletions
238
test/torchaudio_unittest/models/tacotron2/model_test_impl.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,238 @@ | ||
| import torch | ||
| from torchaudio.prototype.tacotron2 import Tacotron2, _Encoder, _Decoder | ||
| from torchaudio_unittest.common_utils import ( | ||
| TestBaseMixin, | ||
| TempDirMixin, | ||
| ) | ||
|
|
||
|
|
||
| class TorchscriptConsistencyMixin(TempDirMixin): | ||
| r"""Mixin to provide easy access assert torchscript consistency""" | ||
|
|
||
| def _assert_torchscript_consistency(self, model, tensors): | ||
| path = self.get_temp_path("func.zip") | ||
| torch.jit.script(model).save(path) | ||
| ts_func = torch.jit.load(path) | ||
|
|
||
| torch.random.manual_seed(40) | ||
| output = model(*tensors) | ||
|
|
||
| torch.random.manual_seed(40) | ||
| ts_output = ts_func(*tensors) | ||
|
|
||
| self.assertEqual(ts_output, output) | ||
|
|
||
|
|
||
| class Tacotron2EncoderTests(TestBaseMixin, TorchscriptConsistencyMixin): | ||
| def test_tacotron2_torchscript_consistency(self): | ||
| r"""Validate the torchscript consistency of a Encoder.""" | ||
| n_batch, n_seq, encoder_embedding_dim = 16, 64, 512 | ||
| model = _Encoder(encoder_embedding_dim=encoder_embedding_dim, | ||
| encoder_n_convolution=3, | ||
| encoder_kernel_size=5).to(self.device).eval() | ||
|
|
||
| x = torch.rand( | ||
| n_batch, encoder_embedding_dim, n_seq, device=self.device, dtype=self.dtype | ||
| ) | ||
| input_lengths = ( | ||
| torch.ones(n_batch, device=self.device, dtype=torch.int32) * n_seq | ||
| ) | ||
|
|
||
| self._assert_torchscript_consistency(model, (x, input_lengths)) | ||
|
|
||
| def test_encoder_output_shape(self): | ||
| r"""Feed tensors with specific shape to Tacotron2 Decoder and validate | ||
| that it outputs with a tensor with expected shape. | ||
| """ | ||
| n_batch, n_seq, encoder_embedding_dim = 16, 64, 512 | ||
| model = _Encoder(encoder_embedding_dim=encoder_embedding_dim, | ||
| encoder_n_convolution=3, | ||
| encoder_kernel_size=5).to(self.device).eval() | ||
|
|
||
| x = torch.rand( | ||
| n_batch, encoder_embedding_dim, n_seq, device=self.device, dtype=self.dtype | ||
| ) | ||
| input_lengths = ( | ||
| torch.ones(n_batch, device=self.device, dtype=torch.int32) * n_seq | ||
| ) | ||
| out = model(x, input_lengths) | ||
|
|
||
| assert out.size() == (n_batch, n_seq, encoder_embedding_dim) | ||
|
|
||
|
|
||
| def _get_decoder_model(n_mels=80, encoder_embedding_dim=512): | ||
| model = _Decoder( | ||
| n_mels=n_mels, | ||
| n_frames_per_step=1, | ||
| encoder_embedding_dim=encoder_embedding_dim, | ||
| decoder_rnn_dim=1024, | ||
| decoder_max_step=2000, | ||
| decoder_dropout=0.1, | ||
| decoder_early_stopping=False, | ||
| attention_rnn_dim=1024, | ||
| attention_hidden_dim=128, | ||
| attention_location_n_filter=32, | ||
| attention_location_kernel_size=31, | ||
| attention_dropout=0.1, | ||
| prenet_dim=256, | ||
| gate_threshold=0.5, | ||
| ) | ||
| return model | ||
|
|
||
|
|
||
| class Tacotron2DecoderTests(TestBaseMixin, TorchscriptConsistencyMixin): | ||
| def test_decoder_torchscript_consistency(self): | ||
| r"""Validate the torchscript consistency of a Decoder.""" | ||
| n_batch = 16 | ||
| n_mels = 80 | ||
| n_seq = 200 | ||
| encoder_embedding_dim = 256 | ||
| n_time_steps = 150 | ||
|
|
||
| model = _get_decoder_model(n_mels=n_mels, encoder_embedding_dim=encoder_embedding_dim) | ||
| model = model.to(self.device).eval() | ||
|
|
||
| memory = torch.rand( | ||
| n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device | ||
| ) | ||
| decoder_inputs = torch.rand( | ||
| n_batch, n_mels, n_time_steps, dtype=self.dtype, device=self.device | ||
| ) | ||
| memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device) | ||
|
|
||
| self._assert_torchscript_consistency( | ||
| model, (memory, decoder_inputs, memory_lengths) | ||
| ) | ||
|
|
||
| def test_decoder_output_shape(self): | ||
| r"""Feed tensors with specific shape to Tacotron2 Decoder and validate | ||
| that it outputs with a tensor with expected shape. | ||
| """ | ||
| n_batch = 16 | ||
| n_mels = 80 | ||
| n_seq = 200 | ||
| encoder_embedding_dim = 256 | ||
| n_time_steps = 150 | ||
|
|
||
| model = _get_decoder_model(n_mels=n_mels, encoder_embedding_dim=encoder_embedding_dim) | ||
| model = model.to(self.device).eval() | ||
|
|
||
| memory = torch.rand( | ||
| n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device | ||
| ) | ||
| decoder_inputs = torch.rand( | ||
| n_batch, n_mels, n_time_steps, dtype=self.dtype, device=self.device | ||
| ) | ||
| memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device) | ||
|
|
||
| mel_outputs, gate_outputs, alignments = model( | ||
| memory, decoder_inputs, memory_lengths | ||
| ) | ||
|
|
||
| assert mel_outputs.size() == (n_batch, n_mels, n_time_steps) | ||
| assert gate_outputs.size() == (n_batch, n_time_steps) | ||
| assert alignments.size() == (n_batch, n_time_steps, n_seq) | ||
|
|
||
|
|
||
| def _get_tacotron2_model(n_mels): | ||
| return Tacotron2( | ||
| mask_padding=False, | ||
| n_mels=n_mels, | ||
| n_symbol=148, | ||
| n_frames_per_step=1, | ||
| symbol_embedding_dim=512, | ||
| encoder_embedding_dim=512, | ||
| encoder_n_convolution=3, | ||
| encoder_kernel_size=5, | ||
| decoder_rnn_dim=1024, | ||
| decoder_max_step=2000, | ||
| decoder_dropout=0.1, | ||
| decoder_early_stopping=True, | ||
| attention_rnn_dim=1024, | ||
| attention_hidden_dim=128, | ||
| attention_location_n_filter=32, | ||
| attention_location_kernel_size=31, | ||
| attention_dropout=0.1, | ||
| prenet_dim=256, | ||
| postnet_n_convolution=5, | ||
| postnet_kernel_size=5, | ||
| postnet_embedding_dim=512, | ||
| gate_threshold=0.5, | ||
| ) | ||
|
|
||
|
|
||
| class Tacotron2Tests(TestBaseMixin, TorchscriptConsistencyMixin): | ||
| def _get_inputs( | ||
| self, n_mels, n_batch: int, max_mel_specgram_length: int, max_text_length: int | ||
| ): | ||
| text = torch.randint( | ||
| 0, 148, (n_batch, max_text_length), dtype=torch.int32, device=self.device | ||
| ) | ||
| text_lengths = max_text_length * torch.ones( | ||
| (n_batch,), dtype=torch.int32, device=self.device | ||
| ) | ||
| mel_specgram = torch.rand( | ||
| n_batch, | ||
| n_mels, | ||
| max_mel_specgram_length, | ||
| dtype=self.dtype, | ||
| device=self.device, | ||
| ) | ||
| mel_specgram_lengths = max_mel_specgram_length * torch.ones( | ||
| (n_batch,), dtype=torch.int32, device=self.device | ||
| ) | ||
| return text, text_lengths, mel_specgram, mel_specgram_lengths | ||
|
|
||
| def test_tacotron2_torchscript_consistency(self): | ||
| r"""Validate the torchscript consistency of a Tacotron2.""" | ||
| n_batch = 16 | ||
| n_mels = 80 | ||
| max_mel_specgram_length = 300 | ||
| max_text_length = 100 | ||
|
|
||
| model = _get_tacotron2_model(n_mels).to(self.device).eval() | ||
| inputs = self._get_inputs( | ||
| n_mels, n_batch, max_mel_specgram_length, max_text_length | ||
| ) | ||
|
|
||
| self._assert_torchscript_consistency(model, inputs) | ||
|
|
||
| def test_tacotron2_output_shape(self): | ||
| r"""Feed tensors with specific shape to Tacotron2 and validate | ||
| that it outputs with a tensor with expected shape. | ||
| """ | ||
| n_batch = 16 | ||
| n_mels = 80 | ||
| max_mel_specgram_length = 300 | ||
| max_text_length = 100 | ||
|
|
||
| model = _get_tacotron2_model(n_mels).to(self.device).eval() | ||
| inputs = self._get_inputs( | ||
| n_mels, n_batch, max_mel_specgram_length, max_text_length | ||
| ) | ||
| mel_out, mel_out_postnet, gate_outputs, alignments = model(*inputs) | ||
|
|
||
| assert mel_out.size() == (n_batch, n_mels, max_mel_specgram_length) | ||
| assert mel_out_postnet.size() == (n_batch, n_mels, max_mel_specgram_length) | ||
| assert gate_outputs.size() == (n_batch, max_mel_specgram_length) | ||
| assert alignments.size() == (n_batch, max_mel_specgram_length, max_text_length) | ||
|
|
||
| def test_tacotron2_backward(self): | ||
| r"""Make sure calling the backward function on Tacotron2's outputs does | ||
| not error out. Following: | ||
| https://github.com/pytorch/vision/blob/23b8760374a5aaed53c6e5fc83a7e83dbe3b85df/test/test_models.py#L255 | ||
| """ | ||
| n_batch = 16 | ||
| n_mels = 80 | ||
| max_mel_specgram_length = 300 | ||
| max_text_length = 100 | ||
|
|
||
| model = _get_tacotron2_model(n_mels).to(self.device) | ||
| inputs = self._get_inputs( | ||
| n_mels, n_batch, max_mel_specgram_length, max_text_length | ||
| ) | ||
| mel_out, mel_out_postnet, gate_outputs, _ = model(*inputs) | ||
|
|
||
| mel_out.sum().backward(retain_graph=True) | ||
| mel_out_postnet.sum().backward(retain_graph=True) | ||
| gate_outputs.sum().backward() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: If we decide to land loss function in library, then we can use it here to mimic the expected use case more closely. |
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like
docs/source/models.rstis not updated.But I cannot tell if
Tacotoron2class is public or private.If it's public, it has to be added to
models.rst.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We plan to make it public after we include the training pipeline in the examples. If this is the case, should I remove this for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, it's fine. It's referred from the docstring.
You can update
models.rstwhen you make the model public.