Skip to content

Commit 9c27422

Browse files
authored
Example pipeline with wav2letter (#632)
* example pipeline, initial commit. * removing notebook conversion artifacts. * remove extra comments. lint. * addressing some feedback. * main function. * defining args in function. * refactor. * lint. * checkpoint. * clean version to start with. * adding more parameters. * lint. * cleaning full version. * check for not None. * cleaning. * back -l 160 * black. * fix runtime error. * removing some print statements. * add help to command line. add progress bar option. * grouping librispeech-specific transform in subclass. * typo. * fix concatenation. * typo. * black. tqdm. * missing transpose. * renaming variables. * sum cer and wer * clip norm. * second signal handler removed. * cosmetic. * default to no checkpoint. * remove non_blocking. * adadelta works better than sgd. * anomaly detection. * moving dataset to separate file. * lint. * move to separate module: languagemodel, decoder, metric. * flush=True. * renaming decoder. * CTC Decoders. * flush=True. * pass length for viterbi decoder. * progress bar. relative path. * generalize transition matrix to n-gram. progress bar. * choice of decoder. * collate func. * remove signal handling. * adding distributed. * lint. * normalize w/r to length of dataset, and w/r to total number characters. * relative cer/wer. * clip grad parameter. momentum back but not yet used. * Switch to SGD. * choice of optimizer. * scheduler. * move to utils file. * metric log, and utils file. * rename metric_logger. * stderr and stdout. simpler metric logger. * replace by logging. * adding time measurement in metric logger. * fix duplicate name. remove tqdm. keep track of epoch instead and iteration instead. * rename main file. and add readme. * refactor distributed. * swap example and output in readme. * remove time from logger. * check non-empty tensor input. * typo in variable name and log update. * typo. * compute cer/wer in training too. * typo. * add back slurm signal capture to resubmit job. * update levinstein distance. * adding tests for levenstein distance. * record error rate during iteration. * metric logger using setitem. * moving signal break to end of loop and return loss so far. * typo. * add citation. * change default to best run. * adding other experiment with decoders. * remove other decoders than greedy. * Revert "remove other decoders than greedy." This reverts commit fb11437. * changing name of folfder. * remove other decoders, and unused dataset class. * rename functions to align with other pipeline. * pick which parts to train with. * adding specaugment to validation. note that caching prevents randomization from happening in validation. * updating readme. * typo in metric logging. * Revert "typo in metric logging." This reverts commit acac245. * Revert "Revert "typo in metric logging."" This reverts commit 2c80d96. * update metric logger. * simplify metric logger implementation. * use json dumps instead. * group metric together. * move function. * lint. * quick summary of files in folder. * pass clip_grad explictly. * typo in default dataset name. * option to disable logger. * ergonomics for distributed. * reminder about signal handler. * minor refactor of main in main. * replace by not_main_rank. * raising error if parameter not supported. * move model before invoking DDP. * changing log level. using python 2 style string for logging. * dynamic augmentations. * update metric log. batch cer/wer metric. correct typo in time. adding other dimensions in metric. * save learning rate even if function not available. * add type option to model. * add adamw. * reduce lr on validation step or training step. * specify hop-length and win-length. * normalize option. * rename parameter. * add dropout and tweak to number of channels. * copy model in pipeline folder for experimentation. * fix scheduler stepping. * fix input_type and num_features. * waveform mode changes shape more. * adding best character error rate with current implementation of model with mfcc. * comment update. * remove signal. remove custom wav2letter model. * remove comment. * simpler import with pandas.
1 parent 95d9f2d commit 9c27422

File tree

8 files changed

+983
-0
lines changed

8 files changed

+983
-0
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
This is an example pipeline for speech recognition using a greedy or Viterbi CTC decoder, along with the Wav2Letter model trained on LibriSpeech, see [Wav2Letter: an End-to-End ConvNet-based Speech Recognition System](https://arxiv.org/pdf/1609.03193.pdf). Wav2Letter and LibriSpeech are available in torchaudio.
2+
3+
### Usage
4+
5+
More information about each command line parameters is available with the `--help` option. An example can be invoked as follows.
6+
```
7+
python main.py \
8+
--reduce-lr-valid \
9+
--dataset-train train-clean-100 train-clean-360 train-other-500 \
10+
--dataset-valid dev-clean \
11+
--batch-size 128 \
12+
--learning-rate .6 \
13+
--momentum .8 \
14+
--weight-decay .00001 \
15+
--clip-grad 0. \
16+
--gamma .99 \
17+
--hop-length 160 \
18+
--n-hidden-channels 2000 \
19+
--win-length 400 \
20+
--n-bins 13 \
21+
--normalize \
22+
--optimizer adadelta \
23+
--scheduler reduceonplateau \
24+
--epochs 30
25+
```
26+
With these default parameters, we get a character error rate of 13.8% on dev-clean after 30 epochs.
27+
28+
### Output
29+
30+
The information reported at each iteration and epoch (e.g. loss, character error rate, word error rate) is printed to standard output in the form of one json per line, e.g.
31+
```python
32+
{"name": "train", "epoch": 0, "cer over target length": 1.0, "cumulative cer": 23317.0, "total chars": 23317.0, "cer": 0.0, "cumulative cer over target length": 0.0, "wer over target length": 1.0, "cumulative wer": 4446.0, "total words": 4446.0, "wer": 0.0, "cumulative wer over target length": 0.0, "lr": 0.6, "batch size": 128, "n_channel": 13, "n_time": 2453, "dataset length": 128.0, "iteration": 1.0, "loss": 8.712121963500977, "cumulative loss": 8.712121963500977, "average loss": 8.712121963500977, "iteration time": 41.46276903152466, "epoch time": 41.46276903152466}
33+
{"name": "train", "epoch": 0, "cer over target length": 1.0, "cumulative cer": 46005.0, "total chars": 46005.0, "cer": 0.0, "cumulative cer over target length": 0.0, "wer over target length": 1.0, "cumulative wer": 8762.0, "total words": 8762.0, "wer": 0.0, "cumulative wer over target length": 0.0, "lr": 0.6, "batch size": 128, "n_channel": 13, "n_time": 1703, "dataset length": 256.0, "iteration": 2.0, "loss": 8.918599128723145, "cumulative loss": 17.63072109222412, "average loss": 8.81536054611206, "iteration time": 1.2905676364898682, "epoch time": 42.753336668014526}
34+
{"name": "train", "epoch": 0, "cer over target length": 1.0, "cumulative cer": 70030.0, "total chars": 70030.0, "cer": 0.0, "cumulative cer over target length": 0.0, "wer over target length": 1.0, "cumulative wer": 13348.0, "total words": 13348.0, "wer": 0.0, "cumulative wer over target length": 0.0, "lr": 0.6, "batch size": 128, "n_channel": 13, "n_time": 1713, "dataset length": 384.0, "iteration": 3.0, "loss": 8.550191879272461, "cumulative loss": 26.180912971496582, "average loss": 8.726970990498861, "iteration time": 1.2109291553497314, "epoch time": 43.96426582336426}
35+
```
36+
One way to import the output in python with pandas is by saving the standard output to a file, and then using `pandas.read_json(filename, lines=True)`.
37+
38+
## Structure of pipeline
39+
40+
* `main.py` -- the entry point
41+
* `ctc_decoders.py` -- the greedy CTC decoder
42+
* `datasets.py` -- the function to split and process librispeech, a collate factory function
43+
* `languagemodels.py` -- a class to encode and decode strings
44+
* `metrics.py` -- the levenshtein edit distance
45+
* `utils.py` -- functions to log metrics, save checkpoint, and count parameters
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from torch import topk
2+
3+
4+
class GreedyDecoder:
5+
def __call__(self, outputs):
6+
"""Greedy Decoder. Returns highest probability of class labels for each timestep
7+
8+
Args:
9+
outputs (torch.Tensor): shape (input length, batch size, number of classes (including blank))
10+
11+
Returns:
12+
torch.Tensor: class labels per time step.
13+
"""
14+
_, indices = topk(outputs, k=1, dim=-1)
15+
return indices[..., 0]
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
from torchaudio.datasets import LIBRISPEECH
3+
4+
5+
class MapMemoryCache(torch.utils.data.Dataset):
6+
"""
7+
Wrap a dataset so that, whenever a new item is returned, it is saved to memory.
8+
"""
9+
10+
def __init__(self, dataset):
11+
self.dataset = dataset
12+
self._cache = [None] * len(dataset)
13+
14+
def __getitem__(self, n):
15+
if self._cache[n] is not None:
16+
return self._cache[n]
17+
18+
item = self.dataset[n]
19+
self._cache[n] = item
20+
21+
return item
22+
23+
def __len__(self):
24+
return len(self.dataset)
25+
26+
27+
class Processed(torch.utils.data.Dataset):
28+
def __init__(self, dataset, transforms, encode):
29+
self.dataset = dataset
30+
self.transforms = transforms
31+
self.encode = encode
32+
33+
def __getitem__(self, key):
34+
item = self.dataset[key]
35+
return self.process_datapoint(item)
36+
37+
def __len__(self):
38+
return len(self.dataset)
39+
40+
def process_datapoint(self, item):
41+
transformed = item[0]
42+
target = item[2].lower()
43+
44+
transformed = self.transforms(transformed)
45+
transformed = transformed[0, ...].transpose(0, -1)
46+
47+
target = self.encode(target)
48+
target = torch.tensor(target, dtype=torch.long, device=transformed.device)
49+
50+
return transformed, target
51+
52+
53+
def split_process_librispeech(
54+
datasets, transforms, language_model, root, folder_in_archive,
55+
):
56+
def create(tags, cache=True):
57+
58+
if isinstance(tags, str):
59+
tags = [tags]
60+
if isinstance(transforms, list):
61+
transform_list = transforms
62+
else:
63+
transform_list = [transforms]
64+
65+
data = torch.utils.data.ConcatDataset(
66+
[
67+
Processed(
68+
LIBRISPEECH(
69+
root, tag, folder_in_archive=folder_in_archive, download=False,
70+
),
71+
transform,
72+
language_model.encode,
73+
)
74+
for tag, transform in zip(tags, transform_list)
75+
]
76+
)
77+
78+
data = MapMemoryCache(data)
79+
return data
80+
81+
# For performance, we cache all datasets
82+
return tuple(create(dataset) for dataset in datasets)
83+
84+
85+
def collate_factory(model_length_function, transforms=None):
86+
87+
if transforms is None:
88+
transforms = torch.nn.Sequential()
89+
90+
def collate_fn(batch):
91+
92+
tensors = [transforms(b[0]) for b in batch if b]
93+
94+
tensors_lengths = torch.tensor(
95+
[model_length_function(t) for t in tensors],
96+
dtype=torch.long,
97+
device=tensors[0].device,
98+
)
99+
100+
tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)
101+
tensors = tensors.transpose(1, -1)
102+
103+
targets = [b[1] for b in batch if b]
104+
target_lengths = torch.tensor(
105+
[target.shape[0] for target in targets],
106+
dtype=torch.long,
107+
device=tensors.device,
108+
)
109+
targets = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True)
110+
111+
return tensors, targets, tensors_lengths, target_lengths
112+
113+
return collate_fn
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import collections
2+
import itertools
3+
4+
5+
class LanguageModel:
6+
def __init__(self, labels, char_blank, char_space):
7+
8+
self.char_space = char_space
9+
self.char_blank = char_blank
10+
11+
labels = [l for l in labels]
12+
self.length = len(labels)
13+
enumerated = list(enumerate(labels))
14+
flipped = [(sub[1], sub[0]) for sub in enumerated]
15+
16+
d1 = collections.OrderedDict(enumerated)
17+
d2 = collections.OrderedDict(flipped)
18+
self.mapping = {**d1, **d2}
19+
20+
def encode(self, iterable):
21+
if isinstance(iterable, list):
22+
return [self.encode(i) for i in iterable]
23+
else:
24+
return [self.mapping[i] + self.mapping[self.char_blank] for i in iterable]
25+
26+
def decode(self, tensor):
27+
if len(tensor) > 0 and isinstance(tensor[0], list):
28+
return [self.decode(t) for t in tensor]
29+
else:
30+
# not idempotent, since clean string
31+
x = (self.mapping[i] for i in tensor)
32+
x = "".join(i for i, _ in itertools.groupby(x))
33+
x = x.replace(self.char_blank, "")
34+
# x = x.strip()
35+
return x
36+
37+
def __len__(self):
38+
return self.length

0 commit comments

Comments
 (0)