Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions examples/pipeline_wavernn/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torchaudio
from torch.utils.data.dataset import random_split
from torchaudio.datasets import LJSPEECH
from torchaudio.datasets import LJSPEECH, LIBRITTS
from torchaudio.transforms import MuLawEncoding

from processing import bits_to_normalized_waveform, normalized_waveform_to_bits
Expand Down Expand Up @@ -48,12 +48,20 @@ def process_datapoint(self, item):
return item[0].squeeze(0), specgram


def split_process_ljspeech(args, transforms):
data = LJSPEECH(root=args.file_path, download=False)
def split_process_dataset(args, transforms):
if args.dataset == 'ljspeech':
data = LJSPEECH(root=args.file_path, download=False)

val_length = int(len(data) * args.val_ratio)
lengths = [len(data) - val_length, val_length]
train_dataset, val_dataset = random_split(data, lengths)
val_length = int(len(data) * args.val_ratio)
lengths = [len(data) - val_length, val_length]
train_dataset, val_dataset = random_split(data, lengths)

elif args.dataset == 'libritts':
train_dataset = LIBRITTS(root=args.file_path, url='train-clean-100', download=False)
val_dataset = LIBRITTS(root=args.file_path, url='dev-clean', download=False)

else:
raise ValueError(f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}")

train_dataset = Processed(train_dataset, transforms)
val_dataset = Processed(val_dataset, transforms)
Expand Down
11 changes: 9 additions & 2 deletions examples/pipeline_wavernn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchaudio.datasets.utils import bg_iterator
from torchaudio.models._wavernn import _WaveRNN

from datasets import collate_factory, split_process_ljspeech
from datasets import collate_factory, split_process_dataset
from losses import LongCrossEntropyLoss, MoLLoss
from processing import LinearToMel, NormalizeDB
from utils import MetricLogger, count_parameters, save_checkpoint
Expand Down Expand Up @@ -55,6 +55,13 @@ def parse_args():
metavar="N",
help="print frequency in epochs",
)
parser.add_argument(
"--dataset",
default="ljspeech",
choices=["ljspeech", "libritts"],
type=str,
help="select dataset to train with",
)
parser.add_argument(
"--batch-size", default=256, type=int, metavar="N", help="mini-batch size"
)
Expand Down Expand Up @@ -269,7 +276,7 @@ def main(args):
NormalizeDB(min_level_db=args.min_level_db),
)

train_dataset, val_dataset = split_process_ljspeech(args, transforms)
train_dataset, val_dataset = split_process_dataset(args, transforms)

loader_training_params = {
"num_workers": args.workers,
Expand Down