Skip to content

Commit f582622

Browse files
committed
Fix TF QAT UT issues (#266)
Signed-off-by: zehao-intel <[email protected]>
1 parent 6f2fe41 commit f582622

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

neural_compressor/adaptor/tensorflow.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,17 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
525525
Returns:
526526
tf.compat.v1.GraphDef: the quantized model
527527
"""
528+
if self.approach == "quant_aware_training":
529+
assert q_func is not None, "quantization aware training mode \
530+
is not configured correctly"
531+
532+
from neural_compressor.experimental import common
533+
qat_model = q_func(model)
534+
535+
return self.convert(common.Model(qat_model), 'QAT', 'default')
536+
537+
assert q_func is None, \
538+
"post-training quantization mode is not support calibration function for Tensorflow!"
528539
self.tuning_cfg_to_fw(tune_cfg)
529540
logger.debug("Dump quantization configurations:")
530541
logger.debug(self.quantize_config)

test/quantization/test_tensorflow_qat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def test_qat(self):
8888
compression_manager = training.prepare_compression('./baseline_model', config)
8989
compression_manager.callbacks.on_train_begin()
9090

91+
q_aware_model = compression_manager.model
9192
# `quantize_model` requires a recompile.
9293
q_aware_model.compile(optimizer='adam',
9394
loss=tf.keras.losses.SparseCategoricalCrossentropy(

0 commit comments

Comments
 (0)