|
4 | 4 | import torch |
5 | 5 | import torchaudio |
6 | 6 | from torch.utils.data.dataset import random_split |
7 | | -from torchaudio.datasets import LJSPEECH |
| 7 | +from torchaudio.datasets import LJSPEECH, LIBRITTS |
8 | 8 | from torchaudio.transforms import MuLawEncoding |
9 | 9 |
|
10 | 10 | from processing import bits_to_normalized_waveform, normalized_waveform_to_bits |
@@ -48,12 +48,20 @@ def process_datapoint(self, item): |
48 | 48 | return item[0].squeeze(0), specgram |
49 | 49 |
|
50 | 50 |
|
51 | | -def split_process_ljspeech(args, transforms): |
52 | | - data = LJSPEECH(root=args.file_path, download=False) |
| 51 | +def split_process_dataset(args, transforms): |
| 52 | + if args.dataset == 'ljspeech': |
| 53 | + data = LJSPEECH(root=args.file_path, download=False) |
53 | 54 |
|
54 | | - val_length = int(len(data) * args.val_ratio) |
55 | | - lengths = [len(data) - val_length, val_length] |
56 | | - train_dataset, val_dataset = random_split(data, lengths) |
| 55 | + val_length = int(len(data) * args.val_ratio) |
| 56 | + lengths = [len(data) - val_length, val_length] |
| 57 | + train_dataset, val_dataset = random_split(data, lengths) |
| 58 | + |
| 59 | + elif args.dataset == 'libritts': |
| 60 | + train_dataset = LIBRITTS(root=args.file_path, url='train-clean-100', download=False) |
| 61 | + val_dataset = LIBRITTS(root=args.file_path, url='dev-clean', download=False) |
| 62 | + |
| 63 | + else: |
| 64 | + raise ValueError(f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}") |
57 | 65 |
|
58 | 66 | train_dataset = Processed(train_dataset, transforms) |
59 | 67 | val_dataset = Processed(val_dataset, transforms) |
|
0 commit comments