Skip to content

Commit 1bee4df

Browse files
authored
support multi-node pruning in Tensorflow (#218)
1 parent 78d26f8 commit 1bee4df

File tree

10 files changed

+375
-156
lines changed

10 files changed

+375
-156
lines changed

examples/optimization/tensorflow/huggingface/text-classification/pruning/README.md

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,46 @@ bash run_tuning.sh --topology=topology
2626
```
2727
bash run_benchmark.sh --topology=topology --mode=benchmark
2828
```
29-
topology is "distilbert_base_sst2"
29+
topology is "distilbert_base_sst2"
30+
31+
32+
### Multi-node usage
33+
34+
We also supported Distributed Data Parallel training on multi nodes settings for pruning.
35+
36+
The default strategy we used is `MultiWorkerMirroredStrategy` in Tensorflow, and with `task_type` set as "worker", we are expected to pass following extra parameters to the script:
37+
38+
* `worker`: a string of your worker ip addresses which is separated by comma and there should not be space between each two of them
39+
40+
* `task_index`: 0 should be set on the chief node (leader) and 1, 2, 3... should be set as the rank of other follower nodes
41+
42+
### Multi-node example
43+
44+
* On leader node
45+
46+
```
47+
bash run_tuning.sh --topology=distilbert_base_sst2 --worker="localhost:12345,localhost:23456" --task_index=0
48+
```
49+
50+
which is equal to
51+
52+
```
53+
python run_glue.py \
54+
--model_name_or_path distilbert-base-uncased-finetuned-sst-2-english \
55+
--task_name sst2 \
56+
--prune \
57+
--do_train \
58+
--do_eval \
59+
--output_dir ./tmp/sst2_output \
60+
--overwrite_output_dir \
61+
--worker "localhost:12345,localhost:23456" \
62+
--task_index 0
63+
```
64+
65+
* On follower node
66+
67+
```
68+
bash run_tuning.sh --topology=distilbert_base_sst2 --worker="localhost:12345,localhost:23456" --task_index=1
69+
```
70+
71+
Please replace the worker ip address list with your own.

examples/optimization/tensorflow/huggingface/text-classification/pruning/run_glue.py

Lines changed: 160 additions & 122 deletions
Large diffs are not rendered by default.

examples/optimization/tensorflow/huggingface/text-classification/pruning/run_tuning.sh

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,15 @@ function init_params {
2525
--input_model=*)
2626
input_model=$(echo $var |cut -f2 -d=)
2727
;;
28-
--output_model=*)
29-
tuned_checkpoint=$(echo $var |cut -f2 -d=)
30-
;;
28+
--output_model=*)
29+
tuned_checkpoint=$(echo $var |cut -f2 -d=)
30+
;;
31+
--worker=*)
32+
worker=$(echo $var |cut -f2 -d=)
33+
;;
34+
--task_index=*)
35+
task_index=$(echo $var |cut -f2 -d=)
36+
;;
3137
*)
3238
echo "Error: No such parameter: ${var}"
3339
exit 1
@@ -40,24 +46,43 @@ function init_params {
4046
# run_tuning
4147
function run_tuning {
4248
extra_cmd=''
43-
batch_size=16
49+
batch_size=64
4450
if [ "${topology}" = "distilbert_base_sst2" ]; then
4551
TASK_NAME='sst2'
4652
model_name_or_path=distilbert-base-uncased-finetuned-sst-2-english
4753
fi
4854

49-
python -u ./run_glue.py \
50-
--model_name_or_path ${model_name_or_path} \
51-
--task_name ${TASK_NAME} \
52-
--target_sparsity_ratio 0.1 \
53-
--prune \
54-
--do_eval \
55-
--do_train \
56-
--per_device_eval_batch_size ${batch_size} \
57-
--output_dir ${tuned_checkpoint} \
58-
--overwrite_output_dir \
59-
--overwrite_cache \
60-
${extra_cmd}
55+
if [ "${worker}" = "" ]
56+
then
57+
python -u ./run_glue.py \
58+
--model_name_or_path ${model_name_or_path} \
59+
--task_name ${TASK_NAME} \
60+
--target_sparsity_ratio 0.1 \
61+
--prune \
62+
--do_eval \
63+
--do_train \
64+
--per_device_train_batch_size ${batch_size} \
65+
--per_device_eval_batch_size ${batch_size} \
66+
--output_dir ${tuned_checkpoint} \
67+
--overwrite_output_dir \
68+
--overwrite_cache
69+
else
70+
python -u ./run_glue.py \
71+
--model_name_or_path ${model_name_or_path} \
72+
--task_name ${TASK_NAME} \
73+
--target_sparsity_ratio 0.1 \
74+
--prune \
75+
--do_eval \
76+
--do_train \
77+
--per_device_train_batch_size ${batch_size} \
78+
--per_device_eval_batch_size ${batch_size} \
79+
--output_dir ${tuned_checkpoint} \
80+
--overwrite_output_dir \
81+
--overwrite_cache \
82+
--worker "${worker}" \
83+
--task_index ${task_index} \
84+
${extra_cmd}
85+
fi
6186
}
6287

6388
main "$@"

nlp_toolkit/optimization/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def framework(self):
363363

364364
@framework.setter
365365
def framework(self, framework):
366-
assert framework.lower() in ["pytorch", "pytorch_fx"], \
366+
assert framework.lower() in ["pytorch", "pytorch_fx", "tensorflow"], \
367367
"framework: {} is not support!".format(framework)
368368
self.inc_config.usr_cfg.model.framework = framework.lower()
369369

nlp_toolkit/optimization/optimizer_tf.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from transformers import PreTrainedModel
3535
from transformers.training_args_tf import TFTrainingArguments
3636
from typing import Callable, Optional, List
37-
from .utils.utility_tf import TFDataloader, TMPPATH
37+
from .utils.utility_tf import TFDataloader, TMPPATH, get_filepath
3838

3939
tf = LazyImport("tensorflow")
4040
logger = logging.getLogger(__name__)
@@ -50,6 +50,8 @@ def __init__(
5050
compute_metrics: Optional[Callable] = None,
5151
criterion = None,
5252
optimizer = None,
53+
task_type = None,
54+
task_id = None,
5355
):
5456
"""
5557
Args:
@@ -78,11 +80,14 @@ def __init__(
7880
self.compute_metrics = compute_metrics
7981
self.args = args
8082
self.optimizer = optimizer
83+
self.task_type = task_type
84+
self.task_id = task_id
8185
self.criterion = criterion if criterion is not None else \
8286
self.model.loss if hasattr(self.model, "loss") else None
83-
self.model.save_pretrained(TMPPATH, saved_model=True)
87+
self.model.save_pretrained(get_filepath(TMPPATH, self.task_type, self.task_id), saved_model=True)
8488
_, self.input_names, self.output_names = saved_model_session(
85-
os.path.join(TMPPATH,"saved_model/1"), input_tensor_names=[], output_tensor_names=[])
89+
os.path.join(get_filepath(TMPPATH, self.task_type, self.task_id), "saved_model/1"), input_tensor_names=[],
90+
output_tensor_names=[])
8691
self.eval_distributed = False
8792

8893
@property
@@ -298,7 +303,8 @@ def init_quantizer(
298303
self.metrics = self.quant_config.metrics
299304

300305
quantizer = Quantization(self.quant_config.inc_config)
301-
quantizer.model = common.Model(os.path.join(TMPPATH,"saved_model/1"), modelType="saved_model")
306+
quantizer.model = common.Model(
307+
os.path.join(get_filepath(TMPPATH, self.task_type, self.task_id),"saved_model/1"), modelType="saved_model")
302308

303309
self.quantizer = quantizer
304310
return quantizer
@@ -325,8 +331,7 @@ def _inc_quantize(
325331
batch_size=self.args.per_device_eval_batch_size)
326332
else: # pragma: no cover
327333
assert False, "Please pass calibration dataset to TFNoTrainerOptimizer.calib_dataloader"
328-
elif self.quant_config.approach == QuantizationMode.QUANTIZATIONAWARETRAINING.value:
329-
# pragma: no cover
334+
elif self.quant_config.approach == QuantizationMode.QUANTIZATIONAWARETRAINING.value: # pragma: no cover
330335
assert False, \
331336
"Unsupport quantization aware training for tensorflow framework"
332337

@@ -369,7 +374,7 @@ def init_pruner(
369374
"please pass a instance of PruningConfig to trainer.prune!"
370375

371376
pruner = Pruning(self.pruning_config.inc_config)
372-
pruner.model = os.path.join(TMPPATH,"saved_model/1")
377+
pruner.model = os.path.join(get_filepath(TMPPATH, self.task_type, self.task_id),"saved_model/1")
373378
pruner.model.model_type = "saved_model"
374379

375380
self.pruner = pruner
@@ -416,7 +421,11 @@ def prune(
416421

417422
opt_model = self.pruner.fit()
418423

419-
return self.model
424+
opt_model.save(self.args.output_dir)
425+
logger.info(
426+
"pruned model have saved to {}".format(self.args.output_dir)
427+
)
428+
return opt_model.model
420429

421430
def init_distiller(
422431
self,
@@ -506,4 +515,4 @@ def on_train_batch_end(self, batch, logs=None):
506515
callbacks=[PruningCb()])
507516

508517
self.pruner.model._sess = None
509-
input_model.save_pretrained(TMPPATH, saved_model=True)
518+
input_model.save_pretrained(get_filepath(TMPPATH, self.task_type, self.task_id), saved_model=True)

nlp_toolkit/optimization/utils/metrics.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2022 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
118
class Metric(object):
219
def __init__(self, name: str, greater_is_better: bool = True, is_relative: bool = True,
320
criterion: float = 0.01, weight_ratio: float = None):

nlp_toolkit/optimization/utils/objectives.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2022 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
118
class Objective(object):
219
def __init__(self, name: str, greater_is_better: bool = True, weight_ratio: float = None):
320
self.name = name

nlp_toolkit/optimization/utils/utility.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2022 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
118
import importlib
219
import os
320
from neural_compressor.utils.utility import LazyImport

nlp_toolkit/optimization/utils/utility_tf.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
1-
from collections import OrderedDict, UserDict
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2022 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
217

18+
from collections import OrderedDict, UserDict
19+
import json
20+
import os
321

422
TMPPATH = "tmp"
523
class TFDataloader(object):
@@ -31,4 +49,28 @@ def __iter__(self):
3149
labels = [label.numpy() for label in labels]
3250
else:
3351
labels = labels.numpy()
34-
yield inputs, labels
52+
yield inputs, labels
53+
54+
55+
def distributed_init(worker_addresses, type='worker', index=0):
56+
tf_config = {
57+
'cluster': {
58+
'worker': worker_addresses
59+
},
60+
'task': {'type': type, 'index': index}
61+
}
62+
os.environ['TF_CONFIG'] = json.dumps(tf_config)
63+
64+
def _is_chief(task_type, task_id):
65+
# here only consider the case in which TF_CONFIG task_type is set as worker
66+
# and task_id=0 represents the chief
67+
return (task_type == 'worker' and task_id == 0)
68+
69+
# get model folder path for the distributed environment
70+
def get_filepath(base_dirpath, task_type, task_id):
71+
if task_type is None: # single node
72+
return base_dirpath
73+
elif _is_chief(task_type, task_id):
74+
return os.path.join(base_dirpath, 'chief')
75+
else:
76+
return os.path.join(base_dirpath, 'worker_' + str(task_id))

tests/test_tf_pruning.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from nlp_toolkit.optimization.utils.utility_tf import get_filepath
12
import numpy as np
23
import os
34
import shutil
@@ -74,6 +75,19 @@ def tearDownClass(self):
7475
shutil.rmtree('./quantized_model', ignore_errors=True)
7576

7677
def test_tf_model_quant(self):
78+
# check whether it is possible to set distributed environment
79+
# only for coverage currently
80+
from nlp_toolkit.optimization.utils.utility_tf import distributed_init
81+
distributed_init(["localhost:12345","localhost:23456"], "worker", 0)
82+
self.assertTrue(os.environ['TF_CONFIG'] != None)
83+
del os.environ['TF_CONFIG']
84+
# check whether filepath can be set correctly if using distributed environment
85+
# only for coverage currently
86+
from nlp_toolkit.optimization.utils.utility_tf import get_filepath
87+
self.assertTrue(type(get_filepath("dummy", "worker", 0)) == str)
88+
self.assertTrue(type(get_filepath("dummy", "worker", 1)) == str)
89+
self.assertTrue(get_filepath("dummy", "worker", 0) != get_filepath("dummy", "worker", 1))
90+
7791
metric = load_metric("glue", "sst2")
7892
def compute_metrics(preds, label_ids):
7993
preds = preds["logits"]
@@ -99,12 +113,10 @@ def compute_metrics(preds, label_ids):
99113
epochs=int(1), pruner_config=pruner_config, metrics=tune_metric
100114
)
101115
p_model = self.optimizer.prune(pruning_config=pruning_conf)
102-
p_model.save_pretrained(self.args.output_dir, saved_model=True)
103-
loaded_model = tf.saved_model.load(os.path.join(self.args.output_dir, "saved_model/1"))
104-
116+
loaded_model = tf.saved_model.load(self.args.output_dir)
105117
p_model = self.optimizer.prune(pruning_config=pruning_conf,
106-
train_dataset=self.dummy_dataset,
107-
eval_dataset=self.dummy_dataset,)
118+
train_dataset=self.dummy_dataset,
119+
eval_dataset=self.dummy_dataset,)
108120

109121
def eval_func(model):
110122
return 1

0 commit comments

Comments
 (0)