diff --git a/examples/pipeline_wav2letter/README.md b/examples/pipeline_wav2letter/README.md new file mode 100644 index 0000000000..8117d19c91 --- /dev/null +++ b/examples/pipeline_wav2letter/README.md @@ -0,0 +1,45 @@ +This is an example pipeline for speech recognition using a greedy or Viterbi CTC decoder, along with the Wav2Letter model trained on LibriSpeech, see [Wav2Letter: an End-to-End ConvNet-based Speech Recognition System](https://arxiv.org/pdf/1609.03193.pdf). Wav2Letter and LibriSpeech are available in torchaudio. + +### Usage + +More information about each command line parameters is available with the `--help` option. An example can be invoked as follows. +``` +python main.py \ + --reduce-lr-valid \ + --dataset-train train-clean-100 train-clean-360 train-other-500 \ + --dataset-valid dev-clean \ + --batch-size 128 \ + --learning-rate .6 \ + --momentum .8 \ + --weight-decay .00001 \ + --clip-grad 0. \ + --gamma .99 \ + --hop-length 160 \ + --n-hidden-channels 2000 \ + --win-length 400 \ + --n-bins 13 \ + --normalize \ + --optimizer adadelta \ + --scheduler reduceonplateau \ + --epochs 30 +``` +With these default parameters, we get a character error rate of 13.8% on dev-clean after 30 epochs. + +### Output + +The information reported at each iteration and epoch (e.g. loss, character error rate, word error rate) is printed to standard output in the form of one json per line, e.g. +```python +{"name": "train", "epoch": 0, "cer over target length": 1.0, "cumulative cer": 23317.0, "total chars": 23317.0, "cer": 0.0, "cumulative cer over target length": 0.0, "wer over target length": 1.0, "cumulative wer": 4446.0, "total words": 4446.0, "wer": 0.0, "cumulative wer over target length": 0.0, "lr": 0.6, "batch size": 128, "n_channel": 13, "n_time": 2453, "dataset length": 128.0, "iteration": 1.0, "loss": 8.712121963500977, "cumulative loss": 8.712121963500977, "average loss": 8.712121963500977, "iteration time": 41.46276903152466, "epoch time": 41.46276903152466} +{"name": "train", "epoch": 0, "cer over target length": 1.0, "cumulative cer": 46005.0, "total chars": 46005.0, "cer": 0.0, "cumulative cer over target length": 0.0, "wer over target length": 1.0, "cumulative wer": 8762.0, "total words": 8762.0, "wer": 0.0, "cumulative wer over target length": 0.0, "lr": 0.6, "batch size": 128, "n_channel": 13, "n_time": 1703, "dataset length": 256.0, "iteration": 2.0, "loss": 8.918599128723145, "cumulative loss": 17.63072109222412, "average loss": 8.81536054611206, "iteration time": 1.2905676364898682, "epoch time": 42.753336668014526} +{"name": "train", "epoch": 0, "cer over target length": 1.0, "cumulative cer": 70030.0, "total chars": 70030.0, "cer": 0.0, "cumulative cer over target length": 0.0, "wer over target length": 1.0, "cumulative wer": 13348.0, "total words": 13348.0, "wer": 0.0, "cumulative wer over target length": 0.0, "lr": 0.6, "batch size": 128, "n_channel": 13, "n_time": 1713, "dataset length": 384.0, "iteration": 3.0, "loss": 8.550191879272461, "cumulative loss": 26.180912971496582, "average loss": 8.726970990498861, "iteration time": 1.2109291553497314, "epoch time": 43.96426582336426} +``` +One way to import the output in python with pandas is by saving the standard output to a file, and then using `pandas.read_json(filename, lines=True)`. + +## Structure of pipeline + +* `main.py` -- the entry point +* `ctc_decoders.py` -- the greedy CTC decoder +* `datasets.py` -- the function to split and process librispeech, a collate factory function +* `languagemodels.py` -- a class to encode and decode strings +* `metrics.py` -- the levenshtein edit distance +* `utils.py` -- functions to log metrics, save checkpoint, and count parameters diff --git a/examples/pipeline_wav2letter/ctc_decoders.py b/examples/pipeline_wav2letter/ctc_decoders.py new file mode 100644 index 0000000000..b4f155d6fa --- /dev/null +++ b/examples/pipeline_wav2letter/ctc_decoders.py @@ -0,0 +1,15 @@ +from torch import topk + + +class GreedyDecoder: + def __call__(self, outputs): + """Greedy Decoder. Returns highest probability of class labels for each timestep + + Args: + outputs (torch.Tensor): shape (input length, batch size, number of classes (including blank)) + + Returns: + torch.Tensor: class labels per time step. + """ + _, indices = topk(outputs, k=1, dim=-1) + return indices[..., 0] diff --git a/examples/pipeline_wav2letter/datasets.py b/examples/pipeline_wav2letter/datasets.py new file mode 100644 index 0000000000..79b05b2c5b --- /dev/null +++ b/examples/pipeline_wav2letter/datasets.py @@ -0,0 +1,113 @@ +import torch +from torchaudio.datasets import LIBRISPEECH + + +class MapMemoryCache(torch.utils.data.Dataset): + """ + Wrap a dataset so that, whenever a new item is returned, it is saved to memory. + """ + + def __init__(self, dataset): + self.dataset = dataset + self._cache = [None] * len(dataset) + + def __getitem__(self, n): + if self._cache[n] is not None: + return self._cache[n] + + item = self.dataset[n] + self._cache[n] = item + + return item + + def __len__(self): + return len(self.dataset) + + +class Processed(torch.utils.data.Dataset): + def __init__(self, dataset, transforms, encode): + self.dataset = dataset + self.transforms = transforms + self.encode = encode + + def __getitem__(self, key): + item = self.dataset[key] + return self.process_datapoint(item) + + def __len__(self): + return len(self.dataset) + + def process_datapoint(self, item): + transformed = item[0] + target = item[2].lower() + + transformed = self.transforms(transformed) + transformed = transformed[0, ...].transpose(0, -1) + + target = self.encode(target) + target = torch.tensor(target, dtype=torch.long, device=transformed.device) + + return transformed, target + + +def split_process_librispeech( + datasets, transforms, language_model, root, folder_in_archive, +): + def create(tags, cache=True): + + if isinstance(tags, str): + tags = [tags] + if isinstance(transforms, list): + transform_list = transforms + else: + transform_list = [transforms] + + data = torch.utils.data.ConcatDataset( + [ + Processed( + LIBRISPEECH( + root, tag, folder_in_archive=folder_in_archive, download=False, + ), + transform, + language_model.encode, + ) + for tag, transform in zip(tags, transform_list) + ] + ) + + data = MapMemoryCache(data) + return data + + # For performance, we cache all datasets + return tuple(create(dataset) for dataset in datasets) + + +def collate_factory(model_length_function, transforms=None): + + if transforms is None: + transforms = torch.nn.Sequential() + + def collate_fn(batch): + + tensors = [transforms(b[0]) for b in batch if b] + + tensors_lengths = torch.tensor( + [model_length_function(t) for t in tensors], + dtype=torch.long, + device=tensors[0].device, + ) + + tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True) + tensors = tensors.transpose(1, -1) + + targets = [b[1] for b in batch if b] + target_lengths = torch.tensor( + [target.shape[0] for target in targets], + dtype=torch.long, + device=tensors.device, + ) + targets = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True) + + return tensors, targets, tensors_lengths, target_lengths + + return collate_fn diff --git a/examples/pipeline_wav2letter/languagemodels.py b/examples/pipeline_wav2letter/languagemodels.py new file mode 100644 index 0000000000..eee018ae69 --- /dev/null +++ b/examples/pipeline_wav2letter/languagemodels.py @@ -0,0 +1,38 @@ +import collections +import itertools + + +class LanguageModel: + def __init__(self, labels, char_blank, char_space): + + self.char_space = char_space + self.char_blank = char_blank + + labels = [l for l in labels] + self.length = len(labels) + enumerated = list(enumerate(labels)) + flipped = [(sub[1], sub[0]) for sub in enumerated] + + d1 = collections.OrderedDict(enumerated) + d2 = collections.OrderedDict(flipped) + self.mapping = {**d1, **d2} + + def encode(self, iterable): + if isinstance(iterable, list): + return [self.encode(i) for i in iterable] + else: + return [self.mapping[i] + self.mapping[self.char_blank] for i in iterable] + + def decode(self, tensor): + if len(tensor) > 0 and isinstance(tensor[0], list): + return [self.decode(t) for t in tensor] + else: + # not idempotent, since clean string + x = (self.mapping[i] for i in tensor) + x = "".join(i for i, _ in itertools.groupby(x)) + x = x.replace(self.char_blank, "") + # x = x.strip() + return x + + def __len__(self): + return self.length diff --git a/examples/pipeline_wav2letter/main.py b/examples/pipeline_wav2letter/main.py new file mode 100644 index 0000000000..8f52ec6c3c --- /dev/null +++ b/examples/pipeline_wav2letter/main.py @@ -0,0 +1,668 @@ +import argparse +import logging +import os +import string +from datetime import datetime +from time import time + +import torch +import torchaudio +from torch.optim import SGD, Adadelta, Adam, AdamW +from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau +from torch.utils.data import DataLoader +from torchaudio.datasets.utils import bg_iterator +from torchaudio.models.wav2letter import Wav2Letter + +from ctc_decoders import GreedyDecoder +from datasets import collate_factory, split_process_librispeech +from languagemodels import LanguageModel +from metrics import levenshtein_distance +from transforms import Normalize, UnsqueezeFirst +from utils import MetricLogger, count_parameters, save_checkpoint + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--type", + metavar="T", + default="mfcc", + choices=["waveform", "mfcc"], + help="input type for model", + ) + parser.add_argument( + "--freq-mask", + default=0, + type=int, + metavar="N", + help="maximal width of frequency mask", + ) + parser.add_argument( + "--win-length", + default=400, + type=int, + metavar="N", + help="width of spectrogram window", + ) + parser.add_argument( + "--hop-length", + default=160, + type=int, + metavar="N", + help="width of spectrogram window", + ) + parser.add_argument( + "--time-mask", + default=0, + type=int, + metavar="N", + help="maximal width of time mask", + ) + parser.add_argument( + "--workers", + default=0, + type=int, + metavar="N", + help="number of data loading workers", + ) + parser.add_argument( + "--checkpoint", + default="", + type=str, + metavar="PATH", + help="path to latest checkpoint", + ) + parser.add_argument( + "--epochs", + default=200, + type=int, + metavar="N", + help="number of total epochs to run", + ) + parser.add_argument( + "--start-epoch", default=0, type=int, metavar="N", help="manual epoch number" + ) + parser.add_argument( + "--print-freq", + default=10, + type=int, + metavar="N", + help="print frequency in epochs", + ) + parser.add_argument( + "--reduce-lr-valid", + action="store_true", + help="reduce learning rate based on validation loss", + ) + parser.add_argument( + "--normalize", action="store_true", help="normalize model input" + ) + parser.add_argument( + "--progress-bar", action="store_true", help="use progress bar while training" + ) + parser.add_argument( + "--decoder", + metavar="D", + default="greedy", + choices=["greedy"], + help="decoder to use", + ) + parser.add_argument( + "--batch-size", default=128, type=int, metavar="N", help="mini-batch size" + ) + parser.add_argument( + "--n-bins", + default=13, + type=int, + metavar="N", + help="number of bins in transforms", + ) + parser.add_argument( + "--optimizer", + metavar="OPT", + default="adadelta", + choices=["sgd", "adadelta", "adam", "adamw"], + help="optimizer to use", + ) + parser.add_argument( + "--scheduler", + metavar="S", + default="reduceonplateau", + choices=["exponential", "reduceonplateau"], + help="optimizer to use", + ) + parser.add_argument( + "--learning-rate", + default=0.6, + type=float, + metavar="LR", + help="initial learning rate", + ) + parser.add_argument( + "--gamma", + default=0.99, + type=float, + metavar="GAMMA", + help="learning rate exponential decay constant", + ) + parser.add_argument( + "--momentum", default=0.8, type=float, metavar="M", help="momentum" + ) + parser.add_argument( + "--weight-decay", default=1e-5, type=float, metavar="W", help="weight decay" + ) + parser.add_argument("--eps", metavar="EPS", type=float, default=1e-8) + parser.add_argument("--rho", metavar="RHO", type=float, default=0.95) + parser.add_argument("--clip-grad", metavar="NORM", type=float, default=0.0) + parser.add_argument( + "--dataset-root", + type=str, + help="specify dataset root folder", + ) + parser.add_argument( + "--dataset-folder-in-archive", + type=str, + help="specify dataset folder in archive", + ) + parser.add_argument( + "--dataset-train", + default=["train-clean-100"], + nargs="+", + type=str, + help="select which part of librispeech to train with", + ) + parser.add_argument( + "--dataset-valid", + default=["dev-clean"], + nargs="+", + type=str, + help="select which part of librispeech to validate with", + ) + parser.add_argument( + "--distributed", action="store_true", help="enable DistributedDataParallel" + ) + parser.add_argument("--seed", type=int, default=0, help="random seed") + parser.add_argument( + "--world-size", type=int, default=8, help="the world size to initiate DPP" + ) + parser.add_argument("--jit", action="store_true", help="if used, model is jitted") + + args = parser.parse_args() + logging.info(args) + return args + + +def setup_distributed(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) + + +def model_length_function(tensor): + if tensor.shape[1] == 1: + # waveform mode + return int(tensor.shape[0]) // 160 // 2 + 1 + return int(tensor.shape[0]) // 2 + 1 + + +def compute_error_rates(outputs, targets, decoder, language_model, metric): + output = outputs.transpose(0, 1).to("cpu") + output = decoder(output) + + # Compute CER + + output = language_model.decode(output.tolist()) + target = language_model.decode(targets.tolist()) + + print_length = 20 + for i in range(2): + # Print a few examples + output_print = output[i].ljust(print_length)[:print_length] + target_print = target[i].ljust(print_length)[:print_length] + logging.info("Target: %s Output: %s", target_print, output_print) + + cers = [levenshtein_distance(t, o) for t, o in zip(target, output)] + cers = sum(cers) + n = sum(len(t) for t in target) + metric["cer over target length"] = cers / n + metric["cumulative cer"] += cers + metric["total chars"] += n + metric["cumulative cer over target length"] = metric["cer"] / metric["total chars"] + + # Compute WER + + output = [o.split(language_model.char_space) for o in output] + target = [t.split(language_model.char_space) for t in target] + + wers = [levenshtein_distance(t, o) for t, o in zip(target, output)] + wers = sum(wers) + n = sum(len(t) for t in target) + metric["wer over target length"] = wers / n + metric["cumulative wer"] += wers + metric["total words"] += n + metric["cumulative wer over target length"] = metric["wer"] / metric["total words"] + + +def train_one_epoch( + model, + criterion, + optimizer, + scheduler, + data_loader, + decoder, + language_model, + device, + epoch, + clip_grad, + disable_logger=False, + reduce_lr_on_plateau=False, +): + + model.train() + + metric = MetricLogger("train", disable=disable_logger) + metric["epoch"] = epoch + + for inputs, targets, tensors_lengths, target_lengths in bg_iterator( + data_loader, maxsize=2 + ): + + start = time() + inputs = inputs.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + + # keep batch first for data parallel + outputs = model(inputs).transpose(-1, -2).transpose(0, 1) + + # CTC + # outputs: input length, batch size, number of classes (including blank) + # targets: batch size, max target length + # input_lengths: batch size + # target_lengths: batch size + + loss = criterion(outputs, targets, tensors_lengths, target_lengths) + + optimizer.zero_grad() + loss.backward() + + if clip_grad > 0: + metric["gradient"] = torch.nn.utils.clip_grad_norm_( + model.parameters(), clip_grad + ) + + optimizer.step() + + compute_error_rates(outputs, targets, decoder, language_model, metric) + + try: + metric["lr"] = scheduler.get_last_lr()[0] + except AttributeError: + metric["lr"] = optimizer.param_groups[0]["lr"] + + metric["batch size"] = len(inputs) + metric["n_channel"] = inputs.shape[1] + metric["n_time"] = inputs.shape[-1] + metric["dataset length"] += metric["batch size"] + metric["iteration"] += 1 + metric["loss"] = loss.item() + metric["cumulative loss"] += metric["loss"] + metric["average loss"] = metric["cumulative loss"] / metric["iteration"] + metric["iteration time"] = time() - start + metric["epoch time"] += metric["iteration time"] + metric() + + if reduce_lr_on_plateau and isinstance(scheduler, ReduceLROnPlateau): + scheduler.step(metric["average loss"]) + elif not isinstance(scheduler, ReduceLROnPlateau): + scheduler.step() + + +def evaluate( + model, + criterion, + data_loader, + decoder, + language_model, + device, + epoch, + disable_logger=False, +): + + with torch.no_grad(): + + model.eval() + start = time() + metric = MetricLogger("validation", disable=disable_logger) + metric["epoch"] = epoch + + for inputs, targets, tensors_lengths, target_lengths in bg_iterator( + data_loader, maxsize=2 + ): + + inputs = inputs.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + + # keep batch first for data parallel + outputs = model(inputs).transpose(-1, -2).transpose(0, 1) + + # CTC + # outputs: input length, batch size, number of classes (including blank) + # targets: batch size, max target length + # input_lengths: batch size + # target_lengths: batch size + + metric["cumulative loss"] += criterion( + outputs, targets, tensors_lengths, target_lengths + ).item() + + metric["dataset length"] += len(inputs) + metric["iteration"] += 1 + + compute_error_rates(outputs, targets, decoder, language_model, metric) + + metric["average loss"] = metric["cumulative loss"] / metric["iteration"] + metric["validation time"] = time() - start + metric() + + return metric["average loss"] + + +def main(rank, args): + + # Distributed setup + + if args.distributed: + setup_distributed(rank, args.world_size) + + not_main_rank = args.distributed and rank != 0 + + logging.info("Start time: %s", datetime.now()) + + # Explicitly set seed to make sure models created in separate processes + # start from same random weights and biases + torch.manual_seed(args.seed) + + # Empty CUDA cache + torch.cuda.empty_cache() + + # Change backend for flac files + torchaudio.set_audio_backend("soundfile") + + # Transforms + + melkwargs = { + "n_fft": args.win_length, + "n_mels": args.n_bins, + "hop_length": args.hop_length, + } + + sample_rate_original = 16000 + + if args.type == "mfcc": + transforms = torch.nn.Sequential( + torchaudio.transforms.MFCC( + sample_rate=sample_rate_original, + n_mfcc=args.n_bins, + melkwargs=melkwargs, + ), + ) + num_features = args.n_bins + elif args.type == "waveform": + transforms = torch.nn.Sequential(UnsqueezeFirst()) + num_features = 1 + else: + raise ValueError("Model type not supported") + + if args.normalize: + transforms = torch.nn.Sequential(transforms, Normalize()) + + augmentations = torch.nn.Sequential() + if args.freq_mask: + augmentations = torch.nn.Sequential( + augmentations, + torchaudio.transforms.FrequencyMasking(freq_mask_param=args.freq_mask), + ) + if args.time_mask: + augmentations = torch.nn.Sequential( + augmentations, + torchaudio.transforms.TimeMasking(time_mask_param=args.time_mask), + ) + + # Text preprocessing + + char_blank = "*" + char_space = " " + char_apostrophe = "'" + labels = char_blank + char_space + char_apostrophe + string.ascii_lowercase + language_model = LanguageModel(labels, char_blank, char_space) + + # Dataset + + training, validation = split_process_librispeech( + [args.dataset_train, args.dataset_valid], + [transforms, transforms], + language_model, + root=args.dataset_root, + folder_in_archive=args.dataset_folder_in_archive, + ) + + # Decoder + + if args.decoder == "greedy": + decoder = GreedyDecoder() + else: + raise ValueError("Selected decoder not supported") + + # Model + + model = Wav2Letter( + num_classes=language_model.length, + input_type=args.type, + num_features=num_features, + ) + + if args.jit: + model = torch.jit.script(model) + + if args.distributed: + n = torch.cuda.device_count() // args.world_size + devices = list(range(rank * n, (rank + 1) * n)) + model = model.to(devices[0]) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=devices) + else: + devices = ["cuda" if torch.cuda.is_available() else "cpu"] + model = model.to(devices[0], non_blocking=True) + model = torch.nn.DataParallel(model) + + n = count_parameters(model) + logging.info("Number of parameters: %s", n) + + # Optimizer + + if args.optimizer == "adadelta": + optimizer = Adadelta( + model.parameters(), + lr=args.learning_rate, + weight_decay=args.weight_decay, + eps=args.eps, + rho=args.rho, + ) + elif args.optimizer == "sgd": + optimizer = SGD( + model.parameters(), + lr=args.learning_rate, + momentum=args.momentum, + weight_decay=args.weight_decay, + ) + elif args.optimizer == "adam": + optimizer = Adam( + model.parameters(), + lr=args.learning_rate, + momentum=args.momentum, + weight_decay=args.weight_decay, + ) + elif args.optimizer == "adamw": + optimizer = AdamW( + model.parameters(), + lr=args.learning_rate, + momentum=args.momentum, + weight_decay=args.weight_decay, + ) + else: + raise ValueError("Selected optimizer not supported") + + if args.scheduler == "exponential": + scheduler = ExponentialLR(optimizer, gamma=args.gamma) + elif args.scheduler == "reduceonplateau": + scheduler = ReduceLROnPlateau(optimizer, patience=10, threshold=1e-3) + else: + raise ValueError("Selected scheduler not supported") + + criterion = torch.nn.CTCLoss( + blank=language_model.mapping[char_blank], zero_infinity=False + ) + + # Data Loader + + collate_fn_train = collate_factory(model_length_function, augmentations) + collate_fn_valid = collate_factory(model_length_function) + + loader_training_params = { + "num_workers": args.workers, + "pin_memory": True, + "shuffle": True, + "drop_last": True, + } + loader_validation_params = loader_training_params.copy() + loader_validation_params["shuffle"] = False + + loader_training = DataLoader( + training, + batch_size=args.batch_size, + collate_fn=collate_fn_train, + **loader_training_params, + ) + loader_validation = DataLoader( + validation, + batch_size=args.batch_size, + collate_fn=collate_fn_valid, + **loader_validation_params, + ) + + # Setup checkpoint + + best_loss = 1.0 + + load_checkpoint = args.checkpoint and os.path.isfile(args.checkpoint) + + if args.distributed: + torch.distributed.barrier() + + if load_checkpoint: + logging.info("Checkpoint: loading %s", args.checkpoint) + checkpoint = torch.load(args.checkpoint) + + args.start_epoch = checkpoint["epoch"] + best_loss = checkpoint["best_loss"] + + model.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + scheduler.load_state_dict(checkpoint["scheduler"]) + + logging.info( + "Checkpoint: loaded '%s' at epoch %s", args.checkpoint, checkpoint["epoch"] + ) + else: + logging.info("Checkpoint: not found") + + save_checkpoint( + { + "epoch": args.start_epoch, + "state_dict": model.state_dict(), + "best_loss": best_loss, + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + }, + False, + args.checkpoint, + not_main_rank, + ) + + if args.distributed: + torch.distributed.barrier() + + torch.autograd.set_detect_anomaly(False) + + for epoch in range(args.start_epoch, args.epochs): + + logging.info("Epoch: %s", epoch) + + train_one_epoch( + model, + criterion, + optimizer, + scheduler, + loader_training, + decoder, + language_model, + devices[0], + epoch, + args.clip_grad, + not_main_rank, + not args.reduce_lr_valid, + ) + + if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1: + + loss = evaluate( + model, + criterion, + loader_validation, + decoder, + language_model, + devices[0], + epoch, + not_main_rank, + ) + + is_best = loss < best_loss + best_loss = min(loss, best_loss) + save_checkpoint( + { + "epoch": epoch + 1, + "state_dict": model.state_dict(), + "best_loss": best_loss, + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + }, + is_best, + args.checkpoint, + not_main_rank, + ) + + if args.reduce_lr_valid and isinstance(scheduler, ReduceLROnPlateau): + scheduler.step(loss) + + logging.info("End time: %s", datetime.now()) + + if args.distributed: + torch.distributed.destroy_process_group() + + +def spawn_main(main, args): + if args.distributed: + torch.multiprocessing.spawn( + main, args=(args,), nprocs=args.world_size, join=True + ) + else: + main(0, args) + + +if __name__ == "__main__": + + logging.basicConfig(level=logging.INFO) + args = parse_args() + spawn_main(main, args) diff --git a/examples/pipeline_wav2letter/metrics.py b/examples/pipeline_wav2letter/metrics.py new file mode 100644 index 0000000000..cba6595016 --- /dev/null +++ b/examples/pipeline_wav2letter/metrics.py @@ -0,0 +1,38 @@ +from typing import List, Union + + +def levenshtein_distance(r: Union[str, List[str]], h: Union[str, List[str]]): + """ + Calculate the Levenshtein distance between two lists or strings. + """ + + # Initialisation + dold = list(range(len(h) + 1)) + dnew = list(0 for _ in range(len(h) + 1)) + + # Computation + for i in range(1, len(r) + 1): + dnew[0] = i + for j in range(1, len(h) + 1): + if r[i - 1] == h[j - 1]: + dnew[j] = dold[j - 1] + else: + substitution = dold[j - 1] + 1 + insertion = dnew[j - 1] + 1 + deletion = dold[j] + 1 + dnew[j] = min(substitution, insertion, deletion) + + dnew, dold = dold, dnew + + return dold[-1] + + +if __name__ == "__main__": + assert levenshtein_distance("abc", "abc") == 0 + assert levenshtein_distance("aaa", "aba") == 1 + assert levenshtein_distance("aba", "aaa") == 1 + assert levenshtein_distance("aa", "aaa") == 1 + assert levenshtein_distance("aaa", "aa") == 1 + assert levenshtein_distance("abc", "bcd") == 2 + assert levenshtein_distance(["hello", "world"], ["hello", "world", "!"]) == 1 + assert levenshtein_distance(["hello", "world"], ["world", "hello", "!"]) == 2 diff --git a/examples/pipeline_wav2letter/transforms.py b/examples/pipeline_wav2letter/transforms.py new file mode 100644 index 0000000000..f1d9115c87 --- /dev/null +++ b/examples/pipeline_wav2letter/transforms.py @@ -0,0 +1,11 @@ +import torch + + +class Normalize(torch.nn.Module): + def forward(self, tensor): + return (tensor - tensor.mean(-1, keepdim=True)) / tensor.std(-1, keepdim=True) + + +class UnsqueezeFirst(torch.nn.Module): + def forward(self, tensor): + return tensor.unsqueeze(0) diff --git a/examples/pipeline_wav2letter/utils.py b/examples/pipeline_wav2letter/utils.py new file mode 100644 index 0000000000..7cd07a2a80 --- /dev/null +++ b/examples/pipeline_wav2letter/utils.py @@ -0,0 +1,55 @@ +import json +import logging +import os +import shutil +from collections import defaultdict + +import torch + + +class MetricLogger(defaultdict): + def __init__(self, name, print_freq=1, disable=False): + super().__init__(lambda: 0.0) + self.disable = disable + self.print_freq = print_freq + self._iter = 0 + self["name"] = name + + def __str__(self): + return json.dumps(self) + + def __call__(self): + self._iter = (self._iter + 1) % self.print_freq + if not self.disable and not self._iter: + print(self, flush=True) + + +def save_checkpoint(state, is_best, filename, disable): + """ + Save the model to a temporary file first, + then copy it to filename, in case the signal interrupts + the torch.save() process. + """ + + if disable: + return + + if filename == "": + return + + tempfile = filename + ".temp" + + # Remove tempfile in case interuption during the copying from tempfile to filename + if os.path.isfile(tempfile): + os.remove(tempfile) + + torch.save(state, tempfile) + if os.path.isfile(tempfile): + os.rename(tempfile, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + logging.warning("Checkpoint: saved") + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad)