Skip to content

Commit 1f4127f

Browse files
authored
Smooth Quant for Tensorflow backend (#830)
Signed-off-by: Lv, Liang1 <[email protected]>
1 parent d47ea8e commit 1f4127f

File tree

24 files changed

+1547
-17
lines changed

24 files changed

+1547
-17
lines changed

examples/.config/model_params_tensorflow.json

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1766,6 +1766,27 @@
17661766
"main_script": "run_inference.py",
17671767
"batch_size": 128
17681768
},
1769+
"distilbert_base_sq": {
1770+
"model_src_dir": "nlp/distilbert_base/quantization/ptq",
1771+
"dataset_location": "/tf_dataset2/datasets/sst2_validation_dataset",
1772+
"input_model": "/tf_dataset2/models/tensorflow/distilbert_base/fp32/distilbert_base_fp32.pb",
1773+
"main_script": "run_inference.py",
1774+
"batch_size": 128
1775+
},
1776+
"gpt2_medium_sq": {
1777+
"model_src_dir": "nlp/large_language_models/quantization/ptq/smoothquant",
1778+
"dataset_location": "",
1779+
"input_model": "/tf_dataset2/models/tensorflow/gpt2-medium",
1780+
"main_script": "main.py",
1781+
"batch_size": 16
1782+
},
1783+
"opt_125m_sq": {
1784+
"model_src_dir": "nlp/large_language_models/quantization/ptq/smoothquant",
1785+
"dataset_location": "",
1786+
"input_model": "/tf_dataset2/models/tensorflow/facebook-opt-125m",
1787+
"main_script": "main.py",
1788+
"batch_size": 16
1789+
},
17691790
}
17701791
}
17711792

examples/tensorflow/nlp/distilbert_base/quantization/ptq/README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,23 @@ Where (Default values are shown in the square brackets):
111111
* $INTRA_THREADS [28]-- The number of intra op parallelism thread to use, which can be set to the number of physical core per socket
112112

113113

114+
### Run Smooth Quant to improve int8 accuracy
115+
116+
#### Tuning
117+
```shell
118+
bash run_tuning.sh \
119+
--input_model=$INPUT_MODEL \
120+
--dataset_location=$DATASET_DIR \
121+
--output_model=$OUTPUT_MODEL \
122+
--batch_size=$BATCH_SIZE \
123+
--max_seq_length=$MAX_SEQ \
124+
--warmup_steps=$WARMUPS \
125+
--num_inter=$INTER_THREADS \
126+
--num_intra=$INTRA_THREADS \
127+
--sq=True
128+
```
129+
130+
114131
Details of enabling Intel® Neural Compressor on DistilBERT base for TensorFlow
115132
=========================
116133

examples/tensorflow/nlp/distilbert_base/quantization/ptq/run_inference.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def boolean_string(s):
7272
dest="tune",
7373
default=False
7474
)
75+
arg_parser.add_argument('--sq', type=boolean_string, dest='sq', help='smooth quantization', default=False)
7576
arg_parser.add_argument("--benchmark", type=boolean_string,
7677
help="whether to do benchmark",
7778
dest="benchmark",
@@ -201,7 +202,7 @@ def validate_args(self):
201202
logger.warning("Warmup steps greater than max possible value of 22." + \
202203
" Setting to max value of ", MAX_WARMUP_STEPS)
203204
ARGS.warmup_steps = MAX_WARMUP_STEPS
204-
if ARGS.tune or (ARGS.benchmark and ARGS.mode == "accuracy"):
205+
if ARGS.tune or ARGS.sq or (ARGS.benchmark and ARGS.mode == "accuracy"):
205206
ARGS.steps = MAX_STEPS
206207
elif ARGS.benchmark:
207208
if ARGS.steps > (MAX_STEPS - MAX_WARMUP_STEPS):
@@ -271,7 +272,7 @@ def eval_func(self, graph):
271272
else:
272273
pred = sess.run(output, feed_dict=feed_dict)
273274
run_time = time.time() - start_time
274-
if ARGS.tune or (ARGS.benchmark and ARGS.mode == "accuracy"):
275+
if ARGS.tune or ARGS.sq or (ARGS.benchmark and ARGS.mode == "accuracy"):
275276
total_correct_predictions += self.get_correct_predictions(pred, labels)
276277
total_time += run_time
277278
# save profiling file
@@ -287,7 +288,7 @@ def eval_func(self, graph):
287288
with open(profiling_file, 'w') as trace_file:
288289
trace_file.write(trace.generate_chrome_trace_format(show_memory=False))
289290
time_per_batch = total_time / float(ARGS.steps / ARGS.batch_size)
290-
if ARGS.tune or (ARGS.benchmark and ARGS.mode == "accuracy"):
291+
if ARGS.tune or ARGS.sq or (ARGS.benchmark and ARGS.mode == "accuracy"):
291292
accuracy = total_correct_predictions / ARGS.steps
292293
logger.info("Accuracy: {:.4f}".format(accuracy))
293294
if self.dataloader.batch_size == 1:
@@ -297,12 +298,17 @@ def eval_func(self, graph):
297298

298299
def run(self):
299300
graph = self.load_graph()
300-
if ARGS.tune:
301+
if ARGS.tune or ARGS.sq:
301302
from neural_compressor import quantization
302303
from neural_compressor.config import PostTrainingQuantConfig, AccuracyCriterion
303-
accuracy_criterion = AccuracyCriterion(tolerable_loss=0.02)
304-
config = PostTrainingQuantConfig(calibration_sampling_size=[500],
305-
accuracy_criterion=accuracy_criterion)
304+
if ARGS.sq:
305+
config = PostTrainingQuantConfig(calibration_sampling_size=[500],
306+
quant_level=1,
307+
recipes={"smooth_quant": True, "smooth_quant_args": {'alpha': 0.6}})
308+
else:
309+
accuracy_criterion = AccuracyCriterion(tolerable_loss=0.02)
310+
config = PostTrainingQuantConfig(calibration_sampling_size=[500],
311+
accuracy_criterion=accuracy_criterion)
306312
q_model = quantization.fit(model=graph, conf=config, calib_dataloader=self.dataloader,
307313
eval_func=self.eval_func)
308314
try:

examples/tensorflow/nlp/distilbert_base/quantization/ptq/run_tuning.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ function init_params {
2121
num_inter=2
2222
num_intra=28
2323
tune=True
24+
sq=False
2425

2526
for var in "$@"
2627
do
@@ -52,6 +53,9 @@ function init_params {
5253
--tune=*)
5354
tune=$(echo ${var} |cut -f2 -d=)
5455
;;
56+
--sq=*)
57+
sq=$(echo ${var} |cut -f2 -d=)
58+
;;
5559
esac
5660
done
5761

@@ -65,6 +69,7 @@ function run_tuning {
6569
--data-location=${dataset_location} \
6670
--output-graph=${output_model} \
6771
--tune=${tune} \
72+
--sq=${sq} \
6873
--warmup-steps=${warmup_steps} \
6974
--batch-size=${batch_size} \
7075
--max-seq-length=${max_seq_length} \
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
Step-by-Step
2+
============
3+
4+
This document is used to list steps of reproducing TensorFlow Intel® Neural Compressor quantization and smooth quantization of language models such as OPT and GPT2.
5+
6+
## Prerequisite
7+
8+
```shell
9+
# Install Intel® Neural Compressor
10+
pip install neural-compressor
11+
pip install -r requirements
12+
```
13+
## Run
14+
15+
16+
### Basic quantization
17+
18+
```
19+
python main.py --model_name_or_path <MODEL_NAME>
20+
```
21+
22+
`<MODEL_NAME>` can be following:
23+
24+
- gpt2-medium
25+
- facebook/opt-125m
26+
27+
### Smooth quant
28+
29+
```shell
30+
bash run_tuning.sh --input_model=<MODEL_NAME>
31+
```
32+
33+
Or you can use
34+
35+
```
36+
python main.py --model_name_or_path <MODEL_NAME> --sq
37+
```
38+
39+
## Benchmark
40+
41+
### Get the FP32 performance
42+
43+
```shell
44+
bash run_benchmark.sh --input_model=<MODEL_NAME>
45+
```
46+
47+
### Get the INT8 performance
48+
49+
```shell
50+
bash run_benchmark.sh --input_model=<MODEL_NAME> --int8=true
51+
```
52+
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import os.path
2+
import transformers
3+
import tensorflow as tf
4+
from tqdm import tqdm
5+
import sys
6+
import argparse
7+
from datasets import load_dataset
8+
import numpy as np
9+
import time
10+
11+
sys.path.insert(0, './')
12+
13+
parser = argparse.ArgumentParser()
14+
parser.add_argument('--int8', action='store_true', help="eval fp32 model or int8 model")
15+
parser.add_argument('--model_name_or_path', type=str, default='facebook/opt-125m')
16+
parser.add_argument('--batch_size', type=int, default=16)
17+
parser.add_argument('--warmup', type=int, default=10)
18+
args = parser.parse_args()
19+
20+
class Evaluator:
21+
def __init__(self, dataset, tokenizer, device, batch_size=args.batch_size):
22+
self.dataset = dataset
23+
self.tokenizer = tokenizer
24+
self.device = device
25+
self.dataloader = INCDataloader(dataset, tokenizer, batch_size, device)
26+
27+
def evaluate(self, model):
28+
# model.eval()
29+
# The task is to predict the last word of the input.
30+
total, hit = 0, 0
31+
index = 1
32+
for input_ids, label, label_indices in tqdm(self.dataloader):
33+
# TFCausalLMOutputWithPast len: 2
34+
# first element shape (16, 196, 50272)
35+
# second element shape (16, 12, 196, 64)
36+
outputs = model(input_ids)
37+
last_token_logits = outputs[0].numpy()[np.arange(len(label_indices)), label_indices, :]
38+
pred = last_token_logits.argmax(axis=-1)
39+
total += label.shape[0]
40+
hit += (pred == label.numpy()).sum().item()
41+
index += 1
42+
acc = hit / total
43+
print(acc, flush=True)
44+
return acc
45+
46+
def get_attention_mask(self, input_ids):
47+
return tf.constant(1 - (input_ids==1).numpy().astype(int))
48+
49+
def evaluate_tf_v1(self, model):
50+
# return 0.99 # TODO debug remove
51+
total, hit = 0, 0
52+
index = 1
53+
infer = model.signatures["serving_default"]
54+
overall_infer_duration = 0
55+
for input_ids, label, label_indices in tqdm(self.dataloader):
56+
attention_mask = self.get_attention_mask(input_ids)
57+
input_ids = tf.constant(input_ids.numpy(), dtype=infer.inputs[0].dtype)
58+
attention_mask = tf.constant(attention_mask.numpy(), dtype=infer.inputs[0].dtype)
59+
start = time.time()
60+
results = infer(input_ids=input_ids, attention_mask=attention_mask) # len: 25 Identity: [16, 196, 50272], Identity_1: [16, 12, 196, 64]
61+
batch_infer_time = time.time() - start
62+
if index > args.warmup:
63+
overall_infer_duration += batch_infer_time
64+
last_token_logits = results['Identity'].numpy()[np.arange(len(label_indices)), label_indices, :]
65+
pred = last_token_logits.argmax(axis=-1)
66+
total += label.shape[0]
67+
hit += (pred == label.numpy()).sum().item()
68+
index += 1
69+
acc = hit / total
70+
print("\nEvaluation result: ")
71+
print(f"Batch size = {args.batch_size}")
72+
print(f"Accuracy: {acc}")
73+
print(
74+
f"Throughput: {(len(self.dataloader) - args.warmup * args.batch_size) / overall_infer_duration} samples/sec"
75+
)
76+
77+
class INCDataloader:
78+
# for_calib=True in quantization, only input_id is needed, =False in evaluation need label
79+
def __init__(self, dataset, tokenizer, batch_size=1, device='cpu', for_calib=False):
80+
self.dataset = dataset
81+
self.tokenizer = tokenizer
82+
self.device = device
83+
self.batch_size = batch_size
84+
self.for_calib = for_calib
85+
import math
86+
self.length = math.ceil(len(dataset) / self.batch_size) # batch number
87+
self.pad_len = 196
88+
89+
# tokenize the dataset
90+
def tokenize_function(examples):
91+
example = self.tokenizer(examples['text'])
92+
return example
93+
94+
self.dataset = self.dataset.map(tokenize_function, batched=True)
95+
self.dataset.set_format(type='tensorflow', columns=['input_ids'])
96+
def get_attention_mask(self, input_ids):
97+
return 1 - (input_ids==1).numpy().astype(int)
98+
def pad_input(self, input): # input: a record
99+
input_id = input['input_ids']
100+
if input_id.numpy().shape[0] > self.pad_len: # truncate the sequence to pad_len if the sequence is longer than pad_len
101+
input_id = input_id[:self.pad_len]
102+
label = input_id[-1]
103+
pad_len = self.pad_len - input_id.numpy().shape[0]
104+
label_index = -2 - pad_len # last logit index
105+
input_id = tf.pad(input_id, tf.constant([[0,pad_len]]), constant_values=1)
106+
input_id = tf.expand_dims(input_id, axis=0)
107+
label = tf.expand_dims(label, axis=0)
108+
return (input_id, label, label_index)
109+
110+
def __iter__(self):
111+
if self.for_calib:
112+
labels = None
113+
# label_indices = None
114+
for idx, record in enumerate(self.dataset):
115+
input_id, label, label_index = self.pad_input(record)
116+
attention_mask = self.get_attention_mask(input_id)
117+
# compose attention_mask and input_id together
118+
# during the calibration, it requires to yield a <attention_mask, input_id>
119+
# cur_input = tf.constant(np.append(attention_mask, input_id.numpy(), axis=0))
120+
cur_input = {"input_ids": input_id.numpy(), "attention_mask": attention_mask}
121+
assert self.batch_size == 1
122+
yield (cur_input, label)
123+
else:
124+
input_ids = None
125+
labels = None
126+
label_indices = None
127+
for idx, record in enumerate(self.dataset):
128+
input_id, label, label_index = self.pad_input(record)
129+
if input_ids is None:
130+
input_ids = input_id
131+
labels = label
132+
label_indices = [label_index]
133+
else:
134+
input_ids = tf.concat([input_ids, input_id], 0)
135+
labels = tf.concat([labels, label], 0)
136+
137+
label_indices.append(label_index)
138+
139+
if (idx + 1) % self.batch_size == 0:
140+
yield (input_ids, labels, label_indices)
141+
input_ids = None
142+
labels = None
143+
label_indices = None
144+
if (idx + 1) % self.batch_size != 0:
145+
yield (input_ids, labels, label_indices)
146+
147+
def __len__(self):
148+
return self.length
149+
150+
from datasets import load_dataset
151+
152+
model_name = args.model_name_or_path
153+
tokenizer = transformers.AutoTokenizer.from_pretrained(
154+
model_name,
155+
)
156+
eval_dataset = load_dataset('lambada', split='validation')
157+
158+
evaluator = Evaluator(eval_dataset, tokenizer, 'cpu')
159+
160+
if args.int8:
161+
print("benchmarking int8 model")
162+
int8_folder = model_name.split('/')[-1] + "_int8"
163+
if not os.path.exists(int8_folder):
164+
print(f"could not find int8 folder {int8_folder} ")
165+
exit()
166+
model = tf.saved_model.load(int8_folder) # tensorflow.python.trackable.autotrackable.AutoTrackable object
167+
else:
168+
print("benchmaking fp32 model")
169+
model = transformers.TFAutoModelForCausalLM.from_pretrained(model_name)
170+
# fp32_folder = model_name.split('/')[-1] + "_fp32"
171+
# model.save(fp32_folder)
172+
# model = tf.keras.models.load_model(fp32_folder)
173+
from neural_compressor.experimental import common
174+
def keras2SavedModel(model):
175+
model = common.Model(model)
176+
return model.model
177+
model = keras2SavedModel(model) # tensorflow.python.trackable.autotrackable.AutoTrackable object
178+
179+
# TODO current neural_compressor.benchmark does not support AutoTrackable model, we will write our own
180+
# from neural_compressor.benchmark import fit
181+
# from neural_compressor.config import BenchmarkConfig
182+
# conf = BenchmarkConfig(cores_per_instance=28, num_of_instance=1)
183+
# fit(model, conf, b_func=evaluator.evaluate_tf_v1)
184+
evaluator.evaluate_tf_v1(model)

0 commit comments

Comments
 (0)