Skip to content

Commit 203cc7c

Browse files
committed
Add ConvTasNet main training script
1 parent 725f8b0 commit 203cc7c

File tree

9 files changed

+1034
-10
lines changed

9 files changed

+1034
-10
lines changed
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Source Separation Example
2+
3+
This directory contains reference implementations for source separations. For the detail of each model, please checkout the followings.
4+
5+
- [Conv-TasNet](./conv_tasnet/README.md)
6+
7+
## Usage
8+
9+
### Overview
10+
11+
To traing a model, you can use [`train.py`](./train.py). This script takes the form of
12+
`train.py [parameters for distributed training] -- [parameters for model/training]`
13+
14+
```
15+
python train.py \
16+
[--worker-id WORKER_ID] \
17+
[--device-id DEVICE_ID] \
18+
[--num-workers NUM_WORKERS] \
19+
[--sync-protocol SYNC_PROTOCOL] \
20+
-- \
21+
<model specific training parameters>
22+
23+
# For the detail of the parameter values, use;
24+
python train.py --help
25+
26+
# For the detail of the model parameters, use;
27+
python train.py -- --help
28+
```
29+
30+
If you would like to just try out the traing script, then try it without any parameters
31+
for distributed training. `train.py -- --sample-rate 8000 --batch-size <BATCH_SIZE> --dataset-dir <DATASET_DIR> --save-dir <SAVE_DIR>`
32+
33+
This script runs training in Distributed Data Parallel (DDP) framework and has two major
34+
operation modes. This behavior depends on if `--worker-id` argument is given or not.
35+
36+
1. (`--worker-id` is not given) Launchs training worker subprocesses that performs the actual training.
37+
2. (`--worker-id` is given) Performs the training as a part of distributed training.
38+
39+
When launching the script without any distributed trainig parameters (operation mode 1),
40+
this script will check the number of GPUs available on the local system and spawns the same
41+
number of training subprocesses (as operaiton mode 2). You can reduce the number of GPUs with
42+
`--num-workers`. If there is no GPU available, only one subprocess is launched and providing
43+
`--num-workers` larger than 1 results in error.
44+
45+
When launching the script as a worker process of a distributed training, you need to configure
46+
the coordination of the workers.
47+
48+
- `--num-workers` is the number of training processes being launched.
49+
- `--worker-id` is the process rank (must be unique across all the processes).
50+
- `--device-id` is the GPU device ID (should be unique within node).
51+
- `--sync-protocl` is how each worker process communicate and synchronize.
52+
If the training is carried out on a single node, then the default `"env://"` should do.
53+
If the training processes span across multiple nodes, then you need to provide a protocol that
54+
can communicate over the network. If you know where the master node is located, you can use
55+
`"env://"` in combination with `MASTER_ADDR` and `MASER_PORT` environment variables. If you do
56+
not know where the master node is located beforehand, you can use `"file://..."` protocol to
57+
indicate where the file to which all the worker process have access is located. For other
58+
protocols, please refer to the official documentation.
59+
60+
### Distributed Training Notes
61+
62+
<details><summary>Quick overview on DDP (distributed data parallel)</summary>
63+
64+
DDP is single-program multiple-data training paradigm.
65+
With DDP, the model is replicated on every process,
66+
and every model replica will be fed with a different set of input data samples.
67+
68+
- **Process**: Worker process (as in Linux process). There are `P` processes per a Node.
69+
- **Node**: A machine. There are `N` machines, each of which holds `P` processes.
70+
- **World**: network of nodes, composed of `N` nodes and `N * P` processes.
71+
- **Rank**: Grobal process ID (unique across nodes) `[0, N * P)`
72+
- **Local Rank**: Local process ID (unique only within a node) `[0, P)`
73+
74+
```
75+
Node 0 Node 1 Node N-1
76+
┌────────────────────────┐┌─────────────────────────┐ ┌───────────────────────────┐
77+
│╔══════════╗ ┌─────────┐││┌───────────┐ ┌─────────┐│ │┌─────────────┐ ┌─────────┐│
78+
│║ Process ╟─┤ GPU: 0 ││││ Process ├─┤ GPU: 0 ││ ││ Process ├─┤ GPU: 0 ││
79+
│║ Rank: 0 ║ └─────────┘│││ Rank:P │ └─────────┘│ ││ Rank:NP-P │ └─────────┘│
80+
│╚══════════╝ ││└───────────┘ │ │└─────────────┘ │
81+
│┌──────────┐ ┌─────────┐││┌───────────┐ ┌─────────┐│ │┌─────────────┐ ┌─────────┐│
82+
││ Process ├─┤ GPU: 1 ││││ Process ├─┤ GPU: 1 ││ ││ Process ├─┤ GPU: 1 ││
83+
││ Rank: 1 │ └─────────┘│││ Rank:P+1 │ └─────────┘│ ││ Rank:NP-P+1 │ └─────────┘│
84+
│└──────────┘ ││└───────────┘ │ ... │└─────────────┘ │
85+
│ ││ │ │ │
86+
│ ... ││ ... │ │ ... │
87+
│ ││ │ │ │
88+
│┌──────────┐ ┌─────────┐││┌───────────┐ ┌─────────┐│ │┌─────────────┐ ┌─────────┐│
89+
││ Process ├─┤ GPU:P-1 ││││ Process ├─┤ GPU:P-1 ││ ││ Process ├─┤ GPU:P-1 ││
90+
││ Rank:P-1 │ └─────────┘│││ Rank:2P-1 │ └─────────┘│ ││ Rank:NP-1 │ └─────────┘│
91+
│└──────────┘ ││└───────────┘ │ │└─────────────┘ │
92+
└────────────────────────┘└─────────────────────────┘ └───────────────────────────┘
93+
```
94+
95+
</details>
96+
97+
### SLURM
98+
99+
When launched as SLURM job, the follwoing environment variables correspond to
100+
101+
- **SLURM_PROCID*: `--worker-id` (Rank)
102+
- **SLURM_NTASKS** (or legacy **SLURM_NPPROCS**): the number of total processes (`--num-workers` == world size)
103+
- **SLURM_LOCALID**: Local Rank (to be mapped with GPU index*)
104+
105+
* Even when GPU resource is allocated with `--gpus-per-task=1`, if there are muptiple
106+
tasks allocated on the same node, (thus multiple GPUs of the node are allocated to the job)
107+
each task can see all the GPUs allocated for the tasks. Therefore we need to use
108+
SLURM_LOCALID to tell each task to which GPU it should be using.
109+
110+
<details><summary>Example scripts for running the training on SLURM cluster</summary>
111+
112+
- **launch_job.sh**
113+
114+
```bash
115+
#!/bin/bash
116+
117+
#SBATCH --job-name=source_separation
118+
119+
#SBATCH --output=/checkpoint/%u/jobs/%x/%j.out
120+
121+
#SBATCH --error=/checkpoint/%u/jobs/%x/%j.err
122+
123+
#SBATCH --nodes=1
124+
125+
#SBATCH --ntasks-per-node=8
126+
127+
#SBATCH --cpus-per-task=8
128+
129+
#SBATCH --mem-per-cpu=16G
130+
131+
#SBATCH --gpus-per-task=1
132+
133+
#srun env
134+
srun wrapper.sh $@
135+
```
136+
137+
- **wrapper.sh**
138+
139+
```bash
140+
#!/bin/bash
141+
num_speakers=2
142+
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
143+
save_dir="/checkpoint/${USER}/jobs/${SLURM_JOB_NAME}/${SLURM_JOB_ID}"
144+
dataset_dir="/dataset/wsj0-mix/${num_speakers}speakers/wav8k/min"
145+
146+
if [ "${SLURM_JOB_NUM_NODES}" -gt 1 ]; then
147+
protocol="file:///checkpoint/${USER}/jobs/source_separation/${SLURM_JOB_ID}/sync"
148+
else
149+
protocol="env://"
150+
fi
151+
152+
mkdir -p "${save_dir}"
153+
154+
python -u \
155+
"${this_dir}/train.py" \
156+
--worker-id "${SLURM_PROCID}" \
157+
--num-workers "${SLURM_NTASKS}" \
158+
--device-id "${SLURM_LOCALID}" \
159+
--sync-protocol "${protocol}" \
160+
-- \
161+
--num-speakers "${num_speakers}" \
162+
--sample-rate 8000 \
163+
--dataset-dir "${dataset_dir}" \
164+
--save-dir "${save_dir}" \
165+
--batch-size $((16 / SLURM_NTASKS))
166+
```
167+
168+
</details>
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Conv-TasNet
2+
3+
This is a reference implementation of Conv-TasNet.
4+
5+
> Luo, Yi, and Nima Mesgarani. "Conv-TasNet: Surpassing Ideal Time-Frequency Magnitude Masking for Speech Separation." IEEE/ACM Transactions on Audio, Speech, and Language Processing 27.8 (2019): 1256-1266. Crossref. Web.
6+
7+
This implementation is based on [arXiv:1809.07454v3](https://arxiv.org/abs/1809.07454v3) and [the reference implementation](https://github.com/naplab/Conv-TasNet) provided by the authors.
8+
9+
For the usage, please checkout the [source separation README](../README.md).
10+
11+
## (Default) Training Configurations
12+
13+
The default training/model configurations follows the best non-causal implementation from the paper. (causal configuration is not implemented.)
14+
15+
- Sample rate: 8000 Hz
16+
- Batch size: total 16 over distributed training workers
17+
- Epochs: 100
18+
- Initial learning rate: 1e-3
19+
- Gradient clipping: maximum L2 norm of 5.0
20+
- Optimizer: Adam
21+
- Learning rate scheduling: Halved after 3 epochs of no improvement in validation accuracy.
22+
- Objective function: SI-SNRi (The paper uses SI-SNR as the objective function, which differs only by a constant value, so it should practically yield the same result.)
23+
- Reported metrics: SI-SNRi, SDRi
24+
- Sample audio length: 4 seconds (first 4 seconds)
25+
- Encoder/Decoder feature dimension (N): 512
26+
- Encoder/Decoder convolution kernel size (L): 16
27+
- TCN bottleneck/output feature dimension (B): 128
28+
- TCN hidden feature dimension (H): 512
29+
- TCN skip connection feature dimension (Sc): 128
30+
- TCN convolution kernel size (P): 3
31+
- The number of TCN convolution block layers (X): 8
32+
- The number of TCN convolution blocks (R): 3
33+
34+
## Evaluation
35+
36+
The following is the evaluation result of training the model on WSJ0-2mix and WSJ0-3mix datasets.
37+
38+
### wsj0-mix 2speakers
39+
40+
| | SI-SNRi (dB) | SDRi (dB) | Epoch |
41+
|:------------------:|-------------:|----------:|------:|
42+
| Reference | 15.3 | 15.6 | |
43+
| Validation dataset | 13.3 | 13.3 | 86 |
44+
| Evaluation dataset | 11.3 | 11.3 | 86 |
45+
46+
### wsj0-mix 3speakers
47+
48+
| | SI-SNRi (dB) | SDRi (dB) | Epoch |
49+
|:------------------:|-------------:|----------:|------:|
50+
| Reference | 12.7 | 13.1 | |
51+
| Validation dataset | 11.5 | 11.5 | 87 |
52+
| Evaluation dataset | 8.7 | 8.6 | 87 |
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from . import (
2+
train,
3+
trainer,
4+
)

0 commit comments

Comments
 (0)