diff --git a/test/torchaudio_unittest/models/wav2vec2/huggingface_intergration_test.py b/test/torchaudio_unittest/models/wav2vec2/huggingface_intergration_test.py index b140275844..2ded9896db 100644 --- a/test/torchaudio_unittest/models/wav2vec2/huggingface_intergration_test.py +++ b/test/torchaudio_unittest/models/wav2vec2/huggingface_intergration_test.py @@ -118,7 +118,7 @@ def _test_import_finetune(self, original, imported, config): # Readout x = torch.randn(3, 10, config["hidden_size"]) ref = original.lm_head(x) - hyp = imported.encoder.readout(x) + hyp = imported.aux(x) self.assertEqual(ref, hyp) # The whole model without mask x = torch.randn(3, 1024) @@ -195,8 +195,8 @@ def _test_recreate(self, imported, reloaded, config): self.assertEqual(ref, hyp) # Readout x = torch.randn(3, 10, config["hidden_size"]) - ref = imported.encoder.readout(x) - hyp = reloaded.encoder.readout(x) + ref = imported.aux(x) + hyp = reloaded.aux(x) self.assertEqual(ref, hyp) # The whole model x = torch.randn(3, 1024) @@ -208,7 +208,7 @@ def _test_recreate(self, imported, reloaded, config): def test_recreate_pretrain(self, config, factory_func): """Imported models can be recreated via a factory function without Hugging Face transformers.""" imported = import_huggingface_model(self._get_model(config)).eval() - reloaded = factory_func(num_out=imported.encoder.readout.out_features) + reloaded = factory_func(num_out=imported.aux.out_features) reloaded.load_state_dict(imported.state_dict()) reloaded.eval() self._test_recreate(imported, reloaded, config) @@ -217,7 +217,7 @@ def test_recreate_pretrain(self, config, factory_func): def test_recreate_finetune(self, config, factory_func): """Imported models can be recreated via a factory function without Hugging Face transformers.""" imported = import_huggingface_model(self._get_model(config)).eval() - reloaded = factory_func(num_out=imported.encoder.readout.out_features) + reloaded = factory_func(num_out=imported.aux.out_features) reloaded.load_state_dict(imported.state_dict()) reloaded.eval() self._test_recreate(imported, reloaded, config) diff --git a/torchaudio/models/wav2vec2/components.py b/torchaudio/models/wav2vec2/components.py index 85df1c2e2b..58dcd16333 100644 --- a/torchaudio/models/wav2vec2/components.py +++ b/torchaudio/models/wav2vec2/components.py @@ -426,12 +426,10 @@ def __init__( self, feature_projection: Module, transformer: Module, - readout: Module, ): super().__init__() self.feature_projection = feature_projection self.transformer = transformer - self.readout = readout def _preprocess( self, @@ -458,7 +456,6 @@ def forward( ) -> Tensor: x, mask = self._preprocess(features, lengths) x = self.transformer(x, attention_mask=mask) - x = self.readout(x) return x def extract_features( @@ -561,7 +558,6 @@ def _get_encoder( dropout: float, layer_norm_first: bool, layer_drop: float, - num_out: int, ) -> Encoder: """ Args: @@ -720,8 +716,4 @@ def _get_encoder( layer_norm_first=not layer_norm_first, layer_drop=layer_drop, ) - readout = nn.Linear( - in_features=embed_dim, - out_features=num_out, - ) - return Encoder(feature_projection, transformer, readout) + return Encoder(feature_projection, transformer) diff --git a/torchaudio/models/wav2vec2/model.py b/torchaudio/models/wav2vec2/model.py index d45f6dc082..3d896956f9 100644 --- a/torchaudio/models/wav2vec2/model.py +++ b/torchaudio/models/wav2vec2/model.py @@ -20,15 +20,20 @@ class Wav2Vec2Model(Module): encoder (torch.nn.Module): Encoder that converts the audio features into the sequence of probability distribution (in negative log-likelihood) over labels. + + aux (torch.nn.Module or None, optional): + Auxiliary module. If provided, the output from encoder is passed to this module. """ def __init__( self, feature_extractor: Module, encoder: Module, + aux: Optional[Module] = None, ): super().__init__() self.feature_extractor = feature_extractor self.encoder = encoder + self.aux = aux @torch.jit.export def extract_features( @@ -89,7 +94,10 @@ def forward( Shape: ``(batch, )``. """ x, lengths = self.feature_extractor(waveforms, lengths) - return self.encoder(x, lengths), lengths + x = self.encoder(x, lengths) + if self.aux is not None: + x = self.aux(x) + return x, lengths def _get_model( @@ -108,7 +116,7 @@ def _get_model( encoder_dropout: float, encoder_layer_norm_first: bool, encoder_layer_drop: float, - encoder_num_out: int, + aux_num_out: int, ) -> Wav2Vec2Model: if extractor_conv_layer_config is None: extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 @@ -129,9 +137,12 @@ def _get_model( dropout=encoder_dropout, layer_norm_first=encoder_layer_norm_first, layer_drop=encoder_layer_drop, - num_out=encoder_num_out, ) - return Wav2Vec2Model(feature_extractor, encoder) + aux = torch.nn.Linear( + in_features=encoder_embed_dim, + out_features=aux_num_out, + ) + return Wav2Vec2Model(feature_extractor, encoder, aux) def wav2vec2_base(num_out: int) -> Wav2Vec2Model: @@ -172,7 +183,7 @@ def wav2vec2_base(num_out: int) -> Wav2Vec2Model: encoder_dropout=0.1, encoder_layer_norm_first=False, encoder_layer_drop=0.1, - encoder_num_out=num_out, + aux_num_out=num_out, ) @@ -214,7 +225,7 @@ def wav2vec2_large(num_out: int) -> Wav2Vec2Model: encoder_dropout=0.1, encoder_layer_norm_first=False, encoder_layer_drop=0.1, - encoder_num_out=num_out, + aux_num_out=num_out, ) @@ -256,5 +267,5 @@ def wav2vec2_large_lv60k(num_out: int) -> Wav2Vec2Model: encoder_dropout=0.0, encoder_layer_norm_first=True, encoder_layer_drop=0.1, - encoder_num_out=num_out, + aux_num_out=num_out, ) diff --git a/torchaudio/models/wav2vec2/utils/import_fairseq.py b/torchaudio/models/wav2vec2/utils/import_fairseq.py index 64f6389a3a..afbdd68648 100644 --- a/torchaudio/models/wav2vec2/utils/import_fairseq.py +++ b/torchaudio/models/wav2vec2/utils/import_fairseq.py @@ -46,7 +46,7 @@ def _parse_config(w2v_model, num_out): 'encoder_dropout': encoder.layers[0].dropout3.p, 'encoder_layer_norm_first': encoder.layer_norm_first, 'encoder_layer_drop': encoder.layerdrop, - 'encoder_num_out': num_out, + 'aux_num_out': num_out, } return config @@ -110,7 +110,7 @@ def _map_key(key): match = re.match(r"proj\.(weight|bias)", key) # Encoder - Readout layer if match: - return f"encoder.readout.{match.group(1)}" + return f"aux.{match.group(1)}" raise ValueError(f'Unexpected key: {key_}') diff --git a/torchaudio/models/wav2vec2/utils/import_huggingface.py b/torchaudio/models/wav2vec2/utils/import_huggingface.py index e9bfa5138b..fd66b1ff02 100644 --- a/torchaudio/models/wav2vec2/utils/import_huggingface.py +++ b/torchaudio/models/wav2vec2/utils/import_huggingface.py @@ -26,7 +26,7 @@ def _get_config(cfg): 'encoder_dropout': cfg.hidden_dropout, 'encoder_layer_norm_first': cfg.do_stable_layer_norm, 'encoder_layer_drop': cfg.layerdrop, - 'encoder_num_out': cfg.vocab_size, + 'aux_num_out': cfg.vocab_size, } return config @@ -42,7 +42,7 @@ def _build(config, original): imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict()) imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict()) if original.__class__.__name__ == 'Wav2Vec2ForCTC': - imported.encoder.readout.load_state_dict(original.lm_head.state_dict()) + imported.aux.load_state_dict(original.lm_head.state_dict()) return imported