Skip to content

Commit b50ba8d

Browse files
committed
Refactor Quantization Aware Training of TF backend (#250)
Signed-off-by: zehao-intel <[email protected]>
1 parent cceaf12 commit b50ba8d

File tree

36 files changed

+2064
-300
lines changed

36 files changed

+2064
-300
lines changed

.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2439,6 +2439,9 @@ AWSSageMakerSupport
24392439
sagemaker
24402440
BenchmarkConfig
24412441
xpu
2442+
dgpu
2443+
BenchmarkConfig
2444+
QuantizationAwareTrainingConfig
24422445
Startup
24432446
doesn
24442447
startup

examples/.config/model_params_tensorflow.json

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1940,6 +1940,13 @@
19401940
"batch_size": 1,
19411941
"new_benchmark": false
19421942
},
1943+
"mnist_keras": {
1944+
"model_src_dir": "image_recognition/keras_models/mnist/quantization/qat",
1945+
"dataset_location": "",
1946+
"input_model": "/tf_dataset2/models/tensorflow/mnist_keras/saved_model/",
1947+
"main_script": "main.py",
1948+
"batch_size": 32
1949+
},
19431950
"resnet50_fashion": {
19441951
"model_src_dir": "image_recognition/keras_models/resnet50_fashion/quantization/ptq",
19451952
"dataset_location": "/tf_dataset2/datasets/mnist/FashionMNIST",
@@ -1958,6 +1965,13 @@
19581965
"batch_size": 1,
19591966
"new_benchmark": true
19601967
},
1968+
"resnet50_keras_qat": {
1969+
"model_src_dir": "image_recognition/keras_models/resnet50/quantization/qat",
1970+
"dataset_location": "/tf_dataset/dataset/imagenet",
1971+
"input_model": "/tf_dataset2/models/tensorflow/resnet50_keras/resnet50",
1972+
"main_script": "main.py",
1973+
"batch_size": 32
1974+
},
19611975
"resnet50_keras_h5": {
19621976
"model_src_dir": "image_recognition/keras_models/resnet50/quantization/ptq",
19631977
"dataset_location": "/tf_dataset/dataset/imagenet",
Lines changed: 88 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,118 @@
11
Step-by-Step
22
============
33

4-
This document is used to list steps of reproducing TensorFlow keras Intel® Neural Compressor QAT conversion.
4+
This document is used to apply QAT to Tensorflow Keras models using Intel® Neural Compressor.
55
This example can run on Intel CPUs and GPUs.
66

7-
87
## Prerequisite
98

109
### 1. Installation
1110
```shell
1211
# Install Intel® Neural Compressor
1312
pip install neural-compressor
1413
```
15-
### 2. Install Intel Tensorflow and TensorFlow Model Optimization
14+
### 2. Install requirements
15+
The Tensorflow and intel-extension-for-tensorflow is mandatory to be installed to run this QAT example.
16+
The Intel Extension for Tensorflow for Intel CPUs is installed as default.
1617
```shell
17-
pip install intel-tensorflow==2.4.0
18-
pip install tensorflow_model_optimization==0.5.0
18+
pip install -r requirements.txt
1919
```
20-
> Note: To generate correct qat model with tensorflow_model_optimization 0.5.0, pls use TensorFlow 2.4 or above.
20+
> Note: Supported Tensorflow [Version](../../../../../../../README.md).
2121
22-
### 3. Install Intel Extension for Tensorflow
22+
### 3. Benchmarking the model on Intel GPU (Optional)
2323

24-
#### Quantizing the model on Intel GPU
25-
Intel Extension for Tensorflow is mandatory to be installed for quantizing the model on Intel GPUs.
24+
To run benchmark of the model on Intel GPUs, Intel Extension for Tensorflow for Intel GPUs is required.
2625

2726
```shell
2827
pip install --upgrade intel-extension-for-tensorflow[gpu]
2928
```
30-
For any more details, please follow the procedure in [install-gpu-drivers](https://github.com/intel-innersource/frameworks.ai.infrastructure.intel-extension-for-tensorflow.intel-extension-for-tensorflow/blob/master/docs/install/install_for_gpu.md#install-gpu-drivers)
31-
32-
#### Quantizing the model on Intel CPU(Experimental)
33-
Intel Extension for Tensorflow for Intel CPUs is experimental currently. It's not mandatory for quantizing the model on Intel CPUs.
3429

35-
```shell
36-
pip install --upgrade intel-extension-for-tensorflow[cpu]
37-
```
30+
Please refer to the [Installation Guides](https://dgpu-docs.intel.com/installation-guides/ubuntu/ubuntu-focal-dc.html) for latest Intel GPU driver installation.
31+
For any more details, please follow the procedure in [install-gpu-drivers](https://github.com/intel-innersource/frameworks.ai.infrastructure.intel-extension-for-tensorflow.intel-extension-for-tensorflow/blob/master/docs/install/install_for_gpu.md#install-gpu-drivers).
3832

3933
### 4. Prepare Pretrained model
4034

41-
Run the `train.py` script to get pretrained fp32 model.
35+
The pretrained model is provided by [Keras Applications](https://keras.io/api/applications/). prepare the model, Run as follow:
36+
```
4237
43-
### 5. Prepare QAT model
44-
45-
Run the `qat.py` script to get QAT model which in fact is a fp32 model with quant/dequant pair inserted.
46-
47-
## Write Yaml config file
48-
In examples directory, there is a mnist.yaml for tuning the model on Intel CPUs. The 'framework' in the yaml is set to 'tensorflow'. If running this example on Intel GPUs, the 'framework' should be set to 'tensorflow_itex' and the device in yaml file should be set to 'gpu'. The mnist_itex.yaml is prepared for the GPU case. We could remove most of items and only keep mandatory item for tuning. We also implement a calibration dataloader and have evaluation field for creation of evaluation function at internal neural_compressor.
38+
python prepare_model.py --output_model=/path/to/model
39+
```
40+
`--output_model ` the model should be saved as SavedModel format or H5 format.
4941

5042
## Run Command
5143
```shell
52-
python convert.py # to convert QAT model to quantized model.
53-
54-
python benchmark.py # to run accuracy benchmark.
44+
bash run_tuning.sh --input_model=./path/to/model --output_model=./result
45+
bash run_benchmark.sh --input_model=./path/to/model --mode=performance --batch_size=32
5546
```
5647

48+
Details of enabling Intel® Neural Compressor to apply QAT.
49+
=========================
50+
51+
This is a tutorial of how to to apply QAT with Intel® Neural Compressor.
52+
## User Code Analysis
53+
1. User specifies fp32 *model* to apply quantization, the dataset is automatically downloaded. In this step, QDQ patterns will be inserted to the keras model, but the fp32 model will not be converted to a int8 model.
54+
55+
2. User specifies *model* with QDQ patterns inserted, evaluate function to run benchmark. The model we get from the previous step will be run on ITEX backend. Then, the model is going to be fused and inferred.
56+
57+
### Quantization Config
58+
The Quantization Config class has default parameters setting for running on Intel CPUs. If running this example on Intel GPUs, the 'backend' parameter should be set to 'itex' and the 'device' parameter should be set to 'gpu'.
59+
60+
```
61+
config = QuantizationAwareTrainingConfig(
62+
device="gpu",
63+
backend="itex",
64+
...
65+
)
66+
```
67+
68+
### Code update
69+
70+
After prepare step is done, we add quantization and benchmark code to generate quantized model and benchmark.
71+
72+
#### Tune
73+
```python
74+
logger.info('start quantizing the model...')
75+
from neural_compressor import training, QuantizationAwareTrainingConfig
76+
config = QuantizationAwareTrainingConfig()
77+
# create a compression_manager instance to implement QAT
78+
compression_manager = training.prepare_compression(FLAGS.input_model, config)
79+
# QDQ patterns will be inserted to the input keras model
80+
compression_manager.callbacks.on_train_begin()
81+
# get the model with QDQ patterns inserted
82+
q_aware_model = compression_manager.model.model
83+
84+
# training code defined by users
85+
q_aware_model.compile(optimizer='adam',
86+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
87+
metrics=['accuracy'])
88+
q_aware_model.summary()
89+
train_images_subset = train_images[0:1000]
90+
train_labels_subset = train_labels[0:1000]
91+
q_aware_model.fit(train_images_subset, train_labels_subset,
92+
batch_size=500, epochs=1, validation_split=0.1)
93+
_, q_aware_model_accuracy = q_aware_model.evaluate(
94+
test_images, test_labels, verbose=0)
95+
print('Quant test accuracy:', q_aware_model_accuracy)
96+
97+
# apply some post process steps and save the output model
98+
compression_manager.callbacks.on_train_end()
99+
compression_manager.save(FLAGS.output_model)
100+
```
101+
#### Benchmark
102+
```python
103+
from neural_compressor.benchmark import fit
104+
from neural_compressor.experimental import common
105+
from neural_compressor.config import BenchmarkConfig
106+
assert FLAGS.mode == 'performance' or FLAGS.mode == 'accuracy', \
107+
"Benchmark only supports performance or accuracy mode."
108+
109+
# convert the quantized keras model to graph_def so that it can be fused by ITEX
110+
model = common.Model(FLAGS.input_model).graph_def
111+
if FLAGS.mode == 'performance':
112+
conf = BenchmarkConfig(cores_per_instance=4, num_of_instance=7)
113+
fit(model, conf, b_func=evaluate)
114+
elif FLAGS.mode == 'accuracy':
115+
accuracy = evaluate(model)
116+
print('Batch size = %d' % FLAGS.batch_size)
117+
print("Accuracy: %.5f" % accuracy)
118+
```

examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/benchmark.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/convert.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

0 commit comments

Comments
 (0)