Skip to content

Commit 870811c

Browse files
jimchen90Ji Chen
andauthored
Add libritts dataset option (#818)
Co-authored-by: Ji Chen <[email protected]>
1 parent 1ecbc24 commit 870811c

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

examples/pipeline_wavernn/datasets.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import torchaudio
66
from torch.utils.data.dataset import random_split
7-
from torchaudio.datasets import LJSPEECH
7+
from torchaudio.datasets import LJSPEECH, LIBRITTS
88
from torchaudio.transforms import MuLawEncoding
99

1010
from processing import bits_to_normalized_waveform, normalized_waveform_to_bits
@@ -48,12 +48,20 @@ def process_datapoint(self, item):
4848
return item[0].squeeze(0), specgram
4949

5050

51-
def split_process_ljspeech(args, transforms):
52-
data = LJSPEECH(root=args.file_path, download=False)
51+
def split_process_dataset(args, transforms):
52+
if args.dataset == 'ljspeech':
53+
data = LJSPEECH(root=args.file_path, download=False)
5354

54-
val_length = int(len(data) * args.val_ratio)
55-
lengths = [len(data) - val_length, val_length]
56-
train_dataset, val_dataset = random_split(data, lengths)
55+
val_length = int(len(data) * args.val_ratio)
56+
lengths = [len(data) - val_length, val_length]
57+
train_dataset, val_dataset = random_split(data, lengths)
58+
59+
elif args.dataset == 'libritts':
60+
train_dataset = LIBRITTS(root=args.file_path, url='train-clean-100', download=False)
61+
val_dataset = LIBRITTS(root=args.file_path, url='dev-clean', download=False)
62+
63+
else:
64+
raise ValueError(f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}")
5765

5866
train_dataset = Processed(train_dataset, transforms)
5967
val_dataset = Processed(val_dataset, transforms)

examples/pipeline_wavernn/main.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torchaudio.datasets.utils import bg_iterator
1616
from torchaudio.models.wavernn import WaveRNN
1717

18-
from datasets import collate_factory, split_process_ljspeech
18+
from datasets import collate_factory, split_process_dataset
1919
from losses import LongCrossEntropyLoss, MoLLoss
2020
from processing import LinearToMel, NormalizeDB
2121
from utils import MetricLogger, count_parameters, save_checkpoint
@@ -55,6 +55,13 @@ def parse_args():
5555
metavar="N",
5656
help="print frequency in epochs",
5757
)
58+
parser.add_argument(
59+
"--dataset",
60+
default="ljspeech",
61+
choices=["ljspeech", "libritts"],
62+
type=str,
63+
help="select dataset to train with",
64+
)
5865
parser.add_argument(
5966
"--batch-size", default=256, type=int, metavar="N", help="mini-batch size"
6067
)
@@ -269,7 +276,7 @@ def main(args):
269276
NormalizeDB(min_level_db=args.min_level_db),
270277
)
271278

272-
train_dataset, val_dataset = split_process_ljspeech(args, transforms)
279+
train_dataset, val_dataset = split_process_dataset(args, transforms)
273280

274281
loader_training_params = {
275282
"num_workers": args.workers,

0 commit comments

Comments
 (0)