Skip to content

Commit 684b95f

Browse files
committed
Update things around training
- Update batchsize to 16 - Fix validataion dataset collate and preprocessing
1 parent 0e8c7b7 commit 684b95f

File tree

6 files changed

+61
-45
lines changed

6 files changed

+61
-45
lines changed

examples/source_separation/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ This directory contains reference implementations for source separations. For th
99
### Overview
1010

1111
To traing a model, you can use [`train.py`](./train.py). This script takes the form of
12-
`trin.py [parameters for distributed training] -- [parameters for model/training]`
12+
`train.py [parameters for distributed training] -- [parameters for model/training]`
1313

1414
```
1515
python train.py \
@@ -162,7 +162,7 @@ python -u \
162162
--sample-rate 8000 \
163163
--dataset-dir "${dataset_dir}" \
164164
--save-dir "${save_dir}" \
165-
--batch-size $((32 / SLURM_NTASKS))
165+
--batch-size $((16 / SLURM_NTASKS))
166166
```
167167

168168
</details>

examples/source_separation/conv_tasnet/README.md

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ For the usage, please checkout the [source separation README](../README.md).
1313
The default training/model configurations follows the best non-causal implementation from the paper. (causal configuration is not implemented.)
1414

1515
- Sample rate: 8000 Hz
16-
- Batch size: total 32 over distributed training workers
16+
- Batch size: total 16 over distributed training workers
1717
- Epochs: 100
1818
- Initial learning rate: 1e-3
1919
- Gradient clipping: maximum L2 norm of 5.0
@@ -31,30 +31,22 @@ The default training/model configurations follows the best non-causal implementa
3131
- The number of TCN convolution block layers (X): 8
3232
- The number of TCN convolution blocks (R): 3
3333

34-
## Training
35-
36-
The training takes about 5 mins per epoch with 8 V100 GPUs in a single node.
37-
3834
## Evaluation
3935

4036
The following is the evaluation result of training the model on WSJ0-2mix and WSJ0-3mix datasets.
4137

4238
### wsj0-mix 2speakers
4339

44-
| | SI-SNRi (dB) | SDRi (dB) | Epoch |
45-
|:-----------------:|-------------:|----------:|------:|
46-
| Reference | 15.3 | 15.6 | |
47-
| "min" Validation | 12.63 | 12.63 | 100 |
48-
| "min" Evaluation | 10.59 | 10.58 | 100 |
49-
| "max" Validation | 12.72 | 12.72 | 100 |
50-
| "max" Evaluation | 11.00 | 11.00 | 100 |
40+
| | SI-SNRi (dB) | SDRi (dB) | Epoch |
41+
|:------------------:|-------------:|----------:|------:|
42+
| Reference | 15.3 | 15.6 | |
43+
| Validation dataset | 13.3 | 13.3 | 86 |
44+
| Evaluation dataset | 11.3 | 11.3 | 86 |
5145

5246
### wsj0-mix 3speakers
5347

54-
| | SI-SNRi (dB) | SDRi (dB) | Epoch |
55-
|:-----------------:|-------------:|----------:|------:|
56-
| Reference | 12.7 | 13.1 | |
57-
| "min" Validation | 10.75 | 10.75 | 99 |
58-
| "min" Evaluation | 8.39 | 8.38 | 99 |
59-
| "max" Validation | 10.87 | 10.86 | 99 |
60-
| "max" Evaluation | 8.23 | 8.20 | 99 |
48+
| | SI-SNRi (dB) | SDRi (dB) | Epoch |
49+
|:------------------:|-------------:|----------:|------:|
50+
| Reference | 12.7 | 13.1 | |
51+
| Validation dataset | 11.5 | 11.5 | 87 |
52+
| Evaluation dataset | 8.7 | 8.6 | 87 |

examples/source_separation/conv_tasnet/train.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
def _parse_args(args):
17-
default_batch_size = 32 // torch.distributed.get_world_size()
17+
default_batch_size = 16 // torch.distributed.get_world_size()
1818

1919
parser = argparse.ArgumentParser(description=__doc__,)
2020
parser.add_argument("--debug", action="store_true", help="Enable debug behavior.")
@@ -61,7 +61,7 @@ def _parse_args(args):
6161
"--batch-size",
6262
default=default_batch_size,
6363
type=int,
64-
help=f"Batch size. (default: {default_batch_size} ( == 32 // batch_size))",
64+
help=f"Batch size. (default: {default_batch_size} (== 16 // world_size))",
6565
)
6666
group = parser.add_argument_group("Training Options")
6767
group.add_argument(
@@ -133,34 +133,48 @@ def _get_dataloader(dataset_type, dataset_dir, num_speakers, sample_rate, batch_
133133
train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset(
134134
dataset_type, dataset_dir, num_speakers, sample_rate,
135135
)
136-
collate_fn = dataset_utils.get_collate_fn(
137-
dataset_type, sample_rate=sample_rate, duration=4
136+
train_collate_fn = dataset_utils.get_collate_fn(
137+
dataset_type, mode='train', sample_rate=sample_rate, duration=4
138138
)
139139

140+
test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode='test')
141+
140142
train_loader = torch.utils.data.DataLoader(
141143
train_dataset,
142144
batch_size=batch_size,
143145
sampler=torch.utils.data.distributed.DistributedSampler(train_dataset),
144-
collate_fn=collate_fn,
146+
collate_fn=train_collate_fn,
145147
pin_memory=True,
146148
)
147149
valid_loader = torch.utils.data.DataLoader(
148150
valid_dataset,
149151
batch_size=batch_size,
150152
sampler=torch.utils.data.distributed.DistributedSampler(valid_dataset),
151-
collate_fn=collate_fn,
153+
collate_fn=test_collate_fn,
152154
pin_memory=True,
153155
)
154156
eval_loader = torch.utils.data.DataLoader(
155157
eval_dataset,
156158
batch_size=batch_size,
157159
sampler=torch.utils.data.distributed.DistributedSampler(eval_dataset),
158-
collate_fn=collate_fn,
160+
collate_fn=test_collate_fn,
159161
pin_memory=True,
160162
)
161163
return train_loader, valid_loader, eval_loader
162164

163165

166+
def _write_header(log_path, args):
167+
rows = [
168+
[f"# torch: {torch.__version__}", ],
169+
[f"# torchaudio: {torchaudio.__version__}", ]
170+
]
171+
rows.append(["# arguments"])
172+
for key, item in vars(args).items():
173+
rows.append([f"# {key}: {item}"])
174+
175+
dist_utils.write_csv_on_master(log_path, *rows)
176+
177+
164178
def train(args):
165179
args = _parse_args(args)
166180
_LG.info("%s", args)
@@ -237,7 +251,7 @@ def train(args):
237251
)
238252

239253
log_path = args.save_dir / f"log.csv"
240-
dist_utils.write_csv_on_master(log_path, [f"# {args}", ])
254+
_write_header(log_path, args)
241255
dist_utils.write_csv_on_master(
242256
log_path,
243257
[

examples/source_separation/conv_tasnet/trainer.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,20 +134,22 @@ def _test(self, loader):
134134
total_si_snri = 0.0
135135
total_sdri = 0.0
136136

137-
for batch in loader:
138-
mixed = batch.mix.to(self.device)
139-
sources = batch.src.to(self.device)
140-
141-
estimate = self.model(mixed)
142-
si_snri, sdri = si_sdr_improvement(estimate, sources, mixed)
143-
si_snri = si_snri.sum()
144-
sdri = sdri.sum()
145-
146-
dist.all_reduce(si_snri, dist.ReduceOp.SUM)
147-
dist.all_reduce(sdri, dist.ReduceOp.SUM)
148-
149-
total_si_snri += si_snri.item()
150-
total_sdri += sdri.item()
137+
for samples in loader:
138+
# Due to the possible length difference, we run evaluation sample-wise
139+
for sample in samples:
140+
mixed = sample.mix.to(self.device)
141+
sources = sample.src.to(self.device)
142+
143+
estimate = self.model(mixed)
144+
si_snri, sdri = si_sdr_improvement(estimate, sources, mixed)
145+
si_snri = si_snri.sum()
146+
sdri = sdri.sum()
147+
148+
dist.all_reduce(si_snri, dist.ReduceOp.SUM)
149+
dist.all_reduce(sdri, dist.ReduceOp.SUM)
150+
151+
total_si_snri += si_snri.item()
152+
total_sdri += sdri.item()
151153

152154
if self.debug:
153155
break

examples/source_separation/train.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ def _parse_args(args=None):
8383
'access, using `"file://..."` protocol. (default: "env://")'
8484
),
8585
)
86+
group.add_argument(
87+
"--random-seed",
88+
type=int,
89+
help="Set random seed value. (default: None)",
90+
)
8691
parser.add_argument(
8792
"rest", nargs=argparse.REMAINDER, help="Model-specific arguments."
8893
)
@@ -118,6 +123,8 @@ def _main(cli_args):
118123
backend='nccl' if torch.cuda.is_available() else 'gloo',
119124
init_method=args.sync_protocol,
120125
)
126+
if args.random_seed is not None:
127+
torch.manual_seed(args.random_seed)
121128
if torch.cuda.is_available():
122129
torch.cuda.set_device(args.device_id)
123130
_LG.info("CUDA device set to %s", args.device_id)

examples/source_separation/utils/dist_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,12 @@ def save_on_master(path, obj):
5858
torch.save(obj, path)
5959

6060

61-
def write_csv_on_master(path, items):
61+
def write_csv_on_master(path, *rows):
6262
if dist.get_rank() == 0:
6363
with open(path, "a", newline="") as fileobj:
6464
writer = csv.writer(fileobj)
65-
writer.writerow(items)
65+
for row in rows:
66+
writer.writerow(row)
6667

6768

6869
def synchronize_params(path, device, *modules):

0 commit comments

Comments
 (0)