Skip to content

Commit 25403d0

Browse files
committed
Add training script
1 parent 52a18a9 commit 25403d0

File tree

6 files changed

+657
-0
lines changed

6 files changed

+657
-0
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Source Separation Example
2+
3+
## Usage
4+
5+
### Overview
6+
7+
To traing a model, you can use [`train.py`](./train.py). This script takes the form of
8+
`[parameters for distributed training] -- [parameters for model/training]`
9+
10+
If you would like to just try out the traing script, then try it without any parameters
11+
for distributed training.
12+
13+
```
14+
python train.py \
15+
[--worker-id WORKER_ID] \
16+
[--device-id DEVICE_ID] \
17+
[--num-workers NUM_WORKERS] \
18+
[--sync-protocol SYNC_PROTOCOL] \
19+
-- \
20+
<model specific training parameters>
21+
22+
# For the detail of the parameter values, use;
23+
python train.py --help
24+
25+
# For the detail of the model parameters, use;
26+
python train.py -- --help
27+
```
28+
29+
This script runs training in Distributed Data Parallel (DDP) framework and has two major
30+
operation modes. This behavior depends on if `--worker-id` argument is given or not.
31+
32+
1. (`--worker-id` is not given) Launchs worker subprocesses that performs the actual training.
33+
2. (`--worker-id` is given) Performs the training as a part of distributed training.
34+
35+
When launching the script without any distributed trainig parameters (operation mode 1),
36+
this script will check the number of GPUs available on the local system and spawns the same
37+
number of training subprocesses (as operaiton mode 2). You can reduce the number of GPUs with
38+
`--num-workers`. If there is no GPU available, only one subprocess is launched and providing
39+
`--num-workers` larger than 1 results in error.
40+
41+
When launching the script as a worker process of a distributed training, you need to configure
42+
the coordination of the workers.
43+
44+
- `--num-workers` is the number of training processes being launched.
45+
- `--worker-id` is the process rank (must be unique across all the processes).
46+
- `--device-id` is the GPU device ID (should be unique within node).
47+
- `--sync-protocl` is how each worker process communicate and synchronize.
48+
If the training is carried out on a single node, then the default `"env://"` should do.
49+
If the training processes span across multiple nodes, then a path to the file to which all the
50+
traiing processes have access, has to be provided with `"file://..."` protocol.
51+
52+
### Distributed Training Notes
53+
54+
<details><summary>Quick overview on DDP (distributed data parallel)</summary>
55+
56+
DDP is single-program multiple-data training paradigm.
57+
With DDP, the model is replicated on every process,
58+
and every model replica will be fed with a different set of input data samples.
59+
60+
Process: Worker process (as in Linux process). P processes per a Node
61+
Node: A machine. There are N machines, each of which holds P processes.
62+
World: network of nodes, composed of N nodes and N * P processes.
63+
64+
Rank: Grobal process ID (unique across nodes) [0, N * P)
65+
Local Rank: Local process ID (unique only within a node) [0, P)
66+
67+
```
68+
Node 0 Node 1 Node N-1
69+
┌────────────────────────┐┌─────────────────────────┐ ┌───────────────────────────┐
70+
│╔══════════╗ ┌─────────┐││┌───────────┐ ┌─────────┐│ │┌─────────────┐ ┌─────────┐│
71+
│║ Process ╟─┤ GPU: 0 ││││ Process ├─┤ GPU: 0 ││ ││ Process ├─┤ GPU: 0 ││
72+
│║ Rank: 0 ║ └─────────┘│││ Rank:P │ └─────────┘│ ││ Rank:NP-P │ └─────────┘│
73+
│╚══════════╝ ││└───────────┘ │ │└─────────────┘ │
74+
│┌──────────┐ ┌─────────┐││┌───────────┐ ┌─────────┐│ │┌─────────────┐ ┌─────────┐│
75+
││ Process ├─┤ GPU: 1 ││││ Process ├─┤ GPU: 1 ││ ││ Process ├─┤ GPU: 1 ││
76+
││ Rank: 1 │ └─────────┘│││ Rank:P+1 │ └─────────┘│ ││ Rank:NP-P+1 │ └─────────┘│
77+
│└──────────┘ ││└───────────┘ │ ... │└─────────────┘ │
78+
│ ││ │ │ │
79+
│ ... ││ ... │ │ ... │
80+
│ ││ │ │ │
81+
│┌──────────┐ ┌─────────┐││┌───────────┐ ┌─────────┐│ │┌─────────────┐ ┌─────────┐│
82+
││ Process ├─┤ GPU:P-1 ││││ Process ├─┤ GPU:P-1 ││ ││ Process ├─┤ GPU:P-1 ││
83+
││ Rank:P-1 │ └─────────┘│││ Rank:2P-1 │ └─────────┘│ ││ Rank:NP-1 │ └─────────┘│
84+
│└──────────┘ ││└───────────┘ │ │└─────────────┘ │
85+
└────────────────────────┘└─────────────────────────┘ └───────────────────────────┘
86+
```
87+
88+
</details>
89+
90+
### SLURM
91+
92+
When launched as SLURM job, the follwoing environment variables correspond to
93+
94+
SLURM_PROCID: `--worker-id` (Rank)
95+
SLURM_NTASKS (or legacy SLURM_NPPROCS): the number of total processes (`--num-workers` == world size)
96+
SLURM_LOCALID: Local Rank (to be mapped with GPU index*)
97+
98+
* Even when GPU resource is allocated with `--gpus-per-task=1`, if there are muptiple
99+
tasks allocated on the same node, (thus multiple GPUs of the node are allocated to the job)
100+
each task can see all the GPUs allocated for the tasks. Therefore we need to use
101+
SLURM_LOCALID to tell each task to which GPU it should be using.
102+
103+
<details><summary>Example scripts for running the training on SLURM cluster</summary>
104+
105+
- **launch_job.sh**
106+
107+
```bash
108+
#!/bin/bash
109+
110+
#SBATCH --job-name=source_separation
111+
112+
#SBATCH --output=/checkpoint/%u/jobs/%x/%j.out
113+
114+
#SBATCH --error=/checkpoint/%u/jobs/%x/%j.err
115+
116+
#SBATCH --nodes=1
117+
118+
#SBATCH --ntasks-per-node=8
119+
120+
#SBATCH --cpus-per-task=8
121+
122+
#SBATCH --mem-per-cpu=16G
123+
124+
#SBATCH --gpus-per-task=1
125+
126+
#srun env
127+
srun wrapper.sh $@
128+
```
129+
130+
- **wrapper.sh**
131+
132+
```bash
133+
#!/bin/bash
134+
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
135+
save_dir="/checkpoint/${USER}/jobs/${SLURM_JOB_NAME}/${SLURM_JOB_ID}"
136+
dataset_dir="/dataset/wsj0-mix/2speakers/wav8k/min"
137+
138+
if [ "${SLURM_JOB_NUM_NODES}" -gt 1 ]; then
139+
protocol="file:///checkpoint/${USER}/jobs/source_separation/${SLURM_JOB_ID}/sync"
140+
else
141+
protocol="env://"
142+
fi
143+
144+
mkdir -p "${save_dir}"
145+
146+
python -u \
147+
"${this_dir}/train.py" \
148+
--worker-id "${SLURM_PROCID}" \
149+
--num-workers "${SLURM_NTASKS}" \
150+
--device-id "${SLURM_LOCALID}" \
151+
--sync-protocol "${protocol}" \
152+
-- \
153+
--sample-rate 8000 \
154+
--batch-size $((32 / SLURM_NTASKS)) \
155+
--dataset-dir "${dataset_dir}" \
156+
--save-dir "${save_dir}"
157+
```
158+
159+
</details>
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from . import (
2+
model,
3+
metrics,
4+
trainer,
5+
)
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
#!/usr/bin/env python3
2+
"""Train Conv-TasNet"""
3+
import pathlib
4+
import argparse
5+
6+
import torch.utils.data
7+
8+
import conv_tasnet
9+
import dataset_utils
10+
import dist_utils
11+
12+
_LG = dist_utils.getLogger(__name__)
13+
14+
15+
def _parse_args(args):
16+
parser = argparse.ArgumentParser(
17+
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
18+
)
19+
group = parser.add_argument_group("model")
20+
group.add_argument(
21+
"--num-speakers", default=2, type=int, help="The number of speakers."
22+
)
23+
group = parser.add_argument_group("dataset")
24+
group.add_argument(
25+
"--sample-rate",
26+
required=True,
27+
type=int,
28+
help="Sample rate of audio files in the given dataset.",
29+
)
30+
group.add_argument(
31+
"--dataset", default="wsj0mix",
32+
choices=["wsj0mix"]
33+
)
34+
group.add_argument(
35+
"--dataset-dir",
36+
required=True,
37+
type=pathlib.Path,
38+
help=(
39+
"Directory where dataset is found. "
40+
'If the dataset type is "wsj9mix", then this is the directory where '
41+
'"cv", "tt" and "tr" subdirectories are found.'
42+
),
43+
)
44+
group = parser.add_argument_group("save")
45+
group.add_argument(
46+
"--save-dir",
47+
required=True,
48+
type=pathlib.Path,
49+
help=(
50+
"Directory where the checkpoints are saved. "
51+
"Though, only the worker 0 saves checkpoint data, all the worker processes must "
52+
"have access to the directory.",
53+
),
54+
)
55+
group = parser.add_argument_group("dataloader")
56+
group.add_argument(
57+
"--batch-size", default=32, type=int,
58+
)
59+
group = parser.add_argument_group("training")
60+
group.add_argument(
61+
"--epochs", default=100, type=int, help="The number of epochs to train."
62+
)
63+
group.add_argument(
64+
"--learning-rate", default=1e-3, type=float, help="Initial learning rate."
65+
)
66+
group.add_argument(
67+
"--grad-clip", default=5.0, type=float, help="Gradient clip value (l2 norm)."
68+
)
69+
group.add_argument(
70+
"--resume",
71+
help="Previous checkpoint file from which the training is resumed."
72+
)
73+
return parser.parse_args(args)
74+
75+
76+
def train(args):
77+
args = _parse_args(args)
78+
_LG.info("%s", args)
79+
80+
args.save_dir.mkdir(parents=True, exist_ok=True)
81+
82+
start_epoch = 1
83+
if args.resume:
84+
checkpoint = torch.load(args.resume)
85+
if args.sample_rate != checkpoint['sample_rate']:
86+
raise ValueError(
87+
"The provided sample rate ({args.sample_rate}) does not match "
88+
"the sample rate from the check point ({checkpoint['sample_rate']}).")
89+
if args.num_speakers != checkpoint['num_speakers']:
90+
raise ValueError(
91+
"The provided #of speakers ({args.num_speakers}) does not match "
92+
"the #of speakers from the check point ({checkpoint['num_speakers']}.)"
93+
)
94+
start_epoch = checkpoint['epoch']
95+
96+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97+
_LG.info("Using: %s", device)
98+
99+
model = conv_tasnet.model.ConvTasNet(
100+
num_speakers=args.num_speakers, enc_kernel_size=args.sample_rate * 2 // 1000
101+
)
102+
model.to(device)
103+
104+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device])
105+
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
106+
107+
if args.resume:
108+
_LG.info('Loading parameters from the checkpoint...')
109+
model.module.load_state_dict(checkpoint['model'])
110+
optimizer.load_state_dict(checkpoint['optimizer'])
111+
else:
112+
dist_utils.synchronize_params(str(args.save_dir / f"tmp.pt"), model, optimizer)
113+
114+
_LG.info_on_master("Model:\n%s", model)
115+
116+
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
117+
optimizer, factor=0.5, patience=3
118+
)
119+
120+
train_loss_func = conv_tasnet.metrics.PITLoss(
121+
loss_func=conv_tasnet.metrics.neg_si_snr
122+
)
123+
124+
train_dataset, eval_dataset = dataset_utils.get_dataset(
125+
args.dataset, args.dataset_dir, args.num_speakers
126+
)
127+
collate_fn = dataset_utils.get_collate_fn(args.dataset)
128+
129+
train_loader = torch.utils.data.DataLoader(
130+
train_dataset,
131+
batch_size=args.batch_size,
132+
sampler=torch.utils.data.distributed.DistributedSampler(train_dataset),
133+
collate_fn=collate_fn,
134+
)
135+
eval_loader = torch.utils.data.DataLoader(
136+
eval_dataset,
137+
batch_size=args.batch_size,
138+
sampler=torch.utils.data.distributed.DistributedSampler(eval_dataset),
139+
collate_fn=collate_fn,
140+
)
141+
142+
trainer = conv_tasnet.trainer.Trainer(
143+
model,
144+
optimizer,
145+
train_loader,
146+
eval_loader,
147+
train_loss_func,
148+
args.grad_clip,
149+
device,
150+
)
151+
152+
_LG.info_on_master("Running %s epochs", args.epochs)
153+
for epoch in range(start_epoch, start_epoch + args.epochs):
154+
_LG.info_on_master("-" * 70)
155+
_LG.info_on_master("Epoch: %s", epoch)
156+
_LG.info_on_master("Learning rate: %s", optimizer.param_groups[0]["lr"])
157+
_LG.info_on_master("-" * 70)
158+
159+
trainer.train_one_epoch()
160+
loss = trainer.evaluate()
161+
lr_scheduler.step(loss)
162+
163+
save_path = args.save_dir / f"epoch_{epoch}.pt"
164+
dist_utils.save_on_master(
165+
{
166+
"model": model.module.state_dict(),
167+
"optimizer": optimizer.state_dict(),
168+
"num_speakers": args.num_speakers,
169+
"sample_rate": args.sample_rate,
170+
"epoch": epoch,
171+
},
172+
save_path,
173+
)

0 commit comments

Comments
 (0)