From 825925a1e7fdb9cad474a0cfb9924e4301b27ff3 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 21 Oct 2021 11:32:07 -0400 Subject: [PATCH 1/2] [BC-breaking] Remove unused dimension from pretrained Wav2Vec2 ASR The Wav2Vec2 ASR pretrained weights originated from fairseq have extra dimention that have nothing to do with the ASR task. This change removes it. --- test/integration_tests/conftest.py | 8 ++--- torchaudio/pipelines/_wav2vec2.py | 51 +++++++++++++++--------------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/test/integration_tests/conftest.py b/test/integration_tests/conftest.py index 66adda5a5d..7f6a61f5da 100644 --- a/test/integration_tests/conftest.py +++ b/test/integration_tests/conftest.py @@ -4,8 +4,9 @@ class GreedyCTCDecoder(torch.nn.Module): - def __init__(self, labels): + def __init__(self, labels, blank: int = 0): super().__init__() + self.blank = blank self.labels = labels def forward(self, logits: torch.Tensor) -> str: @@ -21,9 +22,8 @@ def forward(self, logits: torch.Tensor) -> str: best_path = torch.unique_consecutive(best_path, dim=-1) hypothesis = [] for i in best_path: - char = self.labels[i] - if char not in ['', '']: - hypothesis.append(char) + if i != self.blank: + hypothesis.append(self.labels[i]) return ''.join(hypothesis) diff --git a/torchaudio/pipelines/_wav2vec2.py b/torchaudio/pipelines/_wav2vec2.py index 6958608455..6d5a02ddf6 100644 --- a/torchaudio/pipelines/_wav2vec2.py +++ b/torchaudio/pipelines/_wav2vec2.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import Dict, Tuple, Any +import torch from torch.hub import load_state_dict_from_url from torchaudio.models import wav2vec2_model, Wav2Vec2Model @@ -68,6 +69,14 @@ def get_model(self, *, dl_kwargs=None) -> Wav2Vec2Model: url = f'https://download.pytorch.org/torchaudio/models/{self._path}' dl_kwargs = {} if dl_kwargs is None else dl_kwargs state_dict = load_state_dict_from_url(url, **dl_kwargs) + + if model.aux is not None: + # For ASR task, the parameter originated from fairseq has unrelated dimensions at index 1, 2, 3 + # It's originated from fairseq but not used, so we remove it here. + # https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37 + for key in ['aux.weight', 'aux.bias']: + t = state_dict[key] + state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in (1, 2, 3)]) model.load_state_dict(state_dict) model.eval() return model @@ -102,7 +111,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): >>> # Check the corresponding labels of the output. >>> labels = bundle.get_labels() >>> print(labels) - ('', '', '', '', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') + ('', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') >>> >>> # Resample audio to the expected sampling rate >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate) @@ -119,20 +128,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): def get_labels( self, *, - bos: str = '', - pad: str = '', - eos: str = '', - unk: str = '', + blank: str = '', ) -> Tuple[str]: """The output class labels (only applicable to fine-tuned bundles) - The first four tokens are BOS, padding, EOS and UNK tokens and they can be customized. + The first is blank token, and it is customizable. Args: - bos (str, optional): Beginning of sentence token. (default: ``''``) - pad (str, optional): Padding token. (default: ``''``) - eos (str, optional): End of sentence token. (default: ``''``) - unk (str, optional): Token for unknown class. (default: ``''``) + blank (str, optional): Blank token. (default: ``''``) Returns: Tuple[str]: @@ -142,11 +145,9 @@ def get_labels( Example >>> import torchaudio >>> torchaudio.models.HUBERT_ASR_LARGE.get_labels() - ('', '', '', '', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') + ('', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') """ # noqa: E501 - if self._labels is None: - raise ValueError('Pre-trained models do not have labels.') - return (bos, pad, eos, unk, *self._labels) + return (blank, *self._labels) def _get_labels(): @@ -252,7 +253,7 @@ def _get_labels(): 'encoder_dropout': 0.1, 'encoder_layer_norm_first': False, 'encoder_layer_drop': 0.05, - "aux_num_out": 32, + "aux_num_out": 29, }, _labels=_get_labels(), _sample_rate=16000, @@ -298,7 +299,7 @@ def _get_labels(): 'encoder_dropout': 0.1, 'encoder_layer_norm_first': False, 'encoder_layer_drop': 0.05, - "aux_num_out": 32, + "aux_num_out": 29, }, _labels=_get_labels(), _sample_rate=16000, @@ -344,7 +345,7 @@ def _get_labels(): "encoder_dropout": 0.1, "encoder_layer_norm_first": False, "encoder_layer_drop": 0.05, - "aux_num_out": 32, + "aux_num_out": 29, }, _labels=_get_labels(), _sample_rate=16000, @@ -433,7 +434,7 @@ def _get_labels(): "encoder_dropout": 0.0, "encoder_layer_norm_first": False, "encoder_layer_drop": 0.2, - "aux_num_out": 32, + "aux_num_out": 29, }, _labels=_get_labels(), _sample_rate=16000, @@ -479,7 +480,7 @@ def _get_labels(): "encoder_dropout": 0.0, "encoder_layer_norm_first": False, "encoder_layer_drop": 0.2, - "aux_num_out": 32, + "aux_num_out": 29, }, _labels=_get_labels(), _sample_rate=16000, @@ -525,7 +526,7 @@ def _get_labels(): "encoder_dropout": 0.0, "encoder_layer_norm_first": False, "encoder_layer_drop": 0.2, - "aux_num_out": 32, + "aux_num_out": 29, }, _labels=_get_labels(), _sample_rate=16000, @@ -614,7 +615,7 @@ def _get_labels(): "encoder_dropout": 0.0, "encoder_layer_norm_first": True, "encoder_layer_drop": 0.0, - "aux_num_out": 32, + "aux_num_out": 29, }, _labels=_get_labels(), _sample_rate=16000, @@ -660,7 +661,7 @@ def _get_labels(): "encoder_dropout": 0.0, "encoder_layer_norm_first": True, "encoder_layer_drop": 0.0, - "aux_num_out": 32, + "aux_num_out": 29, }, _labels=_get_labels(), _sample_rate=16000, @@ -706,7 +707,7 @@ def _get_labels(): "encoder_dropout": 0.0, "encoder_layer_norm_first": True, "encoder_layer_drop": 0.0, - "aux_num_out": 32, + "aux_num_out": 29, }, _labels=_get_labels(), _sample_rate=16000, @@ -932,7 +933,7 @@ def _get_labels(): 'encoder_dropout': 0.0, 'encoder_layer_norm_first': True, 'encoder_layer_drop': 0.1, - 'aux_num_out': 32, + 'aux_num_out': 29, }, _labels=_get_labels(), _sample_rate=16000, @@ -979,7 +980,7 @@ def _get_labels(): 'encoder_dropout': 0.0, 'encoder_layer_norm_first': True, 'encoder_layer_drop': 0.1, - 'aux_num_out': 32, + 'aux_num_out': 29, }, _labels=_get_labels(), _sample_rate=16000, From a51c39507cd9ea4eb2e996e3756f4681ebc2db33 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 21 Oct 2021 14:25:17 -0400 Subject: [PATCH 2/2] Use '-' for blank --- torchaudio/pipelines/_wav2vec2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchaudio/pipelines/_wav2vec2.py b/torchaudio/pipelines/_wav2vec2.py index 6d5a02ddf6..bde558faa5 100644 --- a/torchaudio/pipelines/_wav2vec2.py +++ b/torchaudio/pipelines/_wav2vec2.py @@ -69,7 +69,7 @@ def get_model(self, *, dl_kwargs=None) -> Wav2Vec2Model: url = f'https://download.pytorch.org/torchaudio/models/{self._path}' dl_kwargs = {} if dl_kwargs is None else dl_kwargs state_dict = load_state_dict_from_url(url, **dl_kwargs) - + if model.aux is not None: # For ASR task, the parameter originated from fairseq has unrelated dimensions at index 1, 2, 3 # It's originated from fairseq but not used, so we remove it here. @@ -111,7 +111,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): >>> # Check the corresponding labels of the output. >>> labels = bundle.get_labels() >>> print(labels) - ('', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') + ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') >>> >>> # Resample audio to the expected sampling rate >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate) @@ -128,14 +128,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): def get_labels( self, *, - blank: str = '', + blank: str = '-', ) -> Tuple[str]: """The output class labels (only applicable to fine-tuned bundles) The first is blank token, and it is customizable. Args: - blank (str, optional): Blank token. (default: ``''``) + blank (str, optional): Blank token. (default: ``'-'``) Returns: Tuple[str]: @@ -145,7 +145,7 @@ def get_labels( Example >>> import torchaudio >>> torchaudio.models.HUBERT_ASR_LARGE.get_labels() - ('', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') + ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') """ # noqa: E501 return (blank, *self._labels)