Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@


class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels):
def __init__(self, labels, blank: int = 0):
super().__init__()
self.blank = blank
self.labels = labels

def forward(self, logits: torch.Tensor) -> str:
Expand All @@ -21,9 +22,8 @@ def forward(self, logits: torch.Tensor) -> str:
best_path = torch.unique_consecutive(best_path, dim=-1)
hypothesis = []
for i in best_path:
char = self.labels[i]
if char not in ['<s>', '<pad>']:
hypothesis.append(char)
if i != self.blank:
hypothesis.append(self.labels[i])
return ''.join(hypothesis)


Expand Down
51 changes: 26 additions & 25 deletions torchaudio/pipelines/_wav2vec2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Dict, Tuple, Any

import torch
from torch.hub import load_state_dict_from_url

from torchaudio.models import wav2vec2_model, Wav2Vec2Model
Expand Down Expand Up @@ -68,6 +69,14 @@ def get_model(self, *, dl_kwargs=None) -> Wav2Vec2Model:
url = f'https://download.pytorch.org/torchaudio/models/{self._path}'
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)

if model.aux is not None:
# For ASR task, the parameter originated from fairseq has unrelated dimensions at index 1, 2, 3
# It's originated from fairseq but not used, so we remove it here.
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
for key in ['aux.weight', 'aux.bias']:
t = state_dict[key]
state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in (1, 2, 3)])
model.load_state_dict(state_dict)
model.eval()
return model
Expand Down Expand Up @@ -102,7 +111,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
>>> # Check the corresponding labels of the output.
>>> labels = bundle.get_labels()
>>> print(labels)
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>>
>>> # Resample audio to the expected sampling rate
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
Expand All @@ -119,20 +128,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
def get_labels(
self,
*,
bos: str = '<s>',
pad: str = '<pad>',
eos: str = '</s>',
unk: str = '<unk>',
blank: str = '-',
) -> Tuple[str]:
"""The output class labels (only applicable to fine-tuned bundles)

The first four tokens are BOS, padding, EOS and UNK tokens and they can be customized.
The first is blank token, and it is customizable.

Args:
bos (str, optional): Beginning of sentence token. (default: ``'<s>'``)
pad (str, optional): Padding token. (default: ``'<pad>'``)
eos (str, optional): End of sentence token. (default: ``'</s>'``)
unk (str, optional): Token for unknown class. (default: ``'<unk>'``)
blank (str, optional): Blank token. (default: ``'-'``)

Returns:
Tuple[str]:
Expand All @@ -142,11 +145,9 @@ def get_labels(
Example
>>> import torchaudio
>>> torchaudio.models.HUBERT_ASR_LARGE.get_labels()
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
""" # noqa: E501
if self._labels is None:
raise ValueError('Pre-trained models do not have labels.')
return (bos, pad, eos, unk, *self._labels)
return (blank, *self._labels)


def _get_labels():
Expand Down Expand Up @@ -252,7 +253,7 @@ def _get_labels():
'encoder_dropout': 0.1,
'encoder_layer_norm_first': False,
'encoder_layer_drop': 0.05,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
Expand Down Expand Up @@ -298,7 +299,7 @@ def _get_labels():
'encoder_dropout': 0.1,
'encoder_layer_norm_first': False,
'encoder_layer_drop': 0.05,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
Expand Down Expand Up @@ -344,7 +345,7 @@ def _get_labels():
"encoder_dropout": 0.1,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.05,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
Expand Down Expand Up @@ -433,7 +434,7 @@ def _get_labels():
"encoder_dropout": 0.0,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.2,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
Expand Down Expand Up @@ -479,7 +480,7 @@ def _get_labels():
"encoder_dropout": 0.0,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.2,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
Expand Down Expand Up @@ -525,7 +526,7 @@ def _get_labels():
"encoder_dropout": 0.0,
"encoder_layer_norm_first": False,
"encoder_layer_drop": 0.2,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
Expand Down Expand Up @@ -614,7 +615,7 @@ def _get_labels():
"encoder_dropout": 0.0,
"encoder_layer_norm_first": True,
"encoder_layer_drop": 0.0,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
Expand Down Expand Up @@ -660,7 +661,7 @@ def _get_labels():
"encoder_dropout": 0.0,
"encoder_layer_norm_first": True,
"encoder_layer_drop": 0.0,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
Expand Down Expand Up @@ -706,7 +707,7 @@ def _get_labels():
"encoder_dropout": 0.0,
"encoder_layer_norm_first": True,
"encoder_layer_drop": 0.0,
"aux_num_out": 32,
"aux_num_out": 29,
},
_labels=_get_labels(),
_sample_rate=16000,
Expand Down Expand Up @@ -932,7 +933,7 @@ def _get_labels():
'encoder_dropout': 0.0,
'encoder_layer_norm_first': True,
'encoder_layer_drop': 0.1,
'aux_num_out': 32,
'aux_num_out': 29,
},
_labels=_get_labels(),
_sample_rate=16000,
Expand Down Expand Up @@ -979,7 +980,7 @@ def _get_labels():
'encoder_dropout': 0.0,
'encoder_layer_norm_first': True,
'encoder_layer_drop': 0.1,
'aux_num_out': 32,
'aux_num_out': 29,
},
_labels=_get_labels(),
_sample_rate=16000,
Expand Down