Skip to content

Commit ec4837d

Browse files
authored
[BC-breaking] Remove unused dimension from pretrained Wav2Vec2 ASR (#1914)
* [BC-breaking] Remove unused dimension from pretrained Wav2Vec2 ASR The Wav2Vec2 ASR pretrained weights originated from fairseq have extra dimension that have nothing to do with the ASR task. https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L18-L37 which is masked during the loss computation as https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L128 This change removes it. * Use '-' for blank token representation.
1 parent ec12505 commit ec4837d

File tree

2 files changed

+30
-29
lines changed

2 files changed

+30
-29
lines changed

test/integration_tests/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55

66
class GreedyCTCDecoder(torch.nn.Module):
7-
def __init__(self, labels):
7+
def __init__(self, labels, blank: int = 0):
88
super().__init__()
9+
self.blank = blank
910
self.labels = labels
1011

1112
def forward(self, logits: torch.Tensor) -> str:
@@ -21,9 +22,8 @@ def forward(self, logits: torch.Tensor) -> str:
2122
best_path = torch.unique_consecutive(best_path, dim=-1)
2223
hypothesis = []
2324
for i in best_path:
24-
char = self.labels[i]
25-
if char not in ['<s>', '<pad>']:
26-
hypothesis.append(char)
25+
if i != self.blank:
26+
hypothesis.append(self.labels[i])
2727
return ''.join(hypothesis)
2828

2929

torchaudio/pipelines/_wav2vec2.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass
22
from typing import Dict, Tuple, Any
33

4+
import torch
45
from torch.hub import load_state_dict_from_url
56

67
from torchaudio.models import wav2vec2_model, Wav2Vec2Model
@@ -68,6 +69,14 @@ def get_model(self, *, dl_kwargs=None) -> Wav2Vec2Model:
6869
url = f'https://download.pytorch.org/torchaudio/models/{self._path}'
6970
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
7071
state_dict = load_state_dict_from_url(url, **dl_kwargs)
72+
73+
if model.aux is not None:
74+
# For ASR task, the parameter originated from fairseq has unrelated dimensions at index 1, 2, 3
75+
# It's originated from fairseq but not used, so we remove it here.
76+
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
77+
for key in ['aux.weight', 'aux.bias']:
78+
t = state_dict[key]
79+
state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in (1, 2, 3)])
7180
model.load_state_dict(state_dict)
7281
model.eval()
7382
return model
@@ -102,7 +111,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
102111
>>> # Check the corresponding labels of the output.
103112
>>> labels = bundle.get_labels()
104113
>>> print(labels)
105-
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
114+
('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
106115
>>>
107116
>>> # Resample audio to the expected sampling rate
108117
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
@@ -119,20 +128,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
119128
def get_labels(
120129
self,
121130
*,
122-
bos: str = '<s>',
123-
pad: str = '<pad>',
124-
eos: str = '</s>',
125-
unk: str = '<unk>',
131+
blank: str = '-',
126132
) -> Tuple[str]:
127133
"""The output class labels (only applicable to fine-tuned bundles)
128134
129-
The first four tokens are BOS, padding, EOS and UNK tokens and they can be customized.
135+
The first is blank token, and it is customizable.
130136
131137
Args:
132-
bos (str, optional): Beginning of sentence token. (default: ``'<s>'``)
133-
pad (str, optional): Padding token. (default: ``'<pad>'``)
134-
eos (str, optional): End of sentence token. (default: ``'</s>'``)
135-
unk (str, optional): Token for unknown class. (default: ``'<unk>'``)
138+
blank (str, optional): Blank token. (default: ``'-'``)
136139
137140
Returns:
138141
Tuple[str]:
@@ -142,11 +145,9 @@ def get_labels(
142145
Example
143146
>>> import torchaudio
144147
>>> torchaudio.models.HUBERT_ASR_LARGE.get_labels()
145-
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
148+
('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
146149
""" # noqa: E501
147-
if self._labels is None:
148-
raise ValueError('Pre-trained models do not have labels.')
149-
return (bos, pad, eos, unk, *self._labels)
150+
return (blank, *self._labels)
150151

151152

152153
def _get_labels():
@@ -252,7 +253,7 @@ def _get_labels():
252253
'encoder_dropout': 0.1,
253254
'encoder_layer_norm_first': False,
254255
'encoder_layer_drop': 0.05,
255-
"aux_num_out": 32,
256+
"aux_num_out": 29,
256257
},
257258
_labels=_get_labels(),
258259
_sample_rate=16000,
@@ -298,7 +299,7 @@ def _get_labels():
298299
'encoder_dropout': 0.1,
299300
'encoder_layer_norm_first': False,
300301
'encoder_layer_drop': 0.05,
301-
"aux_num_out": 32,
302+
"aux_num_out": 29,
302303
},
303304
_labels=_get_labels(),
304305
_sample_rate=16000,
@@ -344,7 +345,7 @@ def _get_labels():
344345
"encoder_dropout": 0.1,
345346
"encoder_layer_norm_first": False,
346347
"encoder_layer_drop": 0.05,
347-
"aux_num_out": 32,
348+
"aux_num_out": 29,
348349
},
349350
_labels=_get_labels(),
350351
_sample_rate=16000,
@@ -433,7 +434,7 @@ def _get_labels():
433434
"encoder_dropout": 0.0,
434435
"encoder_layer_norm_first": False,
435436
"encoder_layer_drop": 0.2,
436-
"aux_num_out": 32,
437+
"aux_num_out": 29,
437438
},
438439
_labels=_get_labels(),
439440
_sample_rate=16000,
@@ -479,7 +480,7 @@ def _get_labels():
479480
"encoder_dropout": 0.0,
480481
"encoder_layer_norm_first": False,
481482
"encoder_layer_drop": 0.2,
482-
"aux_num_out": 32,
483+
"aux_num_out": 29,
483484
},
484485
_labels=_get_labels(),
485486
_sample_rate=16000,
@@ -525,7 +526,7 @@ def _get_labels():
525526
"encoder_dropout": 0.0,
526527
"encoder_layer_norm_first": False,
527528
"encoder_layer_drop": 0.2,
528-
"aux_num_out": 32,
529+
"aux_num_out": 29,
529530
},
530531
_labels=_get_labels(),
531532
_sample_rate=16000,
@@ -614,7 +615,7 @@ def _get_labels():
614615
"encoder_dropout": 0.0,
615616
"encoder_layer_norm_first": True,
616617
"encoder_layer_drop": 0.0,
617-
"aux_num_out": 32,
618+
"aux_num_out": 29,
618619
},
619620
_labels=_get_labels(),
620621
_sample_rate=16000,
@@ -660,7 +661,7 @@ def _get_labels():
660661
"encoder_dropout": 0.0,
661662
"encoder_layer_norm_first": True,
662663
"encoder_layer_drop": 0.0,
663-
"aux_num_out": 32,
664+
"aux_num_out": 29,
664665
},
665666
_labels=_get_labels(),
666667
_sample_rate=16000,
@@ -706,7 +707,7 @@ def _get_labels():
706707
"encoder_dropout": 0.0,
707708
"encoder_layer_norm_first": True,
708709
"encoder_layer_drop": 0.0,
709-
"aux_num_out": 32,
710+
"aux_num_out": 29,
710711
},
711712
_labels=_get_labels(),
712713
_sample_rate=16000,
@@ -932,7 +933,7 @@ def _get_labels():
932933
'encoder_dropout': 0.0,
933934
'encoder_layer_norm_first': True,
934935
'encoder_layer_drop': 0.1,
935-
'aux_num_out': 32,
936+
'aux_num_out': 29,
936937
},
937938
_labels=_get_labels(),
938939
_sample_rate=16000,
@@ -979,7 +980,7 @@ def _get_labels():
979980
'encoder_dropout': 0.0,
980981
'encoder_layer_norm_first': True,
981982
'encoder_layer_drop': 0.1,
982-
'aux_num_out': 32,
983+
'aux_num_out': 29,
983984
},
984985
_labels=_get_labels(),
985986
_sample_rate=16000,

0 commit comments

Comments
 (0)