diff --git a/examples/pipeline_wavernn/README.md b/examples/pipeline_wavernn/README.md new file mode 100644 index 0000000000..865f06c181 --- /dev/null +++ b/examples/pipeline_wavernn/README.md @@ -0,0 +1,36 @@ +This is an example vocoder pipeline using the WaveRNN model trained with LJSpeech. WaveRNN model is based on the implementation from [this repository](https://github.com/fatchord/WaveRNN). The original implementation was +introduced in "Efficient Neural Audio Synthesis". WaveRNN and LJSpeech are available in torchaudio. + +### Usage + +An example can be invoked as follows. +``` +python main.py \ + --batch-size 256 \ + --learning-rate 1e-4 \ + --n-freq 80 \ + --loss 'crossentropy' \ + --n-bits 8 \ +``` + +### Output + +The information reported at each iteration and epoch (e.g. loss) is printed to standard output in the form of one json per line. Here is an example python function to parse the output if redirected to a file. +```python +def read_json(filename): + """ + Convert the standard output saved to filename into a pandas dataframe for analysis. + """ + + import pandas + import json + + with open(filename, "r") as f: + data = f.read() + + # pandas doesn't read single quotes for json + data = data.replace("'", '"') + + data = [json.loads(l) for l in data.splitlines()] + return pandas.DataFrame(data) +``` diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py new file mode 100644 index 0000000000..8d3068a229 --- /dev/null +++ b/examples/pipeline_wavernn/datasets.py @@ -0,0 +1,115 @@ +import os +import random + +import torch +import torchaudio +from torch.utils.data.dataset import random_split +from torchaudio.datasets import LJSPEECH +from torchaudio.transforms import MuLawEncoding + +from processing import bits_to_normalized_waveform, normalized_waveform_to_bits + + +class MapMemoryCache(torch.utils.data.Dataset): + r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory. + """ + + def __init__(self, dataset): + self.dataset = dataset + self._cache = [None] * len(dataset) + + def __getitem__(self, n): + if self._cache[n] is not None: + return self._cache[n] + + item = self.dataset[n] + self._cache[n] = item + + return item + + def __len__(self): + return len(self.dataset) + + +class Processed(torch.utils.data.Dataset): + def __init__(self, dataset, transforms): + self.dataset = dataset + self.transforms = transforms + + def __getitem__(self, key): + item = self.dataset[key] + return self.process_datapoint(item) + + def __len__(self): + return len(self.dataset) + + def process_datapoint(self, item): + specgram = self.transforms(item[0]) + return item[0].squeeze(0), specgram + + +def split_process_ljspeech(args, transforms): + data = LJSPEECH(root=args.file_path, download=False) + + val_length = int(len(data) * args.val_ratio) + lengths = [len(data) - val_length, val_length] + train_dataset, val_dataset = random_split(data, lengths) + + train_dataset = Processed(train_dataset, transforms) + val_dataset = Processed(val_dataset, transforms) + + train_dataset = MapMemoryCache(train_dataset) + val_dataset = MapMemoryCache(val_dataset) + + return train_dataset, val_dataset + + +def collate_factory(args): + def raw_collate(batch): + + pad = (args.kernel_size - 1) // 2 + + # input waveform length + wave_length = args.hop_length * args.seq_len_factor + # input spectrogram length + spec_length = args.seq_len_factor + pad * 2 + + # max start postion in spectrogram + max_offsets = [x[1].shape[-1] - (spec_length + pad * 2) for x in batch] + + # random start postion in spectrogram + spec_offsets = [random.randint(0, offset) for offset in max_offsets] + # random start postion in waveform + wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets] + + waveform_combine = [ + x[0][wave_offsets[i]: wave_offsets[i] + wave_length + 1] + for i, x in enumerate(batch) + ] + specgram = [ + x[1][:, spec_offsets[i]: spec_offsets[i] + spec_length] + for i, x in enumerate(batch) + ] + + specgram = torch.stack(specgram) + waveform_combine = torch.stack(waveform_combine) + + waveform = waveform_combine[:, :wave_length] + target = waveform_combine[:, 1:] + + # waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'crossentropy' + if args.loss == "crossentropy": + + if args.mulaw: + mulaw_encode = MuLawEncoding(2 ** args.n_bits) + waveform = mulaw_encode(waveform) + target = mulaw_encode(target) + + waveform = bits_to_normalized_waveform(waveform, args.n_bits) + + else: + target = normalized_waveform_to_bits(target, args.n_bits) + + return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1) + + return raw_collate diff --git a/examples/pipeline_wavernn/losses.py b/examples/pipeline_wavernn/losses.py new file mode 100644 index 0000000000..a4494b05fb --- /dev/null +++ b/examples/pipeline_wavernn/losses.py @@ -0,0 +1,119 @@ +import math + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +class LongCrossEntropyLoss(nn.Module): + r""" CrossEntropy loss + """ + + def __init__(self): + super(LongCrossEntropyLoss, self).__init__() + + def forward(self, output, target): + output = output.transpose(1, 2) + target = target.long() + + criterion = nn.CrossEntropyLoss() + return criterion(output, target) + + +class MoLLoss(nn.Module): + r""" Discretized mixture of logistic distributions loss + + Adapted from wavenet vocoder + (https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py) + Explanation of loss (https://github.com/Rayhane-mamah/Tacotron-2/issues/155) + + Args: + y_hat (Tensor): Predicted output (n_batch x n_time x n_channel) + y (Tensor): Target (n_batch x n_time x 1) + num_classes (int): Number of classes + log_scale_min (float): Log scale minimum value + reduce (bool): If True, the losses are averaged or summed for each minibatch + + Returns + Tensor: loss + """ + + def __init__(self, num_classes=65536, log_scale_min=None, reduce=True): + super(MoLLoss, self).__init__() + self.num_classes = num_classes + self.log_scale_min = log_scale_min + self.reduce = reduce + + def forward(self, y_hat, y): + y = y.unsqueeze(-1) + + if self.log_scale_min is None: + self.log_scale_min = math.log(1e-14) + + assert y_hat.dim() == 3 + assert y_hat.size(-1) % 3 == 0 + + nr_mix = y_hat.size(-1) // 3 + + # unpack parameters (n_batch, n_time, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix: 2 * nr_mix] + log_scales = torch.clamp( + y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=self.log_scale_min + ) + + # (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures) + y = y.expand_as(means) + + centered_y = y - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1.0 / (self.num_classes - 1)) + cdf_plus = torch.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1.0 / (self.num_classes - 1)) + cdf_min = torch.sigmoid(min_in) + + # log probability for edge case of 0 (before scaling) + # equivalent: torch.log(F.sigmoid(plus_in)) + log_cdf_plus = plus_in - F.softplus(plus_in) + + # log probability for edge case of 255 (before scaling) + # equivalent: (1 - F.sigmoid(min_in)).log() + log_one_minus_cdf_min = -F.softplus(min_in) + + # probability for all other cases + cdf_delta = cdf_plus - cdf_min + + mid_in = inv_stdv * centered_y + # log probability in the center of the bin, to be used in extreme cases + log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in) + + inner_inner_cond = (cdf_delta > 1e-5).float() + + inner_inner_out = inner_inner_cond * torch.log( + torch.clamp(cdf_delta, min=1e-12) + ) + (1.0 - inner_inner_cond) * ( + log_pdf_mid - math.log((self.num_classes - 1) / 2) + ) + inner_cond = (y > 0.999).float() + inner_out = ( + inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out + ) + cond = (y < -0.999).float() + log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out + + log_probs = log_probs + F.log_softmax(logit_probs, -1) + + if self.reduce: + return -torch.mean(_log_sum_exp(log_probs)) + else: + return -_log_sum_exp(log_probs).unsqueeze(-1) + + +def _log_sum_exp(x): + r""" Numerically stable log_sum_exp implementation that prevents overflow + """ + + axis = len(x.size()) - 1 + m, _ = torch.max(x, dim=axis) + m2, _ = torch.max(x, dim=axis, keepdim=True) + return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py new file mode 100644 index 0000000000..032e9fc70e --- /dev/null +++ b/examples/pipeline_wavernn/main.py @@ -0,0 +1,390 @@ +import argparse +import logging +import os +import signal +from collections import defaultdict +from datetime import datetime +from time import time +from typing import List + +import torch +import torchaudio +from torch import nn as nn +from torch.optim import Adam +from torch.utils.data import DataLoader +from torchaudio.datasets.utils import bg_iterator +from torchaudio.models._wavernn import _WaveRNN + +from datasets import collate_factory, split_process_ljspeech +from losses import LongCrossEntropyLoss, MoLLoss +from processing import LinearToMel, NormalizeDB +from utils import MetricLogger, count_parameters, save_checkpoint + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--workers", + default=4, + type=int, + metavar="N", + help="number of data loading workers", + ) + parser.add_argument( + "--checkpoint", + default="", + type=str, + metavar="PATH", + help="path to latest checkpoint", + ) + parser.add_argument( + "--epochs", + default=8000, + type=int, + metavar="N", + help="number of total epochs to run", + ) + parser.add_argument( + "--start-epoch", default=0, type=int, metavar="N", help="manual epoch number" + ) + parser.add_argument( + "--print-freq", + default=10, + type=int, + metavar="N", + help="print frequency in epochs", + ) + parser.add_argument( + "--batch-size", default=256, type=int, metavar="N", help="mini-batch size" + ) + parser.add_argument( + "--learning-rate", default=1e-4, type=float, metavar="LR", help="learning rate", + ) + parser.add_argument("--clip-grad", metavar="NORM", type=float, default=4.0) + parser.add_argument( + "--mulaw", + default=True, + action="store_true", + help="if used, waveform is mulaw encoded", + ) + parser.add_argument( + "--jit", default=False, action="store_true", help="if used, model is jitted" + ) + parser.add_argument( + "--upsample-scales", + default=[5, 5, 11], + type=List[int], + help="the list of upsample scales", + ) + parser.add_argument( + "--n-bits", default=8, type=int, help="the bits of output waveform", + ) + parser.add_argument( + "--sample-rate", + default=22050, + type=int, + help="the rate of audio dimensions (samples per second)", + ) + parser.add_argument( + "--hop-length", + default=275, + type=int, + help="the number of samples between the starts of consecutive frames", + ) + parser.add_argument( + "--win-length", default=1100, type=int, help="the length of the STFT window", + ) + parser.add_argument( + "--f-min", default=40.0, type=float, help="the minimum frequency", + ) + parser.add_argument( + "--min-level-db", + default=-100, + type=float, + help="the minimum db value for spectrogam normalization", + ) + parser.add_argument( + "--n-res-block", default=10, type=int, help="the number of ResBlock in stack", + ) + parser.add_argument( + "--n-rnn", default=512, type=int, help="the dimension of RNN layer", + ) + parser.add_argument( + "--n-fc", default=512, type=int, help="the dimension of fully connected layer", + ) + parser.add_argument( + "--kernel-size", + default=5, + type=int, + help="the number of kernel size in the first Conv1d layer", + ) + parser.add_argument( + "--n-freq", default=80, type=int, help="the number of spectrogram bins to use", + ) + parser.add_argument( + "--n-hidden-melresnet", + default=128, + type=int, + help="the number of hidden dimensions of resblock in melresnet", + ) + parser.add_argument( + "--n-output-melresnet", default=128, type=int, help="the output dimension of melresnet", + ) + parser.add_argument( + "--n-fft", default=2048, type=int, help="the number of Fourier bins", + ) + parser.add_argument( + "--loss", + default="crossentropy", + choices=["crossentropy", "mol"], + type=str, + help="the type of loss", + ) + parser.add_argument( + "--seq-len-factor", + default=5, + type=int, + help="the length of each waveform to process per batch = hop_length * seq_len_factor", + ) + parser.add_argument( + "--val-ratio", + default=0.1, + type=float, + help="the ratio of waveforms for validation", + ) + parser.add_argument( + "--file-path", default="", type=str, help="the path of audio files", + ) + + args = parser.parse_args() + return args + + +def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch): + + model.train() + + sums = defaultdict(lambda: 0.0) + start1 = time() + + metric = MetricLogger("train_iteration") + metric["epoch"] = epoch + + for waveform, specgram, target in bg_iterator(data_loader, maxsize=2): + + start2 = time() + + waveform = waveform.to(device) + specgram = specgram.to(device) + target = target.to(device) + + output = model(waveform, specgram) + output, target = output.squeeze(1), target.squeeze(1) + + loss = criterion(output, target) + loss_item = loss.item() + sums["loss"] += loss_item + metric["loss"] = loss_item + + optimizer.zero_grad() + loss.backward() + + if args.clip_grad > 0: + gradient = torch.nn.utils.clip_grad_norm_( + model.parameters(), args.clip_grad + ) + sums["gradient"] += gradient.item() + metric["gradient"] = gradient.item() + + optimizer.step() + + metric["iteration"] = sums["iteration"] + metric["time"] = time() - start2 + metric() + sums["iteration"] += 1 + + avg_loss = sums["loss"] / len(data_loader) + + metric = MetricLogger("train_epoch") + metric["epoch"] = epoch + metric["loss"] = sums["loss"] / len(data_loader) + metric["gradient"] = avg_loss + metric["time"] = time() - start1 + metric() + + +def validate(model, criterion, data_loader, device, epoch): + + with torch.no_grad(): + + model.eval() + sums = defaultdict(lambda: 0.0) + start = time() + + for waveform, specgram, target in bg_iterator(data_loader, maxsize=2): + + waveform = waveform.to(device) + specgram = specgram.to(device) + target = target.to(device) + + output = model(waveform, specgram) + output, target = output.squeeze(1), target.squeeze(1) + + loss = criterion(output, target) + sums["loss"] += loss.item() + + avg_loss = sums["loss"] / len(data_loader) + + metric = MetricLogger("validation") + metric["epoch"] = epoch + metric["loss"] = avg_loss + metric["time"] = time() - start + metric() + + return avg_loss + + +def main(args): + + devices = ["cuda" if torch.cuda.is_available() else "cpu"] + + logging.info("Start time: {}".format(str(datetime.now()))) + + melkwargs = { + "n_fft": args.n_fft, + "power": 1, + "hop_length": args.hop_length, + "win_length": args.win_length, + } + + transforms = torch.nn.Sequential( + torchaudio.transforms.Spectrogram(**melkwargs), + LinearToMel( + sample_rate=args.sample_rate, + n_fft=args.n_fft, + n_mels=args.n_freq, + fmin=args.f_min, + ), + NormalizeDB(min_level_db=args.min_level_db), + ) + + train_dataset, val_dataset = split_process_ljspeech(args, transforms) + + loader_training_params = { + "num_workers": args.workers, + "pin_memory": False, + "shuffle": True, + "drop_last": False, + } + loader_validation_params = loader_training_params.copy() + loader_validation_params["shuffle"] = False + + collate_fn = collate_factory(args) + + train_loader = DataLoader( + train_dataset, + batch_size=args.batch_size, + collate_fn=collate_fn, + **loader_training_params, + ) + val_loader = DataLoader( + val_dataset, + batch_size=args.batch_size, + collate_fn=collate_fn, + **loader_validation_params, + ) + + n_classes = 2 ** args.n_bits if args.loss == "crossentropy" else 30 + + model = _WaveRNN( + upsample_scales=args.upsample_scales, + n_classes=n_classes, + hop_length=args.hop_length, + n_res_block=args.n_res_block, + n_rnn=args.n_rnn, + n_fc=args.n_fc, + kernel_size=args.kernel_size, + n_freq=args.n_freq, + n_hidden=args.n_hidden_melresnet, + n_output=args.n_output_melresnet, + ) + + if args.jit: + model = torch.jit.script(model) + + model = torch.nn.DataParallel(model) + model = model.to(devices[0], non_blocking=True) + + n = count_parameters(model) + logging.info(f"Number of parameters: {n}") + + # Optimizer + optimizer_params = { + "lr": args.learning_rate, + } + + optimizer = Adam(model.parameters(), **optimizer_params) + + criterion = LongCrossEntropyLoss() if args.loss == "crossentropy" else MoLLoss() + + best_loss = 10.0 + + if args.checkpoint and os.path.isfile(args.checkpoint): + logging.info(f"Checkpoint: loading '{args.checkpoint}'") + checkpoint = torch.load(args.checkpoint) + + args.start_epoch = checkpoint["epoch"] + best_loss = checkpoint["best_loss"] + + model.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + + logging.info( + f"Checkpoint: loaded '{args.checkpoint}' at epoch {checkpoint['epoch']}" + ) + else: + logging.info("Checkpoint: not found") + + save_checkpoint( + { + "epoch": args.start_epoch, + "state_dict": model.state_dict(), + "best_loss": best_loss, + "optimizer": optimizer.state_dict(), + }, + False, + args.checkpoint, + ) + + for epoch in range(args.start_epoch, args.epochs): + + train_one_epoch( + model, criterion, optimizer, train_loader, devices[0], epoch, + ) + + if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1: + + sum_loss = validate(model, criterion, val_loader, devices[0], epoch) + + is_best = sum_loss < best_loss + best_loss = min(sum_loss, best_loss) + save_checkpoint( + { + "epoch": epoch + 1, + "state_dict": model.state_dict(), + "best_loss": best_loss, + "optimizer": optimizer.state_dict(), + }, + is_best, + args.checkpoint, + ) + + logging.info(f"End time: {datetime.now()}") + + +if __name__ == "__main__": + + logging.basicConfig(level=logging.INFO) + args = parse_args() + main(args) diff --git a/examples/pipeline_wavernn/processing.py b/examples/pipeline_wavernn/processing.py new file mode 100644 index 0000000000..b22d60dae4 --- /dev/null +++ b/examples/pipeline_wavernn/processing.py @@ -0,0 +1,58 @@ +import librosa +import torch +import torch.nn as nn + + +# TODO Replace by torchaudio, once https://github.com/pytorch/audio/pull/593 is resolved +class LinearToMel(nn.Module): + def __init__(self, sample_rate, n_fft, n_mels, fmin, htk=False, norm="slaney"): + super().__init__() + self.sample_rate = sample_rate + self.n_fft = n_fft + self.n_mels = n_mels + self.fmin = fmin + self.htk = htk + self.norm = norm + + def forward(self, specgram): + specgram = librosa.feature.melspectrogram( + S=specgram.squeeze(0).numpy(), + sr=self.sample_rate, + n_fft=self.n_fft, + n_mels=self.n_mels, + fmin=self.fmin, + htk=self.htk, + norm=self.norm, + ) + return torch.from_numpy(specgram) + + +class NormalizeDB(nn.Module): + r"""Normalize the spectrogram with a minimum db value + """ + + def __init__(self, min_level_db): + super().__init__() + self.min_level_db = min_level_db + + def forward(self, specgram): + specgram = 20 * torch.log10(torch.clamp(specgram, min=1e-5)) + return torch.clamp( + (self.min_level_db - specgram) / self.min_level_db, min=0, max=1 + ) + + +def normalized_waveform_to_bits(waveform, bits): + r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1] + """ + + assert abs(waveform).max() <= 1.0 + waveform = (waveform + 1.0) * (2 ** bits - 1) / 2 + return torch.clamp(waveform, 0, 2 ** bits - 1).int() + + +def bits_to_normalized_waveform(label, bits): + r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1] + """ + + return 2 * label / (2 ** bits - 1.0) - 1.0 diff --git a/examples/pipeline_wavernn/utils.py b/examples/pipeline_wavernn/utils.py new file mode 100644 index 0000000000..e924c9f512 --- /dev/null +++ b/examples/pipeline_wavernn/utils.py @@ -0,0 +1,61 @@ +import logging +import os +import shutil +from collections import defaultdict, deque + +import torch + + +class MetricLogger: + r"""Logger for model metrics + """ + + def __init__(self, group, print_freq=1): + self.print_freq = print_freq + self._iter = 0 + self.data = defaultdict(lambda: deque(maxlen=self.print_freq)) + self.data["group"].append(group) + + def __setitem__(self, key, value): + self.data[key].append(value) + + def _get_last(self): + return {k: v[-1] for k, v in self.data.items()} + + def __str__(self): + return str(self._get_last()) + + def __call__(self): + self._iter = (self._iter + 1) % self.print_freq + if not self._iter: + print(self, flush=True) + + +def save_checkpoint(state, is_best, filename): + r"""Save the model to a temporary file first, + then copy it to filename, in case the signal interrupts + the torch.save() process. + """ + + if filename == "": + return + + tempfile = filename + ".temp" + + # Remove tempfile in case interuption during the copying from tempfile to filename + if os.path.isfile(tempfile): + os.remove(tempfile) + + torch.save(state, tempfile) + if os.path.isfile(tempfile): + os.rename(tempfile, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + logging.info("Checkpoint: saved") + + +def count_parameters(model): + r"""Count the total number of parameters in the model + """ + + return sum(p.numel() for p in model.parameters() if p.requires_grad)