Skip to content

Commit f231b46

Browse files
authored
Fix tf distillation / pruning tuning & benchmark (#250)
1 parent ea5c0e9 commit f231b46

File tree

11 files changed

+186
-50
lines changed

11 files changed

+186
-50
lines changed

examples/optimization/tensorflow/huggingface/text-classification/distillation/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ bash run_tuning.sh --topology=topology
2828
```
2929

3030
```
31-
bash run_benchmark.sh --topology=topology --mode=benchmark
31+
bash run_benchmark.sh --topology=topology --mode=benchmark --use_distillation_model=true
3232
```
3333
topology is "distilbert-base-uncased"
3434

examples/optimization/tensorflow/huggingface/text-classification/distillation/run_benchmark.sh

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ function init_params {
1414
batch_size=16
1515
tuned_checkpoint=saved_results
1616
topology="distilbert-base-uncased"
17+
mode="benchmark"
1718
for var in "$@"
1819
do
1920
case $var in
@@ -35,8 +36,8 @@ function init_params {
3536
--iters=*)
3637
iters=$(echo ${var} |cut -f2 -d=)
3738
;;
38-
--int8=*)
39-
int8=$(echo ${var} |cut -f2 -d=)
39+
--use_distillation_model=*)
40+
use_distillation_model=$(echo ${var} |cut -f2 -d=)
4041
;;
4142
--config=*)
4243
tuned_checkpoint=$(echo $var |cut -f2 -d=)
@@ -67,7 +68,11 @@ function run_benchmark {
6768

6869
if [ "${topology}" = "distilbert-base-uncased" ]; then
6970
TASK_NAME='sst2'
70-
model_name_or_path=${tuned_checkpoint}
71+
model_name_or_path=distilbert-base-uncased
72+
fi
73+
74+
if [[ ${use_distillation_model} == "true" ]]; then
75+
extra_cmd=$extra_cmd" --use_distillation_model"
7176
fi
7277

7378
python -u ./run_glue.py \

examples/optimization/tensorflow/huggingface/text-classification/distillation/run_glue.py

Lines changed: 65 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,10 @@ class OptimizationArguments:
284284
)
285285
add_origin_loss: bool = field(
286286
default=False, metadata={"help": "Whether add the origin loss or not"})
287-
benchmark: bool = field(default=False, metadata={"help": "run benchmark."})
287+
benchmark: bool = field(default=False, metadata={"help": "Run benchmark."})
288+
use_distillation_model: bool = field(
289+
default=False,
290+
metadata={"help":"Whether to use pretrained distillation model."})
288291
accuracy_only: bool = field(
289292
default=False,
290293
metadata={
@@ -618,7 +621,7 @@ def compute_metrics(preds, label_ids):
618621
drop_remainder=drop_remainder,
619622
# `label_cols` is needed for user-defined losses, such as in this example
620623
# datasets v2.3.x need "labels", not "label"
621-
label_cols=["labels", "label"]
624+
label_cols=["labels"]
622625
if "label" in dataset.column_names else None,
623626
)
624627
tf_data[key] = data
@@ -682,10 +685,7 @@ def compute_metrics(preds, label_ids):
682685
distillation_config=distillation_conf,
683686
teacher_model=teacher_model,
684687
)
685-
distilled_model.save_pretrained(training_args.output_dir,
686-
saved_model=True)
687-
distilled_model.config.save_pretrained(training_args.output_dir)
688-
tokenizer.save_pretrained(training_args.output_dir)
688+
689689
return
690690

691691
# region Training and validation
@@ -731,17 +731,49 @@ def compute_metrics(preds, label_ids):
731731
raw_datasets = [datasets["validation"]]
732732

733733
total_time = 0
734+
num_examples = 0
735+
if optim_args.use_distillation_model:
736+
model = tf.saved_model.load(training_args.output_dir)
734737
for raw_dataset, tf_dataset, task in zip(raw_datasets, tf_datasets,
735738
tasks):
736-
num_examples = sum(1 for _ in tf_dataset.unbatch())
737-
start = time.time()
738-
eval_predictions = model.predict(tf_dataset)
739-
total_time += time.time() - start
740-
eval_metrics = compute_metrics(eval_predictions,
741-
raw_dataset["label"])
742-
print(f"Evaluation metrics ({task}) Accuracy: ", eval_metrics)
743-
print("Throughput: ", num_examples / total_time)
739+
num_examples += sum(
740+
1 for _ in (tf_dataset.unbatch()
741+
if hasattr(tf_dataset, "unbatch") else tf_dataset
742+
)
743+
)
744744

745+
if optim_args.use_distillation_model:
746+
preds: np.ndarray = None
747+
label_ids: np.ndarray = None
748+
infer = model.signatures[list(model.signatures.keys())[0]]
749+
for i, (inputs, labels) in enumerate(tf_dataset):
750+
for name in inputs:
751+
inputs[name] = tf.constant(inputs[name].numpy(), dtype=tf.int32)
752+
start = time.time()
753+
results = infer(**inputs)
754+
total_time += time.time() - start
755+
for val in results:
756+
if preds is None:
757+
preds = results[val].numpy()
758+
else:
759+
preds = np.append(preds, results[val].numpy(), axis=0)
760+
if label_ids is None:
761+
label_ids = labels.numpy()
762+
else:
763+
label_ids = np.append(label_ids, labels.numpy(), axis=0)
764+
eval_metrics = compute_metrics({"logits": preds}, label_ids)
765+
else:
766+
start = time.time()
767+
eval_predictions = model.predict(tf_dataset)
768+
total_time += time.time() - start
769+
eval_metrics = compute_metrics(eval_predictions, raw_dataset["label"])
770+
print(f"Evaluation metrics ({task}):")
771+
print(eval_metrics)
772+
logger.info("metric ({}) Accuracy: {}".format(task, eval_metrics["accuracy"]))
773+
logger.info(
774+
"Throughput: {} samples/sec".format(
775+
num_examples / total_time)
776+
)
745777
# endregion
746778

747779
# region Prediction
@@ -769,9 +801,26 @@ def compute_metrics(preds, label_ids):
769801
tf_datasets.append(tf_data["user_data"])
770802
raw_datasets.append(datasets["user_data"])
771803

804+
if optim_args.use_distillation_model:
805+
model = tf.saved_model.load(training_args.output_dir)
806+
772807
for raw_dataset, tf_dataset, task in zip(raw_datasets, tf_datasets,
773808
tasks):
774-
test_predictions = model.predict(tf_dataset)
809+
if optim_args.use_distillation_model:
810+
preds: np.ndarray = None
811+
infer = model.signatures[list(model.signatures.keys())[0]]
812+
for i, (inputs, labels) in enumerate(tf_dataset):
813+
for name in inputs:
814+
inputs[name] = tf.constant(inputs[name].numpy(), dtype=tf.int32)
815+
results = infer(**inputs)
816+
for val in results:
817+
if preds is None:
818+
preds = results[val].numpy()
819+
else:
820+
preds = np.append(preds, results[val].numpy(), axis=0)
821+
test_predictions = {"logits": preds}
822+
else:
823+
test_predictions = model.predict(tf_dataset)
775824
if "label" in raw_dataset:
776825
test_metrics = compute_metrics(test_predictions,
777826
raw_dataset["label"])
@@ -795,7 +844,7 @@ def compute_metrics(preds, label_ids):
795844
if is_regression:
796845
writer.write(f"{index}\t{item:3.3f}\n")
797846
else:
798-
item = model.config.id2label[item]
847+
item = config.id2label[item]
799848
writer.write(f"{index}\t{item}\n")
800849
# endregion
801850

examples/optimization/tensorflow/huggingface/text-classification/distillation/run_tuning.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ function init_params {
4040
# run_tuning
4141
function run_tuning {
4242
extra_cmd=''
43-
batch_size=16
43+
batch_size=64
4444
if [ "${topology}" = "distilbert-base-uncased" ]; then
4545
TASK_NAME='sst2'
4646
model_name_or_path=distilbert-base-uncased

examples/optimization/tensorflow/huggingface/text-classification/pruning/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ bash run_tuning.sh --topology=topology
2424
```
2525

2626
```
27-
bash run_benchmark.sh --topology=topology --mode=benchmark
27+
bash run_benchmark.sh --topology=topology --mode=benchmark --use_pruned_model=true
2828
```
2929
topology is "distilbert_base_sst2"
3030

examples/optimization/tensorflow/huggingface/text-classification/pruning/run_benchmark.sh

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ function main {
1111
# init params
1212
function init_params {
1313
iters=100
14-
batch_size=16
14+
batch_size=64
1515
tuned_checkpoint=saved_results
1616
topology="distilbert_base_sst2"
17+
mode="benchmark"
1718
for var in "$@"
1819
do
1920
case $var in
@@ -35,8 +36,8 @@ function init_params {
3536
--iters=*)
3637
iters=$(echo ${var} |cut -f2 -d=)
3738
;;
38-
--int8=*)
39-
int8=$(echo ${var} |cut -f2 -d=)
39+
--use_pruned_model=*)
40+
use_pruned_model=$(echo ${var} |cut -f2 -d=)
4041
;;
4142
--config=*)
4243
tuned_checkpoint=$(echo $var |cut -f2 -d=)
@@ -67,7 +68,11 @@ function run_benchmark {
6768

6869
if [ "${topology}" = "distilbert_base_sst2" ]; then
6970
TASK_NAME='sst2'
70-
model_name_or_path=${tuned_checkpoint}
71+
model_name_or_path=distilbert-base-uncased-finetuned-sst-2-english
72+
fi
73+
74+
if [[ ${use_pruned_model} == "true" ]]; then
75+
extra_cmd=$extra_cmd" --use_pruned_model"
7176
fi
7277

7378
python -u ./run_glue.py \

examples/optimization/tensorflow/huggingface/text-classification/pruning/run_glue.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,10 @@ class OptimizationArguments:
210210
)
211211
benchmark: bool = field(
212212
default=False,
213-
metadata={"help": "run benchmark."})
213+
metadata={"help": "Run benchmark."})
214+
use_pruned_model: bool = field(
215+
default=False,
216+
metadata={"help":"Whether to use pretrained pruned model."})
214217
accuracy_only: bool = field(
215218
default=False,
216219
metadata={"help":"Whether to only test accuracy for model tuned by Neural Compressor."})
@@ -503,7 +506,7 @@ def compute_metrics(preds, label_ids):
503506
drop_remainder=drop_remainder,
504507
# `label_cols` is needed for user-defined losses, such as in this example
505508
# datasets v2.3.x need "labels", not "label"
506-
label_cols=["labels", "label"] if "label" in dataset.column_names else None,
509+
label_cols=["labels"] if "label" in dataset.column_names else None,
507510
)
508511
tf_data[key] = data
509512
# endregion
@@ -573,15 +576,49 @@ def compute_metrics(preds, label_ids):
573576
raw_datasets = [datasets["validation"]]
574577

575578
total_time = 0
576-
for raw_dataset, tf_dataset, task in zip(raw_datasets, tf_datasets, tasks):
577-
num_examples = sum(1 for _ in tf_dataset.unbatch())
578-
start = time.time()
579-
eval_predictions = model.predict(tf_dataset)
580-
total_time += time.time() - start
581-
eval_metrics = compute_metrics(eval_predictions, raw_dataset["label"])
582-
print(f"Evaluation metrics ({task}) Accuracy: ", eval_metrics)
583-
print("Throughput: ", num_examples / total_time)
584-
579+
num_examples = 0
580+
if optim_args.use_pruned_model:
581+
model = tf.saved_model.load(training_args.output_dir)
582+
for raw_dataset, tf_dataset, task in zip(raw_datasets, tf_datasets,
583+
tasks):
584+
num_examples += sum(
585+
1 for _ in (tf_dataset.unbatch()
586+
if hasattr(tf_dataset, "unbatch") else tf_dataset
587+
)
588+
)
589+
if optim_args.use_pruned_model:
590+
preds: np.ndarray = None
591+
label_ids: np.ndarray = None
592+
infer = model.signatures[list(model.signatures.keys())[0]]
593+
for i, (inputs, labels) in enumerate(tf_dataset):
594+
for name in inputs:
595+
inputs[name] = tf.constant(inputs[name].numpy(), dtype=tf.int32)
596+
start = time.time()
597+
results = infer(**inputs)
598+
total_time += time.time() - start
599+
for val in results:
600+
if preds is None:
601+
preds = results[val].numpy()
602+
else:
603+
preds = np.append(preds, results[val].numpy(), axis=0)
604+
if label_ids is None:
605+
label_ids = labels.numpy()
606+
else:
607+
label_ids = np.append(label_ids, labels.numpy(), axis=0)
608+
eval_metrics = compute_metrics({"logits": preds}, label_ids)
609+
else:
610+
start = time.time()
611+
eval_predictions = model.predict(tf_dataset)
612+
total_time += time.time() - start
613+
eval_metrics = compute_metrics(eval_predictions, raw_dataset["label"])
614+
print(f"Evaluation metrics ({task}):")
615+
print(eval_metrics)
616+
617+
logger.info("metric ({}) Accuracy: {}".format(task, eval_metrics["accuracy"]))
618+
logger.info(
619+
"Throughput: {} samples/sec".format(
620+
num_examples / total_time)
621+
)
585622
# endregion
586623

587624
# region Prediction
@@ -606,8 +643,25 @@ def compute_metrics(preds, label_ids):
606643
tf_datasets.append(tf_data["user_data"])
607644
raw_datasets.append(datasets["user_data"])
608645

646+
if optim_args.use_pruned_model:
647+
model = tf.saved_model.load(training_args.output_dir)
648+
609649
for raw_dataset, tf_dataset, task in zip(raw_datasets, tf_datasets, tasks):
610-
test_predictions = model.predict(tf_dataset)
650+
if optim_args.use_pruned_model:
651+
preds: np.ndarray = None
652+
infer = model.signatures[list(model.signatures.keys())[0]]
653+
for i, (inputs, labels) in enumerate(tf_dataset):
654+
for name in inputs:
655+
inputs[name] = tf.constant(inputs[name].numpy(), dtype=tf.int32)
656+
results = infer(**inputs)
657+
for val in results:
658+
if preds is None:
659+
preds = results[val].numpy()
660+
else:
661+
preds = np.append(preds, results[val].numpy(), axis=0)
662+
test_predictions = {"logits": preds}
663+
else:
664+
test_predictions = model.predict(tf_dataset)
611665
if "label" in raw_dataset:
612666
test_metrics = compute_metrics(test_predictions, raw_dataset["label"])
613667
print(f"Test metrics ({task}):")
@@ -626,7 +680,7 @@ def compute_metrics(preds, label_ids):
626680
if is_regression:
627681
writer.write(f"{index}\t{item:3.3f}\n")
628682
else:
629-
item = model.config.id2label[item]
683+
item = config.id2label[item]
630684
writer.write(f"{index}\t{item}\n")
631685
# endregion
632686

0 commit comments

Comments
 (0)