|
| 1 | +# Source Separation Example |
| 2 | + |
| 3 | +This directory contains reference implementations for source separations. For the detail of each model, please checkout the followings. |
| 4 | + |
| 5 | +- [Conv-TasNet](./conv_tasnet/README.md) |
| 6 | + |
| 7 | +## Usage |
| 8 | + |
| 9 | +### Overview |
| 10 | + |
| 11 | +To traing a model, you can use [`train.py`](./train.py). This script takes the form of |
| 12 | +`train.py [parameters for distributed training] -- [parameters for model/training]` |
| 13 | + |
| 14 | + ``` |
| 15 | + python train.py \ |
| 16 | + [--worker-id WORKER_ID] \ |
| 17 | + [--device-id DEVICE_ID] \ |
| 18 | + [--num-workers NUM_WORKERS] \ |
| 19 | + [--sync-protocol SYNC_PROTOCOL] \ |
| 20 | + -- \ |
| 21 | + <model specific training parameters> |
| 22 | + |
| 23 | + # For the detail of the parameter values, use; |
| 24 | + python train.py --help |
| 25 | + |
| 26 | + # For the detail of the model parameters, use; |
| 27 | + python train.py -- --help |
| 28 | + ``` |
| 29 | + |
| 30 | +If you would like to just try out the traing script, then try it without any parameters |
| 31 | +for distributed training. `train.py -- --sample-rate 8000 --batch-size <BATCH_SIZE> --dataset-dir <DATASET_DIR> --save-dir <SAVE_DIR>` |
| 32 | + |
| 33 | +This script runs training in Distributed Data Parallel (DDP) framework and has two major |
| 34 | +operation modes. This behavior depends on if `--worker-id` argument is given or not. |
| 35 | + |
| 36 | +1. (`--worker-id` is not given) Launchs training worker subprocesses that performs the actual training. |
| 37 | +2. (`--worker-id` is given) Performs the training as a part of distributed training. |
| 38 | + |
| 39 | +When launching the script without any distributed trainig parameters (operation mode 1), |
| 40 | +this script will check the number of GPUs available on the local system and spawns the same |
| 41 | +number of training subprocesses (as operaiton mode 2). You can reduce the number of GPUs with |
| 42 | +`--num-workers`. If there is no GPU available, only one subprocess is launched and providing |
| 43 | +`--num-workers` larger than 1 results in error. |
| 44 | + |
| 45 | +When launching the script as a worker process of a distributed training, you need to configure |
| 46 | +the coordination of the workers. |
| 47 | + |
| 48 | +- `--num-workers` is the number of training processes being launched. |
| 49 | +- `--worker-id` is the process rank (must be unique across all the processes). |
| 50 | +- `--device-id` is the GPU device ID (should be unique within node). |
| 51 | +- `--sync-protocl` is how each worker process communicate and synchronize. |
| 52 | + If the training is carried out on a single node, then the default `"env://"` should do. |
| 53 | + If the training processes span across multiple nodes, then you need to provide a protocol that |
| 54 | + can communicate over the network. If you know where the master node is located, you can use |
| 55 | + `"env://"` in combination with `MASTER_ADDR` and `MASER_PORT` environment variables. If you do |
| 56 | + not know where the master node is located beforehand, you can use `"file://..."` protocol to |
| 57 | + indicate where the file to which all the worker process have access is located. For other |
| 58 | + protocols, please refer to the official documentation. |
| 59 | + |
| 60 | +### Distributed Training Notes |
| 61 | + |
| 62 | +<details><summary>Quick overview on DDP (distributed data parallel)</summary> |
| 63 | + |
| 64 | +DDP is single-program multiple-data training paradigm. |
| 65 | +With DDP, the model is replicated on every process, |
| 66 | +and every model replica will be fed with a different set of input data samples. |
| 67 | + |
| 68 | +- **Process**: Worker process (as in Linux process). There are `P` processes per a Node. |
| 69 | +- **Node**: A machine. There are `N` machines, each of which holds `P` processes. |
| 70 | +- **World**: network of nodes, composed of `N` nodes and `N * P` processes. |
| 71 | +- **Rank**: Grobal process ID (unique across nodes) `[0, N * P)` |
| 72 | +- **Local Rank**: Local process ID (unique only within a node) `[0, P)` |
| 73 | + |
| 74 | +``` |
| 75 | + Node 0 Node 1 Node N-1 |
| 76 | +┌────────────────────────┐┌─────────────────────────┐ ┌───────────────────────────┐ |
| 77 | +│╔══════════╗ ┌─────────┐││┌───────────┐ ┌─────────┐│ │┌─────────────┐ ┌─────────┐│ |
| 78 | +│║ Process ╟─┤ GPU: 0 ││││ Process ├─┤ GPU: 0 ││ ││ Process ├─┤ GPU: 0 ││ |
| 79 | +│║ Rank: 0 ║ └─────────┘│││ Rank:P │ └─────────┘│ ││ Rank:NP-P │ └─────────┘│ |
| 80 | +│╚══════════╝ ││└───────────┘ │ │└─────────────┘ │ |
| 81 | +│┌──────────┐ ┌─────────┐││┌───────────┐ ┌─────────┐│ │┌─────────────┐ ┌─────────┐│ |
| 82 | +││ Process ├─┤ GPU: 1 ││││ Process ├─┤ GPU: 1 ││ ││ Process ├─┤ GPU: 1 ││ |
| 83 | +││ Rank: 1 │ └─────────┘│││ Rank:P+1 │ └─────────┘│ ││ Rank:NP-P+1 │ └─────────┘│ |
| 84 | +│└──────────┘ ││└───────────┘ │ ... │└─────────────┘ │ |
| 85 | +│ ││ │ │ │ |
| 86 | +│ ... ││ ... │ │ ... │ |
| 87 | +│ ││ │ │ │ |
| 88 | +│┌──────────┐ ┌─────────┐││┌───────────┐ ┌─────────┐│ │┌─────────────┐ ┌─────────┐│ |
| 89 | +││ Process ├─┤ GPU:P-1 ││││ Process ├─┤ GPU:P-1 ││ ││ Process ├─┤ GPU:P-1 ││ |
| 90 | +││ Rank:P-1 │ └─────────┘│││ Rank:2P-1 │ └─────────┘│ ││ Rank:NP-1 │ └─────────┘│ |
| 91 | +│└──────────┘ ││└───────────┘ │ │└─────────────┘ │ |
| 92 | +└────────────────────────┘└─────────────────────────┘ └───────────────────────────┘ |
| 93 | +``` |
| 94 | + |
| 95 | +</details> |
| 96 | + |
| 97 | +### SLURM |
| 98 | + |
| 99 | +When launched as SLURM job, the follwoing environment variables correspond to |
| 100 | + |
| 101 | +- **SLURM_PROCID*: `--worker-id` (Rank) |
| 102 | +- **SLURM_NTASKS** (or legacy **SLURM_NPPROCS**): the number of total processes (`--num-workers` == world size) |
| 103 | +- **SLURM_LOCALID**: Local Rank (to be mapped with GPU index*) |
| 104 | + |
| 105 | +* Even when GPU resource is allocated with `--gpus-per-task=1`, if there are muptiple |
| 106 | +tasks allocated on the same node, (thus multiple GPUs of the node are allocated to the job) |
| 107 | +each task can see all the GPUs allocated for the tasks. Therefore we need to use |
| 108 | +SLURM_LOCALID to tell each task to which GPU it should be using. |
| 109 | + |
| 110 | +<details><summary>Example scripts for running the training on SLURM cluster</summary> |
| 111 | + |
| 112 | +- **launch_job.sh** |
| 113 | + |
| 114 | +```bash |
| 115 | +#!/bin/bash |
| 116 | + |
| 117 | +#SBATCH --job-name=source_separation |
| 118 | + |
| 119 | +#SBATCH --output=/checkpoint/%u/jobs/%x/%j.out |
| 120 | + |
| 121 | +#SBATCH --error=/checkpoint/%u/jobs/%x/%j.err |
| 122 | + |
| 123 | +#SBATCH --nodes=1 |
| 124 | + |
| 125 | +#SBATCH --ntasks-per-node=8 |
| 126 | + |
| 127 | +#SBATCH --cpus-per-task=8 |
| 128 | + |
| 129 | +#SBATCH --mem-per-cpu=16G |
| 130 | + |
| 131 | +#SBATCH --gpus-per-task=1 |
| 132 | + |
| 133 | +#srun env |
| 134 | +srun wrapper.sh $@ |
| 135 | +``` |
| 136 | + |
| 137 | +- **wrapper.sh** |
| 138 | + |
| 139 | +```bash |
| 140 | +#!/bin/bash |
| 141 | +num_speakers=2 |
| 142 | +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" |
| 143 | +save_dir="/checkpoint/${USER}/jobs/${SLURM_JOB_NAME}/${SLURM_JOB_ID}" |
| 144 | +dataset_dir="/dataset/wsj0-mix/${num_speakers}speakers/wav8k/min" |
| 145 | + |
| 146 | +if [ "${SLURM_JOB_NUM_NODES}" -gt 1 ]; then |
| 147 | + protocol="file:///checkpoint/${USER}/jobs/source_separation/${SLURM_JOB_ID}/sync" |
| 148 | +else |
| 149 | + protocol="env://" |
| 150 | +fi |
| 151 | + |
| 152 | +mkdir -p "${save_dir}" |
| 153 | + |
| 154 | +python -u \ |
| 155 | + "${this_dir}/train.py" \ |
| 156 | + --worker-id "${SLURM_PROCID}" \ |
| 157 | + --num-workers "${SLURM_NTASKS}" \ |
| 158 | + --device-id "${SLURM_LOCALID}" \ |
| 159 | + --sync-protocol "${protocol}" \ |
| 160 | + -- \ |
| 161 | + --num-speakers "${num_speakers}" \ |
| 162 | + --sample-rate 8000 \ |
| 163 | + --dataset-dir "${dataset_dir}" \ |
| 164 | + --save-dir "${save_dir}" \ |
| 165 | + --batch-size $((16 / SLURM_NTASKS)) |
| 166 | +``` |
| 167 | + |
| 168 | +</details> |
0 commit comments