-
Notifications
You must be signed in to change notification settings - Fork 738
Example pipeline with wav2letter #632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b6592ca
3e2f24a
ff90ee9
30efb4a
a55b7cd
b955815
7565d61
8fef171
b27ec81
4c5e4de
f246747
26b327a
8a768d2
e0b1359
be34e16
bd7c9c9
91528c9
3f28b75
38d0cae
797f0f9
79b5daf
401df9f
78fd8f7
516bdc8
0b65e43
c335fe4
bbc2a03
9ed1cce
552a1a9
c8cc7d7
cdb8f8e
2cd49b1
db943fc
8ecdef1
6b6cccb
0e250b3
f791ec5
2fb5097
9ca6f1d
f91f77f
5b3ef99
620e65d
8887c86
c53301d
a0c144e
50fc186
5e6a44a
4c6d87b
bbede94
6a0f12f
afc9d32
9dc45ca
0bfb559
28c905a
f99eef9
9431d55
91e71c1
9472c22
5d77b88
7529009
25cb8f3
660082c
dd03e37
358236a
490c222
17a5999
a188200
bd5d4d9
d1183dc
26de948
7d40304
214ed96
26fc391
8b3e156
16765be
61b61d8
243f9c2
4e34958
84a15a3
efb74f1
dbded0d
fb8324d
5063d68
61b7afc
bc95fb5
d8ee1e9
0503f65
cef6c50
0a90df5
1563288
463a25c
18a18e6
c4545d2
8e2d1f7
523e0e1
7efc028
7780b26
0006d89
68d0ac1
b087ff5
f5bcead
91a06a6
2167f27
10ef47c
6e6b2ea
8d6e27d
6f5f7cd
d7ebdb3
b67ba51
ecd8d73
af2eb0c
25524eb
0143803
406d2a3
3716e9d
f30f713
061dd40
8d49a70
4ea2596
340df0a
f5b1b1b
e5f733d
fe75249
4d2119a
e63a616
4a9381f
4795a72
cc4db15
a2b6ad2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| This is an example pipeline for speech recognition using a greedy or Viterbi CTC decoder, along with the Wav2Letter model trained on LibriSpeech, see [Wav2Letter: an End-to-End ConvNet-based Speech Recognition System](https://arxiv.org/pdf/1609.03193.pdf). Wav2Letter and LibriSpeech are available in torchaudio. | ||
|
|
||
| ### Usage | ||
|
|
||
| More information about each command line parameters is available with the `--help` option. An example can be invoked as follows. | ||
| ``` | ||
| python main.py \ | ||
| --reduce-lr-valid \ | ||
| --dataset-train train-clean-100 train-clean-360 train-other-500 \ | ||
| --dataset-valid dev-clean \ | ||
| --batch-size 128 \ | ||
| --learning-rate .6 \ | ||
| --momentum .8 \ | ||
| --weight-decay .00001 \ | ||
| --clip-grad 0. \ | ||
| --gamma .99 \ | ||
| --hop-length 160 \ | ||
| --n-hidden-channels 2000 \ | ||
| --win-length 400 \ | ||
| --n-bins 13 \ | ||
| --normalize \ | ||
| --optimizer adadelta \ | ||
| --scheduler reduceonplateau \ | ||
| --epochs 30 | ||
| ``` | ||
| With these default parameters, we get a character error rate of 13.8% on dev-clean after 30 epochs. | ||
|
|
||
| ### Output | ||
|
|
||
| The information reported at each iteration and epoch (e.g. loss, character error rate, word error rate) is printed to standard output in the form of one json per line, e.g. | ||
| ```python | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oops :) thx for pointing this out |
||
| {"name": "train", "epoch": 0, "cer over target length": 1.0, "cumulative cer": 23317.0, "total chars": 23317.0, "cer": 0.0, "cumulative cer over target length": 0.0, "wer over target length": 1.0, "cumulative wer": 4446.0, "total words": 4446.0, "wer": 0.0, "cumulative wer over target length": 0.0, "lr": 0.6, "batch size": 128, "n_channel": 13, "n_time": 2453, "dataset length": 128.0, "iteration": 1.0, "loss": 8.712121963500977, "cumulative loss": 8.712121963500977, "average loss": 8.712121963500977, "iteration time": 41.46276903152466, "epoch time": 41.46276903152466} | ||
| {"name": "train", "epoch": 0, "cer over target length": 1.0, "cumulative cer": 46005.0, "total chars": 46005.0, "cer": 0.0, "cumulative cer over target length": 0.0, "wer over target length": 1.0, "cumulative wer": 8762.0, "total words": 8762.0, "wer": 0.0, "cumulative wer over target length": 0.0, "lr": 0.6, "batch size": 128, "n_channel": 13, "n_time": 1703, "dataset length": 256.0, "iteration": 2.0, "loss": 8.918599128723145, "cumulative loss": 17.63072109222412, "average loss": 8.81536054611206, "iteration time": 1.2905676364898682, "epoch time": 42.753336668014526} | ||
| {"name": "train", "epoch": 0, "cer over target length": 1.0, "cumulative cer": 70030.0, "total chars": 70030.0, "cer": 0.0, "cumulative cer over target length": 0.0, "wer over target length": 1.0, "cumulative wer": 13348.0, "total words": 13348.0, "wer": 0.0, "cumulative wer over target length": 0.0, "lr": 0.6, "batch size": 128, "n_channel": 13, "n_time": 1713, "dataset length": 384.0, "iteration": 3.0, "loss": 8.550191879272461, "cumulative loss": 26.180912971496582, "average loss": 8.726970990498861, "iteration time": 1.2109291553497314, "epoch time": 43.96426582336426} | ||
| ``` | ||
| One way to import the output in python with pandas is by saving the standard output to a file, and then using `pandas.read_json(filename, lines=True)`. | ||
|
|
||
| ## Structure of pipeline | ||
|
|
||
| * `main.py` -- the entry point | ||
| * `ctc_decoders.py` -- the greedy CTC decoder | ||
| * `datasets.py` -- the function to split and process librispeech, a collate factory function | ||
| * `languagemodels.py` -- a class to encode and decode strings | ||
| * `metrics.py` -- the levenshtein edit distance | ||
| * `utils.py` -- functions to log metrics, save checkpoint, and count parameters | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| from torch import topk | ||
|
|
||
|
|
||
| class GreedyDecoder: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could generalize this file to be called "decoders.py" and also fold in things such as compute_error_rates There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This class is stateless. Can it be a function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It could be functional corresponding to a transform, but really it's a step towards our beamsearch work |
||
| def __call__(self, outputs): | ||
| """Greedy Decoder. Returns highest probability of class labels for each timestep | ||
|
|
||
| Args: | ||
| outputs (torch.Tensor): shape (input length, batch size, number of classes (including blank)) | ||
|
|
||
| Returns: | ||
| torch.Tensor: class labels per time step. | ||
| """ | ||
| _, indices = topk(outputs, k=1, dim=-1) | ||
| return indices[..., 0] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| import torch | ||
| from torchaudio.datasets import LIBRISPEECH | ||
|
|
||
|
|
||
| class MapMemoryCache(torch.utils.data.Dataset): | ||
| """ | ||
| Wrap a dataset so that, whenever a new item is returned, it is saved to memory. | ||
| """ | ||
|
|
||
| def __init__(self, dataset): | ||
| self.dataset = dataset | ||
| self._cache = [None] * len(dataset) | ||
|
|
||
| def __getitem__(self, n): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be simplified. |
||
| if self._cache[n] is not None: | ||
| return self._cache[n] | ||
|
|
||
| item = self.dataset[n] | ||
| self._cache[n] = item | ||
|
|
||
| return item | ||
|
|
||
| def __len__(self): | ||
| return len(self.dataset) | ||
|
|
||
|
|
||
| class Processed(torch.utils.data.Dataset): | ||
| def __init__(self, dataset, transforms, encode): | ||
| self.dataset = dataset | ||
| self.transforms = transforms | ||
| self.encode = encode | ||
|
|
||
| def __getitem__(self, key): | ||
| item = self.dataset[key] | ||
| return self.process_datapoint(item) | ||
|
|
||
| def __len__(self): | ||
| return len(self.dataset) | ||
|
|
||
| def process_datapoint(self, item): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This operation is not generic and requires specific item type, and since it uses index slicing it is very difficult to understand what it does. Please add a docstring. |
||
| transformed = item[0] | ||
| target = item[2].lower() | ||
|
|
||
| transformed = self.transforms(transformed) | ||
| transformed = transformed[0, ...].transpose(0, -1) | ||
|
|
||
| target = self.encode(target) | ||
| target = torch.tensor(target, dtype=torch.long, device=transformed.device) | ||
|
|
||
| return transformed, target | ||
|
|
||
|
|
||
| def split_process_librispeech( | ||
| datasets, transforms, language_model, root, folder_in_archive, | ||
| ): | ||
| def create(tags, cache=True): | ||
|
|
||
| if isinstance(tags, str): | ||
| tags = [tags] | ||
| if isinstance(transforms, list): | ||
| transform_list = transforms | ||
| else: | ||
| transform_list = [transforms] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is an example code and all the helper functions are for making the example code main code simpler, so making helper functions more specific helps better with maintainability. Instead of allowing multiple types, it's simpler to allow only one type and do the equivalent type conversion in client code. |
||
|
|
||
| data = torch.utils.data.ConcatDataset( | ||
| [ | ||
| Processed( | ||
| LIBRISPEECH( | ||
| root, tag, folder_in_archive=folder_in_archive, download=False, | ||
| ), | ||
| transform, | ||
| language_model.encode, | ||
| ) | ||
| for tag, transform in zip(tags, transform_list) | ||
| ] | ||
| ) | ||
|
|
||
| data = MapMemoryCache(data) | ||
| return data | ||
|
|
||
| # For performance, we cache all datasets | ||
| return tuple(create(dataset) for dataset in datasets) | ||
|
|
||
|
|
||
| def collate_factory(model_length_function, transforms=None): | ||
|
|
||
| if transforms is None: | ||
| transforms = torch.nn.Sequential() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not a fan of the declarative "nn.Sequential" approach and would write a custom function whose pointer I'd pass around, but I can see it being nice to aggregate transforms based on a sequence of decisions. |
||
|
|
||
| def collate_fn(batch): | ||
|
|
||
| tensors = [transforms(b[0]) for b in batch if b] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is very difficult to understand what are being transformed, here.
Why is there a case that one item in a batch (denoted as
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| tensors_lengths = torch.tensor( | ||
| [model_length_function(t) for t in tensors], | ||
| dtype=torch.long, | ||
| device=tensors[0].device, | ||
| ) | ||
|
|
||
| tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A wrapped / generalized version of this could form a useful torchaudio function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| tensors = tensors.transpose(1, -1) | ||
|
|
||
| targets = [b[1] for b in batch if b] | ||
| target_lengths = torch.tensor( | ||
| [target.shape[0] for target in targets], | ||
| dtype=torch.long, | ||
| device=tensors.device, | ||
| ) | ||
| targets = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True) | ||
|
|
||
| return tensors, targets, tensors_lengths, target_lengths | ||
|
|
||
| return collate_fn | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| import collections | ||
| import itertools | ||
|
|
||
|
|
||
| class LanguageModel: | ||
| def __init__(self, labels, char_blank, char_space): | ||
|
|
||
| self.char_space = char_space | ||
| self.char_blank = char_blank | ||
|
|
||
| labels = [l for l in labels] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cannot be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. Yes, it's just a string. |
||
| self.length = len(labels) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having |
||
| enumerated = list(enumerate(labels)) | ||
| flipped = [(sub[1], sub[0]) for sub in enumerated] | ||
|
|
||
| d1 = collections.OrderedDict(enumerated) | ||
| d2 = collections.OrderedDict(flipped) | ||
| self.mapping = {**d1, **d2} | ||
|
|
||
| def encode(self, iterable): | ||
| if isinstance(iterable, list): | ||
| return [self.encode(i) for i in iterable] | ||
| else: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if I pass an iterable that yields lists? What's the basecase type here? Maybe that's an easier case to branch on. Also as a very minor nit, I actually like using returns to avoid "else". So you could write |
||
| return [self.mapping[i] + self.mapping[self.char_blank] for i in iterable] | ||
|
|
||
| def decode(self, tensor): | ||
| if len(tensor) > 0 and isinstance(tensor[0], list): | ||
| return [self.decode(t) for t in tensor] | ||
| else: | ||
| # not idempotent, since clean string | ||
| x = (self.mapping[i] for i in tensor) | ||
| x = "".join(i for i, _ in itertools.groupby(x)) | ||
| x = x.replace(self.char_blank, "") | ||
| # x = x.strip() | ||
| return x | ||
|
|
||
| def __len__(self): | ||
| return self.length | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would write log to a separate file alongside with saved model, otherwise users have to redirect all the time, which is not very convenient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i must admit i do like the standard output a lot -- but i can see users preferring writing to a file, so i'll add the option to choose :)