11from dataclasses import dataclass
22from typing import Dict , Tuple , Any
33
4+ import torch
45from torch .hub import load_state_dict_from_url
56
67from torchaudio .models import wav2vec2_model , Wav2Vec2Model
@@ -68,6 +69,14 @@ def get_model(self, *, dl_kwargs=None) -> Wav2Vec2Model:
6869 url = f'https://download.pytorch.org/torchaudio/models/{ self ._path } '
6970 dl_kwargs = {} if dl_kwargs is None else dl_kwargs
7071 state_dict = load_state_dict_from_url (url , ** dl_kwargs )
72+
73+ if model .aux is not None :
74+ # For ASR task, the parameter originated from fairseq has unrelated dimensions at index 1, 2, 3
75+ # It's originated from fairseq but not used, so we remove it here.
76+ # https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
77+ for key in ['aux.weight' , 'aux.bias' ]:
78+ t = state_dict [key ]
79+ state_dict [key ] = torch .stack ([t [i ] for i in range (t .size (0 )) if i not in (1 , 2 , 3 )])
7180 model .load_state_dict (state_dict )
7281 model .eval ()
7382 return model
@@ -102,7 +111,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
102111 >>> # Check the corresponding labels of the output.
103112 >>> labels = bundle.get_labels()
104113 >>> print(labels)
105- ('<s>', '<pad>', '</s>', '<unk> ', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
114+ ('- ', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
106115 >>>
107116 >>> # Resample audio to the expected sampling rate
108117 >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
@@ -119,20 +128,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
119128 def get_labels (
120129 self ,
121130 * ,
122- bos : str = '<s>' ,
123- pad : str = '<pad>' ,
124- eos : str = '</s>' ,
125- unk : str = '<unk>' ,
131+ blank : str = '-' ,
126132 ) -> Tuple [str ]:
127133 """The output class labels (only applicable to fine-tuned bundles)
128134
129- The first four tokens are BOS, padding, EOS and UNK tokens and they can be customized .
135+ The first is blank token, and it is customizable .
130136
131137 Args:
132- bos (str, optional): Beginning of sentence token. (default: ``'<s>'``)
133- pad (str, optional): Padding token. (default: ``'<pad>'``)
134- eos (str, optional): End of sentence token. (default: ``'</s>'``)
135- unk (str, optional): Token for unknown class. (default: ``'<unk>'``)
138+ blank (str, optional): Blank token. (default: ``'-'``)
136139
137140 Returns:
138141 Tuple[str]:
@@ -142,11 +145,9 @@ def get_labels(
142145 Example
143146 >>> import torchaudio
144147 >>> torchaudio.models.HUBERT_ASR_LARGE.get_labels()
145- ('<s>', '<pad>', '</s>', '<unk> ', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
148+ ('- ', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
146149 """ # noqa: E501
147- if self ._labels is None :
148- raise ValueError ('Pre-trained models do not have labels.' )
149- return (bos , pad , eos , unk , * self ._labels )
150+ return (blank , * self ._labels )
150151
151152
152153def _get_labels ():
@@ -252,7 +253,7 @@ def _get_labels():
252253 'encoder_dropout' : 0.1 ,
253254 'encoder_layer_norm_first' : False ,
254255 'encoder_layer_drop' : 0.05 ,
255- "aux_num_out" : 32 ,
256+ "aux_num_out" : 29 ,
256257 },
257258 _labels = _get_labels (),
258259 _sample_rate = 16000 ,
@@ -298,7 +299,7 @@ def _get_labels():
298299 'encoder_dropout' : 0.1 ,
299300 'encoder_layer_norm_first' : False ,
300301 'encoder_layer_drop' : 0.05 ,
301- "aux_num_out" : 32 ,
302+ "aux_num_out" : 29 ,
302303 },
303304 _labels = _get_labels (),
304305 _sample_rate = 16000 ,
@@ -344,7 +345,7 @@ def _get_labels():
344345 "encoder_dropout" : 0.1 ,
345346 "encoder_layer_norm_first" : False ,
346347 "encoder_layer_drop" : 0.05 ,
347- "aux_num_out" : 32 ,
348+ "aux_num_out" : 29 ,
348349 },
349350 _labels = _get_labels (),
350351 _sample_rate = 16000 ,
@@ -433,7 +434,7 @@ def _get_labels():
433434 "encoder_dropout" : 0.0 ,
434435 "encoder_layer_norm_first" : False ,
435436 "encoder_layer_drop" : 0.2 ,
436- "aux_num_out" : 32 ,
437+ "aux_num_out" : 29 ,
437438 },
438439 _labels = _get_labels (),
439440 _sample_rate = 16000 ,
@@ -479,7 +480,7 @@ def _get_labels():
479480 "encoder_dropout" : 0.0 ,
480481 "encoder_layer_norm_first" : False ,
481482 "encoder_layer_drop" : 0.2 ,
482- "aux_num_out" : 32 ,
483+ "aux_num_out" : 29 ,
483484 },
484485 _labels = _get_labels (),
485486 _sample_rate = 16000 ,
@@ -525,7 +526,7 @@ def _get_labels():
525526 "encoder_dropout" : 0.0 ,
526527 "encoder_layer_norm_first" : False ,
527528 "encoder_layer_drop" : 0.2 ,
528- "aux_num_out" : 32 ,
529+ "aux_num_out" : 29 ,
529530 },
530531 _labels = _get_labels (),
531532 _sample_rate = 16000 ,
@@ -614,7 +615,7 @@ def _get_labels():
614615 "encoder_dropout" : 0.0 ,
615616 "encoder_layer_norm_first" : True ,
616617 "encoder_layer_drop" : 0.0 ,
617- "aux_num_out" : 32 ,
618+ "aux_num_out" : 29 ,
618619 },
619620 _labels = _get_labels (),
620621 _sample_rate = 16000 ,
@@ -660,7 +661,7 @@ def _get_labels():
660661 "encoder_dropout" : 0.0 ,
661662 "encoder_layer_norm_first" : True ,
662663 "encoder_layer_drop" : 0.0 ,
663- "aux_num_out" : 32 ,
664+ "aux_num_out" : 29 ,
664665 },
665666 _labels = _get_labels (),
666667 _sample_rate = 16000 ,
@@ -706,7 +707,7 @@ def _get_labels():
706707 "encoder_dropout" : 0.0 ,
707708 "encoder_layer_norm_first" : True ,
708709 "encoder_layer_drop" : 0.0 ,
709- "aux_num_out" : 32 ,
710+ "aux_num_out" : 29 ,
710711 },
711712 _labels = _get_labels (),
712713 _sample_rate = 16000 ,
@@ -932,7 +933,7 @@ def _get_labels():
932933 'encoder_dropout' : 0.0 ,
933934 'encoder_layer_norm_first' : True ,
934935 'encoder_layer_drop' : 0.1 ,
935- 'aux_num_out' : 32 ,
936+ 'aux_num_out' : 29 ,
936937 },
937938 _labels = _get_labels (),
938939 _sample_rate = 16000 ,
@@ -979,7 +980,7 @@ def _get_labels():
979980 'encoder_dropout' : 0.0 ,
980981 'encoder_layer_norm_first' : True ,
981982 'encoder_layer_drop' : 0.1 ,
982- 'aux_num_out' : 32 ,
983+ 'aux_num_out' : 29 ,
983984 },
984985 _labels = _get_labels (),
985986 _sample_rate = 16000 ,
0 commit comments