|
14 | 14 |
|
15 | 15 |
|
16 | 16 | def _parse_args(args): |
17 | | - default_batch_size = 32 // torch.distributed.get_world_size() |
| 17 | + default_batch_size = 16 // torch.distributed.get_world_size() |
18 | 18 |
|
19 | 19 | parser = argparse.ArgumentParser(description=__doc__,) |
20 | 20 | parser.add_argument("--debug", action="store_true", help="Enable debug behavior.") |
@@ -61,7 +61,7 @@ def _parse_args(args): |
61 | 61 | "--batch-size", |
62 | 62 | default=default_batch_size, |
63 | 63 | type=int, |
64 | | - help=f"Batch size. (default: {default_batch_size} ( == 32 // batch_size))", |
| 64 | + help=f"Batch size. (default: {default_batch_size} (== 16 // world_size))", |
65 | 65 | ) |
66 | 66 | group = parser.add_argument_group("Training Options") |
67 | 67 | group.add_argument( |
@@ -133,34 +133,48 @@ def _get_dataloader(dataset_type, dataset_dir, num_speakers, sample_rate, batch_ |
133 | 133 | train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset( |
134 | 134 | dataset_type, dataset_dir, num_speakers, sample_rate, |
135 | 135 | ) |
136 | | - collate_fn = dataset_utils.get_collate_fn( |
137 | | - dataset_type, sample_rate=sample_rate, duration=4 |
| 136 | + train_collate_fn = dataset_utils.get_collate_fn( |
| 137 | + dataset_type, mode='train', sample_rate=sample_rate, duration=4 |
138 | 138 | ) |
139 | 139 |
|
| 140 | + test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode='test') |
| 141 | + |
140 | 142 | train_loader = torch.utils.data.DataLoader( |
141 | 143 | train_dataset, |
142 | 144 | batch_size=batch_size, |
143 | 145 | sampler=torch.utils.data.distributed.DistributedSampler(train_dataset), |
144 | | - collate_fn=collate_fn, |
| 146 | + collate_fn=train_collate_fn, |
145 | 147 | pin_memory=True, |
146 | 148 | ) |
147 | 149 | valid_loader = torch.utils.data.DataLoader( |
148 | 150 | valid_dataset, |
149 | 151 | batch_size=batch_size, |
150 | 152 | sampler=torch.utils.data.distributed.DistributedSampler(valid_dataset), |
151 | | - collate_fn=collate_fn, |
| 153 | + collate_fn=test_collate_fn, |
152 | 154 | pin_memory=True, |
153 | 155 | ) |
154 | 156 | eval_loader = torch.utils.data.DataLoader( |
155 | 157 | eval_dataset, |
156 | 158 | batch_size=batch_size, |
157 | 159 | sampler=torch.utils.data.distributed.DistributedSampler(eval_dataset), |
158 | | - collate_fn=collate_fn, |
| 160 | + collate_fn=test_collate_fn, |
159 | 161 | pin_memory=True, |
160 | 162 | ) |
161 | 163 | return train_loader, valid_loader, eval_loader |
162 | 164 |
|
163 | 165 |
|
| 166 | +def _write_header(log_path, args): |
| 167 | + rows = [ |
| 168 | + [f"# torch: {torch.__version__}", ], |
| 169 | + [f"# torchaudio: {torchaudio.__version__}", ] |
| 170 | + ] |
| 171 | + rows.append(["# arguments"]) |
| 172 | + for key, item in vars(args).items(): |
| 173 | + rows.append([f"# {key}: {item}"]) |
| 174 | + |
| 175 | + dist_utils.write_csv_on_master(log_path, *rows) |
| 176 | + |
| 177 | + |
164 | 178 | def train(args): |
165 | 179 | args = _parse_args(args) |
166 | 180 | _LG.info("%s", args) |
@@ -237,7 +251,7 @@ def train(args): |
237 | 251 | ) |
238 | 252 |
|
239 | 253 | log_path = args.save_dir / f"log.csv" |
240 | | - dist_utils.write_csv_on_master(log_path, [f"# {args}", ]) |
| 254 | + _write_header(log_path, args) |
241 | 255 | dist_utils.write_csv_on_master( |
242 | 256 | log_path, |
243 | 257 | [ |
|
0 commit comments