Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
85ca1de
remove python annotation.
vincentqb Sep 15, 2020
919ea3c
feedback on language length and fork.
vincentqb Sep 15, 2020
b9d2ae9
factor into main-engine.
vincentqb Sep 15, 2020
8310c53
Revert "remove signal. remove custom wav2letter model."
vincentqb Sep 15, 2020
f97eb57
Revert "remove other decoders, and unused dataset class."
vincentqb Sep 15, 2020
9e23d2b
missing itertools.
vincentqb Sep 16, 2020
a1f4937
mono. pad_sequence.
vincentqb Sep 16, 2020
4f3c426
simplify language model.
vincentqb Sep 16, 2020
e614748
add docstring.
vincentqb Sep 16, 2020
6f21a44
update constructor.
vincentqb Sep 16, 2020
64aa916
rename parameters.
vincentqb Sep 16, 2020
649c9a2
use local wav2letter model.
vincentqb Sep 16, 2020
f176e64
update metric logger: saves some information.
vincentqb Sep 16, 2020
60ca4c8
remove random seed. raise error.
vincentqb Sep 16, 2020
c43d721
get optimizer and scheduler in separate function. use CTC sum reducti…
vincentqb Sep 16, 2020
fd59a07
logger now takes care of time.
vincentqb Sep 17, 2020
695391f
model input type rename, function constructor.
vincentqb Sep 17, 2020
d1eed2a
decouple save_checkpoint logic.
vincentqb Sep 17, 2020
e64296f
cosmetic.
vincentqb Sep 17, 2020
a0105ce
update to reflect recent changes.
vincentqb Sep 17, 2020
4c350f3
update error messages.
vincentqb Sep 17, 2020
2f62134
monotonic time.
vincentqb Sep 18, 2020
40db285
lint.
vincentqb Sep 18, 2020
f866e57
remove jit option.
vincentqb Sep 18, 2020
1f4252c
update metric logger. timestamp in logging.
vincentqb Sep 18, 2020
a2cb162
correct name.
vincentqb Sep 20, 2020
565f4ad
correct name.
vincentqb Sep 20, 2020
175cdf6
speech commands. typo in loss.
vincentqb Sep 22, 2020
38923e3
slurm script.
vincentqb Sep 22, 2020
9b41617
typo, reduction, clip_grad none.
vincentqb Oct 1, 2020
cbb02f6
lint. custom model with drouput and num_hiddent_channel is caussing i…
vincentqb Oct 2, 2020
374e08d
moving distributed parameters to command line. adding gloo backend fo…
vincentqb Oct 2, 2020
b83d638
rename parameter.
vincentqb Oct 2, 2020
6033926
add back iteration.
vincentqb Oct 6, 2020
20536d8
use all dataset for new training.
vincentqb Oct 13, 2020
154dc6c
update parameter in readme. add distributed.
vincentqb Oct 13, 2020
462fa20
remove signal. remove custom wav2letter model.
vincentqb Sep 16, 2020
91344b0
Revert "Revert "remove other decoders, and unused dataset class.""
vincentqb Sep 16, 2020
b9325fe
Revert "speech commands. typo in loss."
vincentqb Sep 22, 2020
2d527f9
Revert "slurm script."
vincentqb Sep 22, 2020
d7c30bb
remove iterablememorycache.
vincentqb Oct 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions examples/pipeline_wav2letter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,80 @@ This is an example pipeline for speech recognition using a greedy or Viterbi CTC
More information about each command line parameters is available with the `--help` option. An example can be invoked as follows.
```
python main.py \
--reduce-lr-valid \
--dataset-train train-clean-100 train-clean-360 train-other-500 \
--dataset-valid dev-clean \
--batch-size 128 \
--learning-rate .6 \
--momentum .8 \
--weight-decay .00001 \
--clip-grad 0. \
--gamma .99 \
--hop-length 160 \
--n-hidden-channels 2000 \
--hidden-channels 2000 \
--win-length 400 \
--n-bins 13 \
--bins 13 \
--normalize \
--optimizer adadelta \
--scheduler reduceonplateau \
--epochs 30
--reduce-lr-valid \
--optimizer adadelta \
--learning-rate .6 \
--momentum .8 \
--weight-decay .00001 \
--clip-grad 0. \
--max-epoch 40
```
With these default parameters, we get a character error rate of 13.8% on dev-clean after 30 epochs.

### Output

The information reported at each iteration and epoch (e.g. loss, character error rate, word error rate) is printed to standard output in the form of one json per line, e.g.
```python
{"name": "train", "epoch": 0, "cer over target length": 1.0, "cumulative cer": 23317.0, "total chars": 23317.0, "cer": 0.0, "cumulative cer over target length": 0.0, "wer over target length": 1.0, "cumulative wer": 4446.0, "total words": 4446.0, "wer": 0.0, "cumulative wer over target length": 0.0, "lr": 0.6, "batch size": 128, "n_channel": 13, "n_time": 2453, "dataset length": 128.0, "iteration": 1.0, "loss": 8.712121963500977, "cumulative loss": 8.712121963500977, "average loss": 8.712121963500977, "iteration time": 41.46276903152466, "epoch time": 41.46276903152466}
{"name": "train", "epoch": 0, "cer over target length": 1.0, "cumulative cer": 46005.0, "total chars": 46005.0, "cer": 0.0, "cumulative cer over target length": 0.0, "wer over target length": 1.0, "cumulative wer": 8762.0, "total words": 8762.0, "wer": 0.0, "cumulative wer over target length": 0.0, "lr": 0.6, "batch size": 128, "n_channel": 13, "n_time": 1703, "dataset length": 256.0, "iteration": 2.0, "loss": 8.918599128723145, "cumulative loss": 17.63072109222412, "average loss": 8.81536054611206, "iteration time": 1.2905676364898682, "epoch time": 42.753336668014526}
{"name": "train", "epoch": 0, "cer over target length": 1.0, "cumulative cer": 70030.0, "total chars": 70030.0, "cer": 0.0, "cumulative cer over target length": 0.0, "wer over target length": 1.0, "cumulative wer": 13348.0, "total words": 13348.0, "wer": 0.0, "cumulative wer over target length": 0.0, "lr": 0.6, "batch size": 128, "n_channel": 13, "n_time": 1713, "dataset length": 384.0, "iteration": 3.0, "loss": 8.550191879272461, "cumulative loss": 26.180912971496582, "average loss": 8.726970990498861, "iteration time": 1.2109291553497314, "epoch time": 43.96426582336426}
```
{"name": "train", "elapsed time": 26.613844830542803, "iteration": 1, "epoch": 0, "batch size": 128, "cumulative batch size": 128.0, "cumulative loss": 199767.09375, "epoch loss": 1560.680419921875, "batch loss": 1560.680419921875, "total chars": 22797.0, "cumulative char errors": 22670.0, "batch cer": 0.9944290915471334, "epoch cer": 0.9944290915471334, "total words": 4349.0, "cumulative word errors": 4320.0, "batch wer": 0.9933318004138882, "epoch wer": 0.9933318004138882, "lr": 0.6, "channel size": 13, "time size": 1669}
{"name": "train", "elapsed time": 14944777.921057917, "iteration": 2, "epoch": 0, "batch size": 128, "cumulative batch size": 256.0, "cumulative loss": 396356.84375, "epoch loss": 1548.2689208984375, "batch loss": 1535.857421875, "total chars": 46234.0, "cumulative char errors": 46107.0, "batch cer": 1.0, "epoch cer": 0.9972531037764416, "total words": 8805.0, "cumulative word errors": 8776.0, "batch wer": 1.0, "epoch wer": 0.9967064168086315, "lr": 0.6, "channel size": 13, "time size": 1677}
{"name": "train", "elapsed time": 21.74324156017974, "iteration": 1, "epoch": 0, "batch size": 128, "cumulative batch size": 128.0, "cumulative loss": 205090.8125, "epoch loss": 1602.27197265625, "batch loss": 1602.27197265625, "total chars": 23428.0, "cumulative char errors": 23300.0, "batch cer": 0.994536452108588, "epoch cer": 0.994536452108588, "total words": 4455.0, "cumulative word errors": 4414.0, "batch wer": 0.9907968574635241, "epoch wer": 0.9907968574635241, "lr": 0.6, "channel size": 13, "time size": 1697}
{"name": "train", "elapsed time": 1572825.2700412387, "iteration": 2, "epoch": 0, "batch size": 128, "cumulative batch size": 256.0, "cumulative loss": 402968.03125, "epoch loss": 1574.0938720703125, "batch loss": 1545.915771484375, "total chars": 46845.0, "cumulative char errors": 46717.0, "batch cer": 1.0, "epoch cer": 0.9972675845874693, "total words": 8887.0, "cumulative word errors": 8846.0, "batch wer": 1.0, "epoch wer": 0.9953865196354226, "lr": 0.6, "channel size": 13, "time size": 1652}
{"name": "train", "elapsed time": 36.98494444228709, "iteration": 3, "epoch": 0, "batch size": 128, "cumulative batch size": 384.0, "cumulative loss": 568859.890625, "epoch loss": 1481.4059651692708, "batch loss": 1347.6800537109375, "total chars": 69573.0, "cumulative char errors": 69446.0, "batch cer": 1.0, "epoch cer": 0.9981745792189499, "total words": 13265.0, "cumulative word errors": 13236.0, "batch wer": 1.0, "epoch wer": 0.9978137957029778, "lr": 0.6, "channel size": 13, "time size": 1690}
...
{"name": "validation", "elapsed time": 141.4699967801571, "iteration": 1, "epoch": 0, "batch size": 15, "cumulative batch size": 2703.0, "cumulative loss": 783672.287109375, "epoch loss": 289.9268542764983, "batch loss": 378.550390625, "total chars": 282969.0, "cumulative char errors": 198943.0, "batch cer": 0.6935483870967742, "epoch cer": 0.7030558117673668, "total words": 54402.0, "cumulative word errors": 63139.0, "batch wer": 1.2208121827411167, "epoch wer": 1.1606007132090732}
...
```
One way to import the output in python with pandas is by saving the standard output to a file, and then using `pandas.read_json(filename, lines=True)`.

## Structure of pipeline

* `main.py` -- the entry point
* `ctc_decoders.py` -- the greedy CTC decoder
* `datasets.py` -- the function to split and process librispeech, a collate factory function
* `languagemodels.py` -- a class to encode and decode strings
* `metrics.py` -- the levenshtein edit distance
* `main.py` -- command line entry point
* `engine.py` -- preprocessing, training, and validation, code
* `ctc_decoders.py` -- greedy CTC decoder
* `datasets.py` -- function to split and process librispeech, a collate factory function
* `languagemodels.py` -- class to encode and decode strings
* `metrics.py` -- levenshtein edit distance
* `utils.py` -- functions to log metrics, save checkpoint, and count parameters

## Distributed

The option `--distributed` enables distributed mode. For example with SLURM, one could use the follow file.
```
#SBATCH --job-name=torchaudiomodel
#SBATCH --open-mode=append
#SBATCH --nodes=2
#SBATCH --gres=gpu:8

export MASTER_ADDR=${SLURM_JOB_NODELIST:0:9}${SLURM_JOB_NODELIST:10:4}
export MASTER_PORT=29500

python main.py \
--dataset-train train-clean-100 train-clean-360 train-other-500 \
--dataset-valid dev-clean \
--batch-size 128 \
--hop-length 160 \
--hidden-channels 2000 \
--win-length 400 \
--bins 13 \
--normalize \
--scheduler reduceonplateau \
--reduce-lr-valid \
--optimizer adadelta \
--learning-rate .6 \
--momentum .8 \
--weight-decay .00001 \
--clip-grad 0. \
--max-epoch 40 \
--distributed \
--distributed-master-addr ${MASTER_ADDR} \
--distributed-master-port ${MASTER_PORT}
```
84 changes: 71 additions & 13 deletions examples/pipeline_wav2letter/datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,61 @@
import itertools
from typing import List

import torch
from torch import Tensor
from torchaudio.datasets import LIBRISPEECH


def pad_sequence(sequences, padding_value=0.0):
Copy link
Contributor Author

@vincentqb vincentqb Oct 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: how fast this runs against something like this?

def pad_sequence(batch):
    # Make all tensor in a batch the same length by padding with zeros
    batch = [item.t() for item in batch]
    batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.)
    return batch.permute(0, 2, 1)

# type: (List[Tensor], float) -> Tensor
r"""Pad a list of variable length Tensors with ``padding_value``

``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. If the input is list of
sequences with size ``* x L`` then the output is and ``B x * x T``.

`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.

Example:
>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(300, 25)
>>> b = torch.ones(300, 22)
>>> c = torch.ones(300, 15)
>>> pad_sequence([a, b, c]).size()
torch.Size([300, 3, 25])

Note:
This function returns a Tensor of size ``B x * x T``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.

Arguments:
sequences (list[Tensor]): list of variable length sequences.
padding_value (float, optional): value for padded elements. Default: 0.

Returns:
Tensor of size ``B x * x T``
"""

# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size = sequences[0].size()
trailing_dims = max_size[:-1]
max_len = max([s.size(-1) for s in sequences])
out_dims = (len(sequences),) + trailing_dims + (max_len,)

out_tensor = sequences[0].new_full(out_dims, padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(-1)
# use index notation to prevent duplicate references to the tensor
out_tensor[i, ..., :length] = tensor

return out_tensor


class MapMemoryCache(torch.utils.data.Dataset):
"""
Wrap a dataset so that, whenever a new item is returned, it is saved to memory.
Expand All @@ -12,13 +66,9 @@ def __init__(self, dataset):
self._cache = [None] * len(dataset)

def __getitem__(self, n):
if self._cache[n] is not None:
return self._cache[n]

item = self.dataset[n]
self._cache[n] = item

return item
if self._cache[n] is None:
self._cache[n] = self.dataset[n]
return self._cache[n]

def __len__(self):
return len(self.dataset)
Expand All @@ -38,11 +88,16 @@ def __len__(self):
return len(self.dataset)

def process_datapoint(self, item):
"""
Consume a LibriSpeech data point tuple:
(waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id).
- Transforms are applied to waveform. Output tensor shape (freq, time).
- target gets transformed into lower case, and encoded into a one dimensional long tensor.
"""
transformed = item[0]
target = item[2].lower()

transformed = self.transforms(transformed)
transformed = transformed[0, ...].transpose(0, -1)

target = self.encode(target)
target = torch.tensor(target, dtype=torch.long, device=transformed.device)
Expand Down Expand Up @@ -89,24 +144,27 @@ def collate_factory(model_length_function, transforms=None):

def collate_fn(batch):

tensors = [transforms(b[0]) for b in batch if b]
tensors = [transforms(b[0]) for b in batch] # apply transforms to waveforms

tensors_lengths = torch.tensor(
[model_length_function(t) for t in tensors],
dtype=torch.long,
device=tensors[0].device,
)

tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)
tensors = tensors.transpose(1, -1)
# tensors = [b.transpose(1, -1) for b in batch]
# tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)
# tensors = tensors.transpose(1, -1)
tensors = pad_sequence(tensors)

targets = [b[1] for b in batch if b]
targets = [b[1] for b in batch] # extract target utterance
target_lengths = torch.tensor(
[target.shape[0] for target in targets],
dtype=torch.long,
device=tensors.device,
)
targets = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True)
# targets = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True)
targets = pad_sequence(targets)

return tensors, targets, tensors_lengths, target_lengths

Expand Down
Loading