From ac815f9726b2f445c86415e6cb965835d0ab923a Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Thu, 23 Jul 2020 09:56:54 -0700 Subject: [PATCH] Add libritts dataset option --- examples/pipeline_wavernn/datasets.py | 20 ++++++++++++++------ examples/pipeline_wavernn/main.py | 11 +++++++++-- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py index 8d3068a229..58c2853f3d 100644 --- a/examples/pipeline_wavernn/datasets.py +++ b/examples/pipeline_wavernn/datasets.py @@ -4,7 +4,7 @@ import torch import torchaudio from torch.utils.data.dataset import random_split -from torchaudio.datasets import LJSPEECH +from torchaudio.datasets import LJSPEECH, LIBRITTS from torchaudio.transforms import MuLawEncoding from processing import bits_to_normalized_waveform, normalized_waveform_to_bits @@ -48,12 +48,20 @@ def process_datapoint(self, item): return item[0].squeeze(0), specgram -def split_process_ljspeech(args, transforms): - data = LJSPEECH(root=args.file_path, download=False) +def split_process_dataset(args, transforms): + if args.dataset == 'ljspeech': + data = LJSPEECH(root=args.file_path, download=False) - val_length = int(len(data) * args.val_ratio) - lengths = [len(data) - val_length, val_length] - train_dataset, val_dataset = random_split(data, lengths) + val_length = int(len(data) * args.val_ratio) + lengths = [len(data) - val_length, val_length] + train_dataset, val_dataset = random_split(data, lengths) + + elif args.dataset == 'libritts': + train_dataset = LIBRITTS(root=args.file_path, url='train-clean-100', download=False) + val_dataset = LIBRITTS(root=args.file_path, url='dev-clean', download=False) + + else: + raise ValueError(f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}") train_dataset = Processed(train_dataset, transforms) val_dataset = Processed(val_dataset, transforms) diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index 032e9fc70e..6b6252d1ff 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -15,7 +15,7 @@ from torchaudio.datasets.utils import bg_iterator from torchaudio.models._wavernn import _WaveRNN -from datasets import collate_factory, split_process_ljspeech +from datasets import collate_factory, split_process_dataset from losses import LongCrossEntropyLoss, MoLLoss from processing import LinearToMel, NormalizeDB from utils import MetricLogger, count_parameters, save_checkpoint @@ -55,6 +55,13 @@ def parse_args(): metavar="N", help="print frequency in epochs", ) + parser.add_argument( + "--dataset", + default="ljspeech", + choices=["ljspeech", "libritts"], + type=str, + help="select dataset to train with", + ) parser.add_argument( "--batch-size", default=256, type=int, metavar="N", help="mini-batch size" ) @@ -269,7 +276,7 @@ def main(args): NormalizeDB(min_level_db=args.min_level_db), ) - train_dataset, val_dataset = split_process_ljspeech(args, transforms) + train_dataset, val_dataset = split_process_dataset(args, transforms) loader_training_params = { "num_workers": args.workers,