From 8f1f318bcb177b2d153caac06f27a51e09e94893 Mon Sep 17 00:00:00 2001 From: Clark Chin Date: Thu, 8 Dec 2022 16:09:30 +0800 Subject: [PATCH 01/14] add keras-in/keras-out to INC 1. support keras-level transformation instead of freezing to graph_def 2. support specific Q/DQ/QDense/QConv2D op in keras level 3. support mnist model to keras --- examples/keras/minist/README.md | 41 ++ examples/keras/minist/conf.yaml | 28 + examples/keras/minist/mnist.py | 95 +++ examples/keras/minist/requirements.txt | 2 + neural_compressor/adaptor/keras.py | 551 ++++++++++++++++++ neural_compressor/adaptor/keras.yaml | 67 +++ .../adaptor/keras_utils/quantizer.py | 268 +++++++++ neural_compressor/conf/pythonic_config.py | 13 +- .../data/dataloaders/dataloader.py | 1 + neural_compressor/model/keras_model.py | 75 +++ neural_compressor/model/model.py | 133 +---- neural_compressor/strategy/strategy.py | 3 + 12 files changed, 1168 insertions(+), 109 deletions(-) create mode 100644 examples/keras/minist/README.md create mode 100644 examples/keras/minist/conf.yaml create mode 100644 examples/keras/minist/mnist.py create mode 100644 examples/keras/minist/requirements.txt create mode 100644 neural_compressor/adaptor/keras.py create mode 100644 neural_compressor/adaptor/keras.yaml create mode 100644 neural_compressor/adaptor/keras_utils/quantizer.py create mode 100644 neural_compressor/model/keras_model.py diff --git a/examples/keras/minist/README.md b/examples/keras/minist/README.md new file mode 100644 index 00000000000..163b7c96a70 --- /dev/null +++ b/examples/keras/minist/README.md @@ -0,0 +1,41 @@ +Step-by-Step +============ + +This document list steps of reproducing Keras mnist model tuning results via Neural Compressor. +This example can run on Intel CPUs. + +# Prerequisite + +### 1. Installation +Recommend python 3.6 or higher version. + +```shell +# Install IntelĀ® Neural Compressor +pip install neural-compressor +``` + +### 2. Install Tensorflow +```shell +pip install tensorflow +``` +> Note: Supported Tensorflow version > 2.10.0. + +### 3. Installation Dependency packages +```shell +cd examples/keras/mnist/ +pip install -r requirements.txt +``` + +#### Quantizing the model on Intel CPU(Experimental) +Intel Extension for Tensorflow for Intel CPUs is experimental currently. It's not mandatory for quantizing the model on Intel CPUs. + +```shell +pip install --upgrade intel-extension-for-tensorflow[cpu] +``` + +# Run + + ```shell + cd examples/keras/mnist/ + python mnist.py + ``` diff --git a/examples/keras/minist/conf.yaml b/examples/keras/minist/conf.yaml new file mode 100644 index 00000000000..43828fb6332 --- /dev/null +++ b/examples/keras/minist/conf.yaml @@ -0,0 +1,28 @@ +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +version: 1.0 + +model: + name: mnist + framework: keras # mandatory. possible values are tensorflow, mxnet, pytorch, pytorch_ipex, onnxrt_integerops and onnxrt_qlinearops. + +tuning: + accuracy_criterion: + relative: 0.01 # the tuning target of accuracy loss percentage: 1% + exit_policy: + timeout: 0 # tuning timeout (seconds) + random_seed: 100 # random seed + diff --git a/examples/keras/minist/mnist.py b/examples/keras/minist/mnist.py new file mode 100644 index 00000000000..68d53da32c8 --- /dev/null +++ b/examples/keras/minist/mnist.py @@ -0,0 +1,95 @@ +# Copyright 2019 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +import tensorflow as tf +import numpy as np +from tensorflow import keras +from tensorflow.keras import layers +import time + +num_classes = 10 + +def build_dataset(): + # Load the data and split it between train and test sets + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + + # Scale images to the [0, 1] range + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + # Make sure images have shape (28, 28, 1) + x_train = np.expand_dims(x_train, -1) + x_test = np.expand_dims(x_test, -1) + + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, num_classes) + y_test = keras.utils.to_categorical(y_test, num_classes) + return x_train, y_train, x_test, y_test + +def build_model(x_train, y_train, x_test, y_test): + if os.path.exists('fp32_model'): + model = keras.models.load_model('fp32_model') + return model + # Model / data parameters + input_shape = (28, 28, 1) + model = keras.Sequential( + [ + keras.Input(shape=input_shape), + layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Flatten(), + layers.Dropout(0.5), + layers.Dense(num_classes, activation="softmax"), + ] + ) + + batch_size = 128 + epochs = 1 + + model.compile(loss="categorical_crossentropy", optimizer="adam", + metrics=["accuracy"], run_eagerly=True) + model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1) + model.summary() + # start_time = time.time() + # fp32_score = model.evaluate(x_test, y_test, verbose=0) + # duration = time.time() - start_time + # print("Fp32 accuracy:", fp32_score[1]) + # print("Fp32 duration:", duration) + if not os.path.exists('fp32_model'): + model.save('fp32_model') + return model + +def eval_func(model): + x_train, y_train, x_test, y_test = build_dataset() + model.compile(metrics=["accuracy"], run_eagerly=False) + score = model.evaluate(x_test, y_test) + return score[1] + +def main(): + x_train, y_train, x_test, y_test = build_dataset() + model = build_model(x_train, y_train, x_test, y_test) + from neural_compressor.experimental import Quantization, common + quantizer = Quantization('./conf.yaml') + quantizer.model = common.Model(model) + quantizer.calib_dataloader = common.DataLoader((x_train[:100], y_train[:100]), batch_size=10) + quantizer.eval_func = eval_func + quantized_model = quantizer.fit() + + +if __name__ == '__main__': + main() + diff --git a/examples/keras/minist/requirements.txt b/examples/keras/minist/requirements.txt new file mode 100644 index 00000000000..2755e1a41ac --- /dev/null +++ b/examples/keras/minist/requirements.txt @@ -0,0 +1,2 @@ +tensorflow +neural-compressor diff --git a/neural_compressor/adaptor/keras.py b/neural_compressor/adaptor/keras.py new file mode 100644 index 00000000000..b396fb22e65 --- /dev/null +++ b/neural_compressor/adaptor/keras.py @@ -0,0 +1,551 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +import json +import yaml +import math +import numpy as np +from collections import OrderedDict, UserDict +from .query import QueryBackendCapability +from .adaptor import adaptor_registry, Adaptor +from ..utils.utility import LazyImport, CpuInfo, singleton, Dequantize, dump_elapsed_time +from ..utils.utility import Statistics, GLOBAL_STATE, MODE, version1_lt_version2 +from ..utils import logger +from ..conf.dotdict import deep_get +from ..experimental.data.dataloaders.base_dataloader import BaseDataLoader + +tf = LazyImport('tensorflow') + +# REGISTERED_LAYERS = [ +# "Quantize", +# "DeQuantize", +# "FakeQuant", +# "QConv2D", +# "QDense", +# ] + +def _add_supported_quantized_objects(custom_objects): + """Map all the quantized objects.""" + from neural_compressor.adaptor.keras_utils.quantizer import Quantize, DeQuantize + from neural_compressor.adaptor.keras_utils.quantizer import FakeQuant, QConv2D, QDense + custom_objects["Quantize"] = Quantize + custom_objects["DeQuantize"] = DeQuantize + custom_objects["FakeQuant"] = FakeQuant + custom_objects["QConv2D"] = QConv2D + custom_objects["QDense"] = QDense + return custom_objects + +@adaptor_registry +class KerasAdaptor(Adaptor): + '''The keras class of framework adaptor layer. + + ''' + def __init__(self, framework_specific_info): + super(KerasAdaptor, self).__init__(framework_specific_info) + self.framework_specific_info = framework_specific_info + self.approach = deep_get(self.framework_specific_info, 'approach', False) + self.quantize_config = {'op_wise_config': {}} + self.device = self.framework_specific_info['device'] + self.work_dir = os.path.abspath(self.framework_specific_info['workspace_path']) + self.recipes = deep_get(self.framework_specific_info, 'recipes', {}) + os.makedirs(self.work_dir, exist_ok=True) + + self.pre_optimizer_handle = None + + self.fp32_ops = [] + + self.query_handler = KerasQuery(local_config_file=os.path.join( + os.path.dirname(__file__), 'keras.yaml')) + + self.fp32_results = [] + self.fp32_preds_as_label = False + self.benchmark = (GLOBAL_STATE.STATE == MODE.BENCHMARK) + self.callbacks = [] + self.optype_statistics = None + + + def tuning_cfg_to_fw(self, tuning_cfg): + self.quantize_config['calib_iteration'] = tuning_cfg['calib_iteration'] + self.quantize_config['device'] = self.device + self.quantize_config['advance'] = deep_get(tuning_cfg, 'advance') + fp32_ops = [] + dispatched_op_names = [j[0] for j in tuning_cfg['op']] + + invalid_op_names = [i for i in self.quantize_config['op_wise_config'] + if i not in dispatched_op_names] + + for op_name in invalid_op_names: + self.quantize_config['op_wise_config'].pop(op_name) + + for each_op_info in tuning_cfg['op']: + op_name = each_op_info[0] + + if tuning_cfg['op'][each_op_info]['activation']['dtype'] == 'fp32': + if op_name in self.quantize_config['op_wise_config']: + self.quantize_config['op_wise_config'].pop(op_name) + fp32_ops.append(op_name) + continue + + is_perchannel = False + bit = None + if 'weight' in tuning_cfg['op'][each_op_info]: + is_perchannel = tuning_cfg['op'][each_op_info]['weight'][ + 'granularity'] == 'per_channel' + #bit = tuning_cfg['op'][each_op_info]['weight']['bit'] + weight_bit = bit if bit else 7.0 + + algorithm = tuning_cfg['op'][each_op_info]['activation']['algorithm'] + + is_asymmetric = False + if 'activation' in tuning_cfg['op'][each_op_info]: + is_asymmetric = tuning_cfg['op'][each_op_info]['activation']['scheme'] == 'asym' + self.quantize_config['op_wise_config'][op_name] = (is_perchannel, + algorithm, + is_asymmetric, + weight_bit) + self.fp32_ops = fp32_ops + + @dump_elapsed_time("Pass quantize model") + def quantize(self, tune_cfg, model, dataloader, q_func=None): + '''Execute the quantize process on the specified model. + + Args: + tune_cfg(dict): The chosen tuning configuration. + model (object): The model to do quantization. + dataloader(object): The dataloader used to load quantization dataset. + q_func (optional): training function for quantization aware training mode. + ''' + import pdb; pdb.set_trace() + self.tuning_cfg_to_fw(tune_cfg) + logger.debug("Dump quantization configurations:") + logger.debug(self.quantize_config) + calib_sampling_size = tune_cfg.get('calib_sampling_size', 1) + if isinstance(dataloader, BaseDataLoader): + batch_size = dataloader.batch_size + for i in range(batch_size): + if calib_sampling_size % (batch_size - i) == 0: + calib_batch_size = batch_size - i + if i != 0: # pragma: no cover + logger.warning("Reset `calibration.dataloader.batch_size` field " + "to {}".format(calib_batch_size) + + " to make sure the sampling_size is " + "divisible exactly by batch size") + break + tmp_iterations = int(math.ceil(calib_sampling_size / calib_batch_size)) + dataloader.batch(calib_batch_size) + self.quantize_config['calib_iteration'] = tmp_iterations + + else: # pragma: no cover + if hasattr(dataloader, 'batch_size') and \ + calib_sampling_size % dataloader.batch_size != 0: + iter = self.quantize_config['calib_iteration'] + logger.warning( + "Please note that calibration sampling size {} " \ + "isn't divisible exactly by batch size {}. " \ + "So the real sampling size is {}.". + format(calib_sampling_size, dataloader.batch_size, + dataloader.batch_size * iter)) + q_layers = [] + # check the tuning_config here + for idx, layer in enumerate(self.fp32_layers): + layer_config = layer["config"] + if layer["class_name"] in ["Conv2D", "Dense"]: + fake_quant_name = 'fake_quant_' + str(idx) + q_layers.append({'class_name': 'FakeQuant', + 'config': {'mode': 'per_tensor', 'axis': 1, 'name': fake_quant_name}}) + q_layers.append(layer) + else: + q_layers.append(layer) + + keras_object = model._model_object + json_model = copy.deepcopy(json.loads(keras_object.to_json())) + json_model['config']['layers'] = q_layers + quantized_model = self._restore_model_from_json(json_model) + + converted_model = self._calibrate(quantized_model, dataloader, + self.quantize_config['calib_iteration']) + + from neural_compressor.model.keras_model import KerasModel + converted_model = KerasModel(converted_model) + return converted_model + + def _calibrate(self, model, dataloader, calib_interation): + # run eagerly to fetch the numpy min/max + model.compile(run_eagerly=True) + results = {} + for idx, (inputs, labels) in enumerate(dataloader): + outputs = model.predict_on_batch(inputs) + json_model = copy.deepcopy(json.loads(model.to_json())) + config = json_model["config"] + layers = config["layers"] + for layer in layers: + if layer['class_name'] == 'FakeQuant': + min_value = layer['config']['min_value'] + max_value = layer['config']['max_value'] + if layer['config']['name'] not in results: + results[layer['config']['name']] = { + 'min': [min_value], 'max': [max_value]} + else: + results[layer['config']['name']]['min'].append(min_value) + results[layer['config']['name']]['max'].append(max_value) + if idx + 1 == calib_interation: + break + + # insert the calibrated min/max to Q/DQ + json_model = copy.deepcopy(json.loads(model.to_json())) + config = json_model["config"] + layers = config["layers"] + q_layers = [] + for layer in layers: + layer_config = copy.deepcopy(layer['config']) + if layer['class_name'] == 'FakeQuant': + min_value = min(results[layer['config']['name']]['min']) + max_value = max(results[layer['config']['name']]['max']) + q_layers.append({'class_name': 'Quantize', + 'config': {'min_range': min_value, + 'max_range': max_value, + }}) + q_layers.append({'class_name': 'DeQuantize', + 'config': {'min_range': min_value, + 'max_range': max_value, + }}) + elif layer['class_name'] == 'Conv2D' or layer['class_name'] == 'Dense': + # index 0 is weight, index 1 is bias + q_layer_name = 'Q' + layer['class_name'] + kernel = self.layer_weights[layer['config']['name']][0] + layer_config['min_value'] = str(kernel.min()) + layer_config['max_value'] = str(kernel.max()) + q_layers.append({'class_name': q_layer_name, 'config': layer_config}) + else: + q_layers.append(layer) + + json_model['config']['layers'] = q_layers + quantized_model = self._restore_model_from_json(json_model) + return quantized_model + + def _restore_model_from_json(self, json_model): + # We need to keep a dictionary of custom objects as our quantized library + # is not recognized by keras. + from tensorflow.keras.models import model_from_json + custom_objects = {} + custom_objects = _add_supported_quantized_objects(custom_objects) + qmodel = model_from_json(json.dumps(json_model), custom_objects=custom_objects) + qmodel = self._set_weights(qmodel, self.layer_weights) + return qmodel + + # set fp32 weights to qmodel + def _set_weights(self, qmodel, layer_weights): + for qlayer in qmodel.layers: + if qlayer.get_weights(): + if qlayer.name in layer_weights: + qlayer.set_weights(layer_weights[qlayer.name]) + else: + hit_layer = False + for sub_layer in qlayer.submodules: + if sub_layer.name in layer_weights: + qlayer.set_weights(layer_weights[sub_layer.name]) + hit_layer = True + break + if not hit_layer: + raise ValueError('Can not match the module weights....') + return qmodel + + + @dump_elapsed_time(customized_msg="Model inference") + def evaluate(self, model, dataloader, postprocess=None, + metric=None, measurer=None, iteration=-1, tensorboard=False): + '''The function is used to run evaluation on validation dataset. + + Args: + model (object): The model to do calibration. + dataloader (generator): generate the data and labels. + postprocess (object, optional): process the result from the model + metric (object, optional): Depends on model category. Defaults to None. + measurer (object, optional): for precise benchmark measurement. + iteration(int, optional): control steps of mini-batch + tensorboard (boolean, optional): for tensorboard inspect tensor. + ''' + logger.info("Start to evaluate the Keras model.") + results = [] + for idx, (inputs, labels) in enumerate(dataloader): + # use predict on batch + if measurer is not None: + measurer.start() + predictions = model.predict_on_batch(inputs) + measurer.end() + else: + predictions = model.predict_on_batch(inputs) + + if self.fp32_preds_as_label: + self.fp32_results.append(predictions) if fp32_baseline else \ + results.append(predictions) + + if postprocess is not None: + predictions, labels = postprocess((predictions, labels)) + if metrics: + for metric in metrics: + if not hasattr(metric, "compare_label") or \ + (hasattr(metric, "compare_label") and metric.compare_label): + metric.update(predictions, labels) + if idx + 1 == iteration: + break + return results + + def query_fw_capability(self, model): + '''The function is used to return framework tuning capability. + + Args: + model (object): The model to query quantization tuning capability. + ''' + self.pre_optimized_model = model + fp32_config = {'weight': {'dtype': 'fp32'}, 'activation': {'dtype': 'fp32'}} + int8_type = self.query_handler.get_op_types_by_precision(precision='int8') + op_capability = self.query_handler.get_quantization_capability() + conv_config = copy.deepcopy(op_capability['int8']['Conv2D']) + dense_config = copy.deepcopy(op_capability['int8']['Dense']) + other_config = copy.deepcopy(op_capability['int8']['default']) + + # get the layers info + keras_object = model._model_object + json_model = copy.deepcopy(json.loads(keras_object.to_json())) + config = json_model["config"] + self.fp32_layers = config["layers"] + + # get fp32 layer weights + self.layer_weights = {} + for layer in keras_object.layers: + if layer.get_weights(): + self.layer_weights[layer.name] = copy.deepcopy(layer.get_weights()) + + quantizable_op_details = OrderedDict() + for details in self.fp32_layers: + node_op = details['class_name'] + node_name = details['config']['name'] + if node_op == 'Conv2D': + # quantizable_op_details[(node_name, node_op)] = [conv_config, fp32_config] + quantizable_op_details[(node_name, node_op)] = [conv_config, fp32_config] + elif node_op == 'Dense': + quantizable_op_details[(node_name, node_op)] = [dense_config, fp32_config] + else: + quantizable_op_details[(node_name, node_op)] = [fp32_config] + + capability = { + 'opwise': copy.deepcopy(quantizable_op_details), + 'optypewise': self.get_optype_wise_ability(quantizable_op_details), + } + logger.debug("Dump framework quantization capability:") + logger.debug(capability) + + return capability + + def get_optype_wise_ability(self, quantizable_op_details): + """Get the op type wise capability by generating the union value of each op type. + Returns: + [string dict]: the key is op type while the value is the + detail configurations of activation and weight for this op type. + """ + res = OrderedDict() + for op in quantizable_op_details: + if op[1] not in res: + res[op[1]] = {'activation': quantizable_op_details[op][0]['activation']} + if 'weight' in quantizable_op_details[op][0]: + res[op[1]]['weight'] = quantizable_op_details[op][0]['weight'] + return res + + def inspect_tensor(self, model, dataloader, op_list=[], iteration_list=[], + inspect_type='activation', save_to_disk=False): + '''The function is used by tune strategy class for dumping tensor info. + + Args: + model (object): The model to inspect. + dataloader (object): The dataloader used to feed into. + op_list (list): The op name in the fp32 model for dumpping. + iteration_list (list): The iteration list containing iterations to dump. + inspect_type (str): The valid value are 'weight', 'activation', 'all'. + save_to_disk (bool): Save to disk or memory. + + Return: + Numpy Array Dict + { + 'weight': { + 'node0_name': {'weight0_name': numpy.array, 'bias0_name': numpy.array, ...}, + 'node1_name': {'weight1_name': numpy.array, 'bias1_name': numpy.array, ...}, + ... + }, + 'activation': [ + # iter 0 + { + 'node0_name': {'output0_name': numpy.array, 'output1_name': numpy.array, ...} + 'node1_name': {'output1_name': numpy.array, 'output1_name': numpy.array, ...} + ... + }, + # iter 1 + ... + ] + } + ''' + pass + + def set_tensor(self, model, tensor_dict): + '''The function is used by tune strategy class for setting tensor back to model. + + Args: + model (object): The model to set tensor. Usually it is quantized model. + tensor_dict (dict): The tensor dict to set. Note the numpy array contains float + value, adaptor layer has the responsibility to quantize to + int8 or int32 to set into the quantized model if needed. + The dict format is something like: + { + 'weight0_name': numpy.array, + 'bias0_name': numpy.array, + ... + } + ''' + pass + + def quantize_input(self, model): + ''' quantize the model to be able to take quantized input + + Args: + model (object): The model to quantize input + + Return: + model (object): The quantized input model + scale (float): The scale for dataloader to generate quantized input + ''' + return model, 1. + + def _pre_eval_hook(self, model, *args, **kwargs): + '''The function is used to do some preprocession before evaluation phase. + + Return: + model + ''' + return model + + def _post_eval_hook(self, model, *args, **kwargs): + '''The function is used to do some post process after complete evaluation. + ''' + pass + + def save(self, model, path): + '''The function is used by tune strategy class for saving model. + + Args: + model (object): The model to saved. + path (string): The path where to save. + ''' + model.save(path) + + def convert(self, model, source, destinatin): + '''The function is used to convert a source model format to another. + + Args: + model (neural_compressor.model): base model to be converted. + source (string): The source model format. + destination (string): The destination model format. + ''' + pass + +class KerasQuery(QueryBackendCapability): + + def __init__(self, local_config_file=None): + super().__init__() + self.version = tf.version.VERSION + self.cfg = local_config_file + self.cur_config = None + self._one_shot_query() + + def _one_shot_query(self): + with open(self.cfg) as f: + content = yaml.safe_load(f) + try: + self.cur_config = self._get_specified_version_cfg(content) + except Exception as e: + logger.info("Fail to parse {} due to {}.".format(self.cfg, str(e))) + self.cur_config = None + raise ValueError("Please check if the format of {} follows Neural Compressor yaml schema.". + format(self.cfg)) + + def _get_specified_version_cfg(self, data): + """Get the configuration for the current runtime. + If there's no matched configuration in the input yaml, we'll + use the `default` field of yaml. + + Args: + data (Yaml content): input yaml file. + + Returns: + [dictionary]: the content for specific version. + """ + default_config = None + for sub_data in data: + if sub_data['version']['name'] == self.version: + return sub_data + + if sub_data['version']['name'] == 'default': + default_config = sub_data + + return default_config + + def get_version(self): + """Get the current backend version infomation. + + Returns: + [string]: version string. + """ + return self.cur_config['version']['name'] + + def get_precisions(self): + """Get supported precisions for current backend. + + Returns: + [string list]: the precisions' name. + """ + return self.cur_config['precisions']['names'] + + def get_op_types(self): + """Get the supported op types by all precisions. + + Returns: + [dictionary list]: A list composed of dictionary which key is precision + and value is the op types. + """ + return self.cur_config['ops'] + + def get_quantization_capability(self): + """Get the supported op types' quantization capability. + + Returns: + [dictionary list]: A list composed of dictionary which key is precision + and value is a dict that describes all op types' quantization capability. + """ + return self.cur_config['capabilities'] + + def get_op_types_by_precision(self, precision): + """Get op types per precision + + Args: + precision (string): precision name + + Returns: + [string list]: A list composed of op type. + """ + assert precision in list(self.cur_config['ops'].keys()) + return self.cur_config['ops'][precision] diff --git a/neural_compressor/adaptor/keras.yaml b/neural_compressor/adaptor/keras.yaml new file mode 100644 index 00000000000..caac51828c4 --- /dev/null +++ b/neural_compressor/adaptor/keras.yaml @@ -0,0 +1,67 @@ +## Copyright (c) 2021 Intel Corporation +## +## Licensed under the Apache License, Version 2.0 (the "License"); +## you may not use this file except in compliance with the License. +## You may obtain a copy of the License at +## +## http://www.apache.org/licenses/LICENSE-2.0 +## +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +## See the License for the specific language governing permissions and +## limitations under the License. +## +# +- + version: + name: 'default' + + precisions: &common_precisions + names: int8, fp32 + valid_mixed_precisions: [] + + ops: &common_ops + int8: ['Conv2D', 'Dense'] + fp32: ['*'] # '*' means all op types + + capabilities: &common_capabilities + int8: { + 'Conv2D': { + 'weight': { + 'dtype': ['int8'], + 'scheme': ['sym'], + 'granularity': ['per_tensor'], + 'algorithm': ['minmax'] + }, + 'activation': { + 'dtype': ['int8'], + 'scheme': ['sym'], + 'granularity': ['per_tensor'], + 'algorithm': ['minmax'] + } + }, + 'Dense': { + 'weight': { + 'dtype': ['int8'], + 'scheme': ['sym'], + 'algorithm': ['minmax'], + 'granularity': ['per_tensor'], + }, + 'activation': { + 'dtype': ['int8'], + 'scheme': ['sym'], + 'algorithm': ['minmax'], + 'granularity': ['per_tensor'], + } + }, + 'default': { + 'activation': { + 'dtype': ['int8'], + 'quant_mode': 'static', + 'scheme': ['sym'], + 'algorithm': ['minmax'], + 'granularity': ['per_tensor'] + } + }, + } diff --git a/neural_compressor/adaptor/keras_utils/quantizer.py b/neural_compressor/adaptor/keras_utils/quantizer.py new file mode 100644 index 00000000000..c9407a51af5 --- /dev/null +++ b/neural_compressor/adaptor/keras_utils/quantizer.py @@ -0,0 +1,268 @@ +# Copyright 2019 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import sys +import numpy as np +import tensorflow as tf + +from tensorflow.python.eager import context +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import activations +from tensorflow.python.keras import backend +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.engine.input_spec import InputSpec +from tensorflow.python.keras.utils import conv_utils +from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.util.tf_export import keras_export + +from tensorflow.keras.layers import Layer +from tensorflow.python.keras.layers.convolutional import Conv +from tensorflow.python.keras.layers.core import Dense + +class FakeQuant(Layer): + def __init__(self, mode='per_tensor', axis=None, **kwargs): + super(FakeQuant, self).__init__(**kwargs) + self.mode = mode + self.axis = axis + self.min_value = tf.constant(np.finfo(np.float32).max, dtype=tf.float32) + self.max_value = tf.constant(np.finfo(np.float32).min, dtype=tf.float32) + + def call(self, inputs): + if self.mode == 'per_tensor': + self.min_value = tf.math.reduce_min(inputs) + self.max_value = tf.math.reduce_max(inputs) + # self.min_value = tf.math.reduce_min(inputs) + # self.max_value = tf.math.reduce_max(inputs) + else: + self.min_value = tf.math.reduce_min(inputs, axis=self.axis) + self.max_value = tf.math.reduce_max(inputs, axis=self.axis) + return inputs + + @classmethod + def from_config(cls, config): + return cls(**config) + + def get_config(self): + return {'mode': self.mode, 'axis': self.axis, + 'min_value': self.min_value.numpy(), + 'max_value': self.max_value.numpy(), + 'name': self.name} + +class Quantize(Layer): + def __init__(self, min_range, max_range, T=tf.qint8, mode='SCALED', + round_mode='HALF_AWAY_FROM_ZERO', narrow_range=False, axis=None): + super(Quantize, self).__init__() + self.min_range = float(min_range) + self.max_range = float(max_range) + self.T = T + self.mode = mode + self.round_mode = round_mode + self.narrow_range = narrow_range + self.axis = axis + + def call(self, inputs): + outputs, _, _ = tf.quantization.quantize(inputs, self.min_range, + self.max_range, self.T, + mode=self.mode, round_mode=self.round_mode, + narrow_range=self.narrow_range, axis=self.axis) + return outputs + + def get_config(self): + return {'min_range': self.min_range, 'max_range': self.max_range, + 'T': self.T, 'mode': self.mode, 'round_mode': self.round_mode, + 'narrow': self.narrow_range, 'axis': self.axis} + + @classmethod + def from_config(cls, config): + return cls(**config) + +class QConv2D(Conv): + def __init__(self, filters, kernel_size, strides=(1, 1), padding='valid', + data_format=None, dilation_rate=(1, 1), groups=1, activation=None, + use_bias=True, kernel_initializer='glorot_uniform', + bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, + activity_regularizer=None, kernel_constraint=None, bias_constraint=None, + min_value=-10000, max_value=10000, **kwargs): + super(QConv2D, self).__init__(rank=2, filters=filters, kernel_size=kernel_size, + strides=strides, padding=padding, data_format=data_format, + dilation_rate=dilation_rate, groups=groups, + activation=activations.get(activation), + use_bias=use_bias, kernel_initializer=initializers.get(kernel_initializer), + bias_initializer=initializers.get(bias_initializer), + kernel_regularizer=regularizers.get(kernel_regularizer), + bias_regularizer=regularizers.get(bias_regularizer), + activity_regularizer=regularizers.get(activity_regularizer), + kernel_constraint=constraints.get(kernel_constraint), + bias_constraint=constraints.get(bias_constraint), **kwargs) + self.weight_quantizer = Quantize(float(min_value), float(max_value)) + self.weight_dequantizer = DeQuantize(float(min_value), float(max_value)) + + def call(self, inputs): + input_shape = inputs.shape + + if self._is_causal: # Apply causal padding to inputs for Conv1D. + inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs)) + + # add the Q/DQ here + kernel = self.weight_quantizer(self.kernel) + kernel = self.weight_dequantizer(kernel) + outputs = self._convolution_op(inputs, kernel) + + if self.use_bias: + output_rank = outputs.shape.rank + if self.rank == 1 and self._channels_first: + # nn.bias_add does not accept a 1D input tensor. + bias = array_ops.reshape(self.bias, (1, self.filters, 1)) + outputs += bias + else: + # Handle multiple batch dimensions. + if output_rank is not None and output_rank > 2 + self.rank: + + def _apply_fn(o): + return nn.bias_add(o, self.bias, data_format=self._tf_data_format) + + outputs = conv_utils.squeeze_batch_dims( + outputs, _apply_fn, inner_rank=self.rank + 1) + else: + outputs = nn.bias_add( + outputs, self.bias, data_format=self._tf_data_format) + + if not context.executing_eagerly(): + # Infer the static output shape: + out_shape = self.compute_output_shape(input_shape) + outputs.set_shape(out_shape) + + if self.activation is not None: + return self.activation(outputs) + return outputs + +class QDense(Dense): + def __init__(self, + units, + activation=None, + use_bias=True, + kernel_initializer='glorot_uniform', + bias_initializer='zeros', + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + min_value=-10000, + max_value=10000, + **kwargs): + super(QDense, self).__init__( + units=units, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + **kwargs) + self.weight_quantizer = Quantize(float(min_value), float(max_value)) + self.weight_dequantizer = DeQuantize(float(min_value), float(max_value)) + + def call(self, inputs): + if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype: + inputs = math_ops.cast(inputs, dtype=self._compute_dtype_object) + + # add the Q/DQ here + # (TODO) we have not try sparse dense and may have issues + kernel = self.weight_quantizer(self.kernel) + kernel = self.weight_dequantizer(kernel) + rank = inputs.shape.rank + if rank == 2 or rank is None: + # We use embedding_lookup_sparse as a more efficient matmul operation for + # large sparse input tensors. The op will result in a sparse gradient, as + # opposed to sparse_ops.sparse_tensor_dense_matmul which results in dense + # gradients. This can lead to sigfinicant speedups, see b/171762937. + if isinstance(inputs, sparse_tensor.SparseTensor): + # We need to fill empty rows, as the op assumes at least one id per row. + inputs, _ = sparse_ops.sparse_fill_empty_rows(inputs, 0) + # We need to do some munging of our input to use the embedding lookup as + # a matrix multiply. We split our input matrix into separate ids and + # weights tensors. The values of the ids tensor should be the column + # indices of our input matrix and the values of the weights tensor + # can continue to the actual matrix weights. + # The column arrangement of ids and weights + # will be summed over and does not matter. See the documentation for + # sparse_ops.sparse_tensor_dense_matmul a more detailed explanation + # of the inputs to both ops. + ids = sparse_tensor.SparseTensor( + indices=inputs.indices, + values=inputs.indices[:, 1], + dense_shape=inputs.dense_shape) + weights = inputs + outputs = embedding_ops.embedding_lookup_sparse_v2( + kernel, ids, weights, combiner='sum') + else: + outputs = gen_math_ops.MatMul(a=inputs, b=kernel) + # Broadcast kernel to inputs. + else: + outputs = standard_ops.tensordot(inputs, kernel, [[rank - 1], [0]]) + # Reshape the output back to the original ndim of the input. + if not context.executing_eagerly(): + shape = inputs.shape.as_list() + output_shape = shape[:-1] + [kernel.shape[-1]] + outputs.set_shape(output_shape) + + if self.use_bias: + outputs = nn_ops.bias_add(outputs, self.bias) + + if self.activation is not None: + outputs = self.activation(outputs) + return outputs + + +class DeQuantize(Layer): + def __init__(self, min_range, max_range, mode='SCALED', + narrow_range=False, axis=None): + super(DeQuantize, self).__init__() + self.min_range = min_range + self.max_range = max_range + self.mode = mode + self.narrow_range = narrow_range + self.axis = axis + + def call(self, inputs): + return tf.quantization.dequantize(inputs, float(self.min_range), + float(self.max_range), mode=self.mode, + narrow_range=self.narrow_range, axis=self.axis) + def get_config(self): + return {'min_range': self.min_range, 'max_range': self.max_range, + 'mode': self.mode, 'narrow': self.narrow_range, 'axis': self.axis, + 'dtype': self.dtype} + + @classmethod + def from_config(cls, config): + return cls(**config) diff --git a/neural_compressor/conf/pythonic_config.py b/neural_compressor/conf/pythonic_config.py index c9975a9ebc6..708407cd24c 100644 --- a/neural_compressor/conf/pythonic_config.py +++ b/neural_compressor/conf/pythonic_config.py @@ -183,6 +183,10 @@ class TensorFlow(MXNet): def __init__(self, precisions=None): super().__init__(precisions) +class Keras(MXNet): + def __init__(self, precisions=None): + super().__init__(precisions) + class PyTorch(MXNet): def __init__(self, precisions=None): super().__init__(precisions) @@ -241,6 +245,7 @@ def search(self, search): nas = NASConfig() onnxruntime_config = ONNX() tensorflow_config = TensorFlow() +keras_config = Keras() pytorch_config = PyTorch() mxnet_config = MXNet() @@ -256,7 +261,8 @@ def __init__(self, onnxruntime=onnxruntime_config, tensorflow=tensorflow_config, pytorch=pytorch_config, - mxnet=mxnet_config): + mxnet=mxnet_config, + keras=keras_config): self._quantization = quantization self._benchmark = benchmark self._options = options @@ -267,6 +273,7 @@ def __init__(self, self._tensorflow = tensorflow self._pytorch = pytorch self._mxnet = mxnet + self._keras = keras @property def distillation(self): @@ -280,6 +287,10 @@ def nas(self): def tensorflow(self): return self._tensorflow + @property + def keras(self): + return self._keras + @property def pytorch(self): return self._pytorch diff --git a/neural_compressor/experimental/data/dataloaders/dataloader.py b/neural_compressor/experimental/data/dataloaders/dataloader.py index c3463b875eb..c879b0b45d1 100644 --- a/neural_compressor/experimental/data/dataloaders/dataloader.py +++ b/neural_compressor/experimental/data/dataloaders/dataloader.py @@ -25,6 +25,7 @@ DATALOADERS = {"tensorflow": TensorflowDataLoader, "tensorflow_itex": TensorflowDataLoader, + "keras": TensorflowDataLoader, "mxnet": MXNetDataLoader, "pytorch": PyTorchDataLoader, "pytorch_ipex": PyTorchDataLoader, diff --git a/neural_compressor/model/keras_model.py b/neural_compressor/model/keras_model.py new file mode 100644 index 00000000000..a4e248c54fb --- /dev/null +++ b/neural_compressor/model/keras_model.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import abstractmethod +from neural_compressor.model.base_model import BaseModel +from neural_compressor.utils.utility import LazyImport +tf = LazyImport('tensorflow') + +class KerasModel(BaseModel): + """Build KerasModel object + + Args: + model (string or keras model object): model path or model object + kwargs (dict): other required parameters + + """ + + def __init__(self, model, **kwargs): + self.component = None + self._model = model + if not isinstance(model, tf.keras.Model): + self._model_object = tf.keras.models.load_model(self._model) + else: + self._model_object = self._model + self._q_config = None + + @property + def q_config(self): + return self._q_config + + @q_config.setter + def q_config(self, q_config): + self._q_config = q_config + + @property + def model(self): + return self._model + + @property + def graph_info(self): + ''' return {Node: Node_type} like {'conv0': 'conv2d'} ''' + #(TODO) get the graph info + return None + + @abstractmethod + def save(self, root, *args, **kwargs): + self._model_object.save(root) + + @abstractmethod + def export( + self, + save_path: str, + conf, + ): + pass + + @abstractmethod + def framework(self): + return 'keras' + diff --git a/neural_compressor/model/model.py b/neural_compressor/model/model.py index 59a87d51a29..bc63217d7e5 100644 --- a/neural_compressor/model/model.py +++ b/neural_compressor/model/model.py @@ -28,6 +28,7 @@ from neural_compressor.conf import config as cfg from neural_compressor.model.base_model import BaseModel from neural_compressor.model.onnx_model import ONNXModel +from neural_compressor.model.keras_model import KerasModel TORCH = False if importlib.util.find_spec('torch'): @@ -60,35 +61,13 @@ def get_model_type(model): return 'graph' elif isinstance(model, tf.compat.v1.GraphDef): return 'graph_def' - elif isinstance(model, tf.keras.Model): - return 'keras' elif isinstance(model, tf.compat.v1.estimator.Estimator): return 'estimator' elif isinstance(model, str): model = os.path.abspath(os.path.expanduser(model)) - if (model.endswith('.h5') and os.path.isfile(model)): - if version1_lt_version2(tf.version.VERSION, '2.3.0'): - logger.warn("keras model running on tensorflow 2.2.0 and" - " lower may have problem.") - model = tf.keras.models.load_model(model) - if isinstance(model, tf.keras.Model): - return 'keras' if (model.endswith('.pb') and os.path.isfile(model)): if is_saved_model_format(os.path.dirname(model)): - # Warning: TF compatibility issue to load saved model. TF 2.3 keras.load - # can load saved model from TF backend, but TF 2.4 cannot. - try: - if version1_lt_version2(tf.version.VERSION, '2.3.0'): - logger.warn("keras model running on tensorflow 2.2.0 and" - " lower may have problem.") - model = tf.keras.models.load_model(model) - if isinstance(model, tf.keras.Model): - return 'keras' - else: - return 'saved_model' - except: - # can't use keras load - return 'saved_model' + return 'saved_model' else: return 'frozen_pb' elif model.endswith('.ckpt') and os.path.isfile(model): @@ -97,20 +76,7 @@ def get_model_type(model): if is_ckpt_format(model): return 'checkpoint' elif is_saved_model_format(model): - # it's very ugly tf version issue, in tf2.3 keras.load can - #batch_size_(batch_size), load saved model from tf backend, but tf2.4 it will crash - try: - if version1_lt_version2(tf.version.VERSION, '2.3.0'): - logger.warn("keras model running on tensorflow 2.2.0 and" - " lower may have problem.") - model = tf.keras.models.load_model(model) - if isinstance(model, tf.keras.Model): - return 'keras' - else: - return 'saved_model' - except: - # can't use keras load - return 'saved_model' + return 'saved_model' elif os.path.isfile(model + '.pb'): return 'frozen_pb' @@ -152,6 +118,25 @@ def _is_pytorch(model): except: return 'NA' + def _is_keras(model): + if isinstance(model, str): + model = os.path.abspath(os.path.expanduser(model)) + if (model.endswith('.h5') and os.path.isfile(model)) or \ + is_saved_model_format(os.path.dirname(model)) or \ + (os.path.isdir(model) and is_saved_model_format(model)): + if version1_lt_version2(tf.version.VERSION, '2.10.0'): + logger.warn("keras model running on tensorflow 2.10.0 and" + " lower not support intel ITEX.") + return 'NA' + try: + model = tf.keras.models.load_model(model) + except: + return 'NA' + if isinstance(model, tf.keras.Model) and hasattr(model, 'to_json'): + return 'keras' + else: + return 'NA' + def _is_tensorflow(model): try: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -186,7 +171,7 @@ def _is_mxnet(model): if isinstance(model, TensorflowBaseModel): return 'tensorflow' - checker = [_is_tensorflow, _is_pytorch, _is_onnxruntime, _is_mxnet] + checker = [_is_keras, _is_tensorflow, _is_pytorch, _is_onnxruntime, _is_mxnet] for handler in checker: fwk_name = handler(model) if fwk_name != 'NA': @@ -407,28 +392,6 @@ def load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_ grappler_meta_graph_def, graph_id=b"tf_graph") return opt, input_tensor_names, output_tensor_names -def check_keras_format(model, saved_model_dir): - from tensorflow.python import saved_model - from tensorflow.python.saved_model.load import load - from tensorflow.python.saved_model import save_options - from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info - version = 'saved_model_v2' - try: - saved_model.save( - model, - saved_model_dir, - options=save_options.SaveOptions(save_debug_info=True)) - except: - return 'trackable_object' - saved_model_proto, _ = parse_saved_model_with_debug_info(saved_model_dir) - saved_model_version = saved_model_proto.saved_model_schema_version - if saved_model_version == 0: - return 'saved_model_v1' - if saved_model_version not in [1, 2]: - raise ValueError("SavedModel file format({0}) is not supported".format( - saved_model_version)) - return version - def get_graph_from_saved_model_v2(saved_model_dir, input_tensor_names, output_tensor_names): from tensorflow.python.saved_model import tag_constants @@ -527,52 +490,6 @@ def get_graph_from_saved_model_v1(model): sess, graph.as_graph_def(), output_nodes) return graph_def, inputs, outputs -def keras_session(model, input_tensor_names, output_tensor_names, **kwargs): - """Build session with keras model - - Args: - model (string or tf.keras.Model): model path or tf.keras.Model object - input_tensor_names (list of string): input_tensor_names of model - output_tensor_names (list of string): output_tensor_names of model - - Returns: - sess (tf.compat.v1.Session): tf.compat.v1.Session object - input_tensor_names (list of string): validated input_tensor_names - output_tensor_names (list of string): validated output_tensor_names - """ - temp_dir = tempfile.mkdtemp() - if tf.version.VERSION > '2.1.0': - if not isinstance(model, tf.keras.Model): - model = tf.keras.models.load_model(model) - keras_format = check_keras_format(model, temp_dir) - if keras_format == 'saved_model_v2': - try: - graph_def, input_names, output_names = get_graph_from_saved_model_v2( - temp_dir, input_tensor_names, output_tensor_names) - if '_FusedBatchNormEx' in [node.op for node in graph_def.node]: - keras_format = 'trackable_object' - except: - keras_format = 'trackable_object' - if keras_format == 'trackable_object': - try: - graph_def, input_names, output_names = get_graph_from_original_keras_v2( - model, temp_dir) - except: - keras_format = 'saved_model_v1' - if keras_format == 'saved_model_v1': - try: - tf.keras.backend.set_learning_phase(0) - graph_def, input_names, output_names = get_graph_from_saved_model_v1(model) - except: - raise ValueError('Not supported keras model type...') - - # tensorflow 1.x use v1 convert method - else: - tf.keras.backend.set_learning_phase(0) - graph_def, input_names, output_names = get_graph_from_saved_model_v1(model) - shutil.rmtree(temp_dir, True) - return graph_def_session(graph_def, input_names, output_names, **kwargs) - def slim_session(model, input_tensor_names, output_tensor_names, **kwargs): """Build session with slim model @@ -738,7 +655,6 @@ def saved_model_session(model, input_tensor_names, output_tensor_names, **kwargs 'graph_def': graph_def_session, 'graph': graph_session, 'saved_model': saved_model_session, - 'keras': keras_session, 'checkpoint': checkpoint_session, 'estimator': estimator_session, 'slim': slim_session,} @@ -1107,7 +1023,7 @@ def graph_def(self, graph_def): 'estimator': TensorflowBaseModel, 'slim': TensorflowBaseModel, 'saved_model': TensorflowSavedModelModel, - 'keras': TensorflowSavedModelModel,} + } class TensorflowModel(object): def __new__(cls, model_type, root, **kwargs): @@ -1162,6 +1078,7 @@ def save(self, root=None): MODELS = {'tensorflow': TensorflowModel, 'tensorflow_itex': TensorflowModel, + 'keras': KerasModel, 'mxnet': MXNetModel, 'pytorch': PyTorchModel if TORCH else None, 'pytorch_ipex': PyTorchIpexModel if TORCH else None, diff --git a/neural_compressor/strategy/strategy.py b/neural_compressor/strategy/strategy.py index 7be1897a948..59fbb88ebfa 100644 --- a/neural_compressor/strategy/strategy.py +++ b/neural_compressor/strategy/strategy.py @@ -532,6 +532,9 @@ def set_framework_info(self, q_dataloader, q_func=None): 'recipes': self.cfg.quantization.recipes, 'performance_only': self.cfg.tuning.exit_policy.performance_only, 'use_bf16': self.cfg.use_bf16 if self.cfg.use_bf16 is not None else False}) + if 'keras' in framework: + framework_specific_info.update({ + 'workspace_path': self.cfg.tuning.workspace.path, }) if framework == 'mxnet': framework_specific_info.update({"q_dataloader": q_dataloader}) if 'onnxrt' in framework.lower(): From 9c18029c306bdf7ff8f71d6b0e29cd5ee205bb2c Mon Sep 17 00:00:00 2001 From: Clark Chin Date: Thu, 8 Dec 2022 16:26:06 +0800 Subject: [PATCH 02/14] clean the code --- examples/keras/{minist => mnist}/README.md | 0 examples/keras/{minist => mnist}/conf.yaml | 0 examples/keras/{minist => mnist}/mnist.py | 5 ---- .../keras/{minist => mnist}/requirements.txt | 0 neural_compressor/adaptor/keras.py | 24 +++---------------- 5 files changed, 3 insertions(+), 26 deletions(-) rename examples/keras/{minist => mnist}/README.md (100%) rename examples/keras/{minist => mnist}/conf.yaml (100%) rename examples/keras/{minist => mnist}/mnist.py (93%) rename examples/keras/{minist => mnist}/requirements.txt (100%) diff --git a/examples/keras/minist/README.md b/examples/keras/mnist/README.md similarity index 100% rename from examples/keras/minist/README.md rename to examples/keras/mnist/README.md diff --git a/examples/keras/minist/conf.yaml b/examples/keras/mnist/conf.yaml similarity index 100% rename from examples/keras/minist/conf.yaml rename to examples/keras/mnist/conf.yaml diff --git a/examples/keras/minist/mnist.py b/examples/keras/mnist/mnist.py similarity index 93% rename from examples/keras/minist/mnist.py rename to examples/keras/mnist/mnist.py index 68d53da32c8..ea35307b468 100644 --- a/examples/keras/minist/mnist.py +++ b/examples/keras/mnist/mnist.py @@ -64,11 +64,6 @@ def build_model(x_train, y_train, x_test, y_test): metrics=["accuracy"], run_eagerly=True) model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1) model.summary() - # start_time = time.time() - # fp32_score = model.evaluate(x_test, y_test, verbose=0) - # duration = time.time() - start_time - # print("Fp32 accuracy:", fp32_score[1]) - # print("Fp32 duration:", duration) if not os.path.exists('fp32_model'): model.save('fp32_model') return model diff --git a/examples/keras/minist/requirements.txt b/examples/keras/mnist/requirements.txt similarity index 100% rename from examples/keras/minist/requirements.txt rename to examples/keras/mnist/requirements.txt diff --git a/neural_compressor/adaptor/keras.py b/neural_compressor/adaptor/keras.py index b396fb22e65..bf7d698b864 100644 --- a/neural_compressor/adaptor/keras.py +++ b/neural_compressor/adaptor/keras.py @@ -29,17 +29,8 @@ from ..utils import logger from ..conf.dotdict import deep_get from ..experimental.data.dataloaders.base_dataloader import BaseDataLoader - tf = LazyImport('tensorflow') -# REGISTERED_LAYERS = [ -# "Quantize", -# "DeQuantize", -# "FakeQuant", -# "QConv2D", -# "QDense", -# ] - def _add_supported_quantized_objects(custom_objects): """Map all the quantized objects.""" from neural_compressor.adaptor.keras_utils.quantizer import Quantize, DeQuantize @@ -66,10 +57,9 @@ def __init__(self, framework_specific_info): self.recipes = deep_get(self.framework_specific_info, 'recipes', {}) os.makedirs(self.work_dir, exist_ok=True) + self.pre_optimized_model = None self.pre_optimizer_handle = None - self.fp32_ops = [] - self.query_handler = KerasQuery(local_config_file=os.path.join( os.path.dirname(__file__), 'keras.yaml')) @@ -79,14 +69,12 @@ def __init__(self, framework_specific_info): self.callbacks = [] self.optype_statistics = None - def tuning_cfg_to_fw(self, tuning_cfg): self.quantize_config['calib_iteration'] = tuning_cfg['calib_iteration'] self.quantize_config['device'] = self.device self.quantize_config['advance'] = deep_get(tuning_cfg, 'advance') fp32_ops = [] dispatched_op_names = [j[0] for j in tuning_cfg['op']] - invalid_op_names = [i for i in self.quantize_config['op_wise_config'] if i not in dispatched_op_names] @@ -95,7 +83,6 @@ def tuning_cfg_to_fw(self, tuning_cfg): for each_op_info in tuning_cfg['op']: op_name = each_op_info[0] - if tuning_cfg['op'][each_op_info]['activation']['dtype'] == 'fp32': if op_name in self.quantize_config['op_wise_config']: self.quantize_config['op_wise_config'].pop(op_name) @@ -109,9 +96,7 @@ def tuning_cfg_to_fw(self, tuning_cfg): 'granularity'] == 'per_channel' #bit = tuning_cfg['op'][each_op_info]['weight']['bit'] weight_bit = bit if bit else 7.0 - algorithm = tuning_cfg['op'][each_op_info]['activation']['algorithm'] - is_asymmetric = False if 'activation' in tuning_cfg['op'][each_op_info]: is_asymmetric = tuning_cfg['op'][each_op_info]['activation']['scheme'] == 'asym' @@ -131,7 +116,6 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): dataloader(object): The dataloader used to load quantization dataset. q_func (optional): training function for quantization aware training mode. ''' - import pdb; pdb.set_trace() self.tuning_cfg_to_fw(tune_cfg) logger.debug("Dump quantization configurations:") logger.debug(self.quantize_config) @@ -240,10 +224,10 @@ def _calibrate(self, model, dataloader, calib_interation): return quantized_model def _restore_model_from_json(self, json_model): - # We need to keep a dictionary of custom objects as our quantized library - # is not recognized by keras. from tensorflow.keras.models import model_from_json custom_objects = {} + # We need to keep a dictionary of custom objects as our quantized library + # is not recognized by keras. custom_objects = _add_supported_quantized_objects(custom_objects) qmodel = model_from_json(json.dumps(json_model), custom_objects=custom_objects) qmodel = self._set_weights(qmodel, self.layer_weights) @@ -266,7 +250,6 @@ def _set_weights(self, qmodel, layer_weights): raise ValueError('Can not match the module weights....') return qmodel - @dump_elapsed_time(customized_msg="Model inference") def evaluate(self, model, dataloader, postprocess=None, metric=None, measurer=None, iteration=-1, tensorboard=False): @@ -464,7 +447,6 @@ def convert(self, model, source, destinatin): pass class KerasQuery(QueryBackendCapability): - def __init__(self, local_config_file=None): super().__init__() self.version = tf.version.VERSION From 69911c15fe38ff5714340cfd1d30faf315a5bb01 Mon Sep 17 00:00:00 2001 From: Clark Chin Date: Thu, 8 Dec 2022 17:31:04 +0800 Subject: [PATCH 03/14] fix the tune_cfg issue --- neural_compressor/adaptor/keras.py | 10 ++++++---- neural_compressor/adaptor/keras.yaml | 2 ++ neural_compressor/adaptor/keras_utils/quantizer.py | 8 +++----- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/neural_compressor/adaptor/keras.py b/neural_compressor/adaptor/keras.py index bf7d698b864..5f031d5b37b 100644 --- a/neural_compressor/adaptor/keras.py +++ b/neural_compressor/adaptor/keras.py @@ -146,13 +146,16 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): format(calib_sampling_size, dataloader.batch_size, dataloader.batch_size * iter)) q_layers = [] - # check the tuning_config here for idx, layer in enumerate(self.fp32_layers): layer_config = layer["config"] - if layer["class_name"] in ["Conv2D", "Dense"]: + if layer["class_name"] in ["Conv2D", "Dense"] and \ + layer['config']['name'] in self.quantize_config['op_wise_config']: + op_config = self.quantize_config['op_wise_config'][layer['config']['name']] + mode = 'per_channel' if op_config[0] else 'per_tensor' + #(TODO) support asym/sym fake_quant_name = 'fake_quant_' + str(idx) q_layers.append({'class_name': 'FakeQuant', - 'config': {'mode': 'per_tensor', 'axis': 1, 'name': fake_quant_name}}) + 'config': {'mode': 'per_tensor', 'name': fake_quant_name}}) q_layers.append(layer) else: q_layers.append(layer) @@ -321,7 +324,6 @@ def query_fw_capability(self, model): node_op = details['class_name'] node_name = details['config']['name'] if node_op == 'Conv2D': - # quantizable_op_details[(node_name, node_op)] = [conv_config, fp32_config] quantizable_op_details[(node_name, node_op)] = [conv_config, fp32_config] elif node_op == 'Dense': quantizable_op_details[(node_name, node_op)] = [dense_config, fp32_config] diff --git a/neural_compressor/adaptor/keras.yaml b/neural_compressor/adaptor/keras.yaml index caac51828c4..291eb43dc0d 100644 --- a/neural_compressor/adaptor/keras.yaml +++ b/neural_compressor/adaptor/keras.yaml @@ -36,6 +36,7 @@ }, 'activation': { 'dtype': ['int8'], + 'quant_mode': 'static', 'scheme': ['sym'], 'granularity': ['per_tensor'], 'algorithm': ['minmax'] @@ -50,6 +51,7 @@ }, 'activation': { 'dtype': ['int8'], + 'quant_mode': 'static', 'scheme': ['sym'], 'algorithm': ['minmax'], 'granularity': ['per_tensor'], diff --git a/neural_compressor/adaptor/keras_utils/quantizer.py b/neural_compressor/adaptor/keras_utils/quantizer.py index c9407a51af5..241721d517e 100644 --- a/neural_compressor/adaptor/keras_utils/quantizer.py +++ b/neural_compressor/adaptor/keras_utils/quantizer.py @@ -46,10 +46,10 @@ from tensorflow.python.keras.layers.core import Dense class FakeQuant(Layer): - def __init__(self, mode='per_tensor', axis=None, **kwargs): + def __init__(self, mode='per_tensor', **kwargs): super(FakeQuant, self).__init__(**kwargs) self.mode = mode - self.axis = axis + self.axis = 1 if mode == 'per_channel' else 0 self.min_value = tf.constant(np.finfo(np.float32).max, dtype=tf.float32) self.max_value = tf.constant(np.finfo(np.float32).min, dtype=tf.float32) @@ -57,8 +57,6 @@ def call(self, inputs): if self.mode == 'per_tensor': self.min_value = tf.math.reduce_min(inputs) self.max_value = tf.math.reduce_max(inputs) - # self.min_value = tf.math.reduce_min(inputs) - # self.max_value = tf.math.reduce_max(inputs) else: self.min_value = tf.math.reduce_min(inputs, axis=self.axis) self.max_value = tf.math.reduce_max(inputs, axis=self.axis) @@ -69,7 +67,7 @@ def from_config(cls, config): return cls(**config) def get_config(self): - return {'mode': self.mode, 'axis': self.axis, + return {'mode': self.mode, 'min_value': self.min_value.numpy(), 'max_value': self.max_value.numpy(), 'name': self.name} From 8c6c743d45534de68a086151c1da568761b4741d Mon Sep 17 00:00:00 2001 From: Clark Chin Date: Fri, 9 Dec 2022 11:08:06 +0800 Subject: [PATCH 04/14] fix the dataloader issue --- examples/keras/mnist/mnist.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/examples/keras/mnist/mnist.py b/examples/keras/mnist/mnist.py index ea35307b468..d7997ef6ba8 100644 --- a/examples/keras/mnist/mnist.py +++ b/examples/keras/mnist/mnist.py @@ -38,6 +38,17 @@ def build_dataset(): y_test = keras.utils.to_categorical(y_test, num_classes) return x_train, y_train, x_test, y_test +class Dataset(): + def __init__(self, ): + self.inputs, self.labels, _, _ = build_dataset() + + def __getitem__(self, idx): + return self.inputs[idx], self.labels[idx] + + def __len__(self): + assert len(self.inputs) == len(self.labels), 'inputs should have equal len with labels' + return len(self.inputs) + def build_model(x_train, y_train, x_test, y_test): if os.path.exists('fp32_model'): model = keras.models.load_model('fp32_model') @@ -80,7 +91,7 @@ def main(): from neural_compressor.experimental import Quantization, common quantizer = Quantization('./conf.yaml') quantizer.model = common.Model(model) - quantizer.calib_dataloader = common.DataLoader((x_train[:100], y_train[:100]), batch_size=10) + quantizer.calib_dataloader = common.DataLoader(Dataset(), batch_size=10) quantizer.eval_func = eval_func quantized_model = quantizer.fit() From a9f3dab6c145eb33cfe6b893a682897f5b7d47d5 Mon Sep 17 00:00:00 2001 From: Clark Chin Date: Fri, 9 Dec 2022 12:07:05 +0800 Subject: [PATCH 05/14] fix several issues 1. fix keras saved model load issue 2. fix evaluate keras object/model object issue 3. fix metric registry 4. fix benchmark performance/acc 5. fix dataloader mis-batch --- neural_compressor/adaptor/keras.py | 10 +++++++--- neural_compressor/experimental/benchmark.py | 2 ++ neural_compressor/experimental/metric/metric.py | 4 +++- neural_compressor/model/keras_model.py | 2 +- neural_compressor/model/model.py | 1 + 5 files changed, 14 insertions(+), 5 deletions(-) diff --git a/neural_compressor/adaptor/keras.py b/neural_compressor/adaptor/keras.py index 5f031d5b37b..76b6b47723e 100644 --- a/neural_compressor/adaptor/keras.py +++ b/neural_compressor/adaptor/keras.py @@ -255,7 +255,8 @@ def _set_weights(self, qmodel, layer_weights): @dump_elapsed_time(customized_msg="Model inference") def evaluate(self, model, dataloader, postprocess=None, - metric=None, measurer=None, iteration=-1, tensorboard=False): + metrics=None, measurer=None, iteration=-1, + tensorboard=False, fp32_baseline=False): '''The function is used to run evaluation on validation dataset. Args: @@ -266,17 +267,20 @@ def evaluate(self, model, dataloader, postprocess=None, measurer (object, optional): for precise benchmark measurement. iteration(int, optional): control steps of mini-batch tensorboard (boolean, optional): for tensorboard inspect tensor. + fp32_baseline (boolen, optional): only for compare_label=False pipeline ''' + # use keras object + keras_model = model.model logger.info("Start to evaluate the Keras model.") results = [] for idx, (inputs, labels) in enumerate(dataloader): # use predict on batch if measurer is not None: measurer.start() - predictions = model.predict_on_batch(inputs) + predictions = keras_model.predict_on_batch(inputs) measurer.end() else: - predictions = model.predict_on_batch(inputs) + predictions = keras_model.predict_on_batch(inputs) if self.fp32_preds_as_label: self.fp32_results.append(predictions) if fp32_baseline else \ diff --git a/neural_compressor/experimental/benchmark.py b/neural_compressor/experimental/benchmark.py index 00329dabd43..f35e145551f 100644 --- a/neural_compressor/experimental/benchmark.py +++ b/neural_compressor/experimental/benchmark.py @@ -305,6 +305,8 @@ def run_instance(self, mode): "outputs": cfg.model.outputs, \ "recipes": cfg.model.recipes, \ 'workspace_path': cfg.tuning.workspace.path}) + if framework == 'keras': + framework_specific_info.update({'workspace_path': cfg.tuning.workspace.path}) if framework == 'mxnet': framework_specific_info.update({"b_dataloader": self._b_dataloader}) if 'onnxrt' in framework.lower(): diff --git a/neural_compressor/experimental/metric/metric.py b/neural_compressor/experimental/metric/metric.py index b02b52cc861..31f0550b071 100644 --- a/neural_compressor/experimental/metric/metric.py +++ b/neural_compressor/experimental/metric/metric.py @@ -113,6 +113,7 @@ def __init__(self) -> None: framework_metrics = {"tensorflow": TensorflowMetrics, "tensorflow_itex": TensorflowMetrics, + "keras": TensorflowMetrics, "mxnet": MXNetMetrics, "pytorch": PyTorchMetrics, "pytorch_ipex": PyTorchMetrics, @@ -132,6 +133,7 @@ def __init__(self) -> None: registry_metrics = {"tensorflow": TENSORFLOW_METRICS, "tensorflow_itex": TENSORFLOW_ITEX_METRICS, + "keras": TENSORFLOW_METRICS, "mxnet": MXNET_METRICS, "pytorch": PYTORCH_METRICS, "pytorch_ipex": PYTORCH_METRICS, @@ -156,7 +158,7 @@ def __init__(self, framework: str): Args: framework: The framwork name. """ - assert framework in ("tensorflow", "tensorflow_itex", + assert framework in ("tensorflow", "tensorflow_itex","keras", "pytorch", "pytorch_ipex", "pytorch_fx", "onnxrt_qdq", "onnxrt_qlinearops", "onnxrt_integerops", "mxnet", "onnxrt_qoperator"), \ diff --git a/neural_compressor/model/keras_model.py b/neural_compressor/model/keras_model.py index a4e248c54fb..f7034e4e635 100644 --- a/neural_compressor/model/keras_model.py +++ b/neural_compressor/model/keras_model.py @@ -49,7 +49,7 @@ def q_config(self, q_config): @property def model(self): - return self._model + return self._model_object @property def graph_info(self): diff --git a/neural_compressor/model/model.py b/neural_compressor/model/model.py index bc63217d7e5..96fff1c7417 100644 --- a/neural_compressor/model/model.py +++ b/neural_compressor/model/model.py @@ -120,6 +120,7 @@ def _is_pytorch(model): def _is_keras(model): if isinstance(model, str): + from neural_compressor.adaptor.tf_utils.util import is_saved_model_format model = os.path.abspath(os.path.expanduser(model)) if (model.endswith('.h5') and os.path.isfile(model)) or \ is_saved_model_format(os.path.dirname(model)) or \ From fa5e6abcfb34cdf0050456bcdd011a50f0b2d836 Mon Sep 17 00:00:00 2001 From: "Lv, Liang1" Date: Fri, 9 Dec 2022 13:41:53 +0800 Subject: [PATCH 06/14] add ut case Signed-off-by: Lv, Liang1 --- examples/keras/mnist/mnist.py | 17 +- neural_compressor/adaptor/keras.py | 4 +- .../adaptor/keras_utils/quantizer.py | 8 +- neural_compressor/experimental/benchmark.py | 5 +- neural_compressor/model/model.py | 1 + test/itex/test_keras_in_keras_out.py | 188 ++++++++++++++++++ 6 files changed, 209 insertions(+), 14 deletions(-) create mode 100644 test/itex/test_keras_in_keras_out.py diff --git a/examples/keras/mnist/mnist.py b/examples/keras/mnist/mnist.py index ea35307b468..b912c3ab51c 100644 --- a/examples/keras/mnist/mnist.py +++ b/examples/keras/mnist/mnist.py @@ -77,14 +77,17 @@ def eval_func(model): def main(): x_train, y_train, x_test, y_test = build_dataset() model = build_model(x_train, y_train, x_test, y_test) - from neural_compressor.experimental import Quantization, common - quantizer = Quantization('./conf.yaml') - quantizer.model = common.Model(model) - quantizer.calib_dataloader = common.DataLoader((x_train[:100], y_train[:100]), batch_size=10) - quantizer.eval_func = eval_func - quantized_model = quantizer.fit() - + from neural_compressor.quantization import fit + from neural_compressor.config import PostTrainingQuantConfig, set_random_seed + from neural_compressor.experimental import common + set_random_seed(9527) + config = PostTrainingQuantConfig() + quantized_model = fit(model, + conf=config, + calib_dataloader=common.DataLoader((x_train[:100], y_train[:100]), batch_size=10), + eval_func=eval_func) + if __name__ == '__main__': main() diff --git a/neural_compressor/adaptor/keras.py b/neural_compressor/adaptor/keras.py index 5f031d5b37b..8dff2d24d24 100644 --- a/neural_compressor/adaptor/keras.py +++ b/neural_compressor/adaptor/keras.py @@ -53,9 +53,9 @@ def __init__(self, framework_specific_info): self.approach = deep_get(self.framework_specific_info, 'approach', False) self.quantize_config = {'op_wise_config': {}} self.device = self.framework_specific_info['device'] - self.work_dir = os.path.abspath(self.framework_specific_info['workspace_path']) + #self.work_dir = os.path.abspath(self.framework_specific_info['workspace_path']) self.recipes = deep_get(self.framework_specific_info, 'recipes', {}) - os.makedirs(self.work_dir, exist_ok=True) + #os.makedirs(self.work_dir, exist_ok=True) self.pre_optimized_model = None self.pre_optimizer_handle = None diff --git a/neural_compressor/adaptor/keras_utils/quantizer.py b/neural_compressor/adaptor/keras_utils/quantizer.py index 241721d517e..76742001ded 100644 --- a/neural_compressor/adaptor/keras_utils/quantizer.py +++ b/neural_compressor/adaptor/keras_utils/quantizer.py @@ -1,19 +1,19 @@ -# Copyright 2019 Google LLC +#!/usr/bin/env python +# -*- coding: utf-8 -*- # +# Copyright (c) 2022 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -# ============================================================================== import sys import numpy as np diff --git a/neural_compressor/experimental/benchmark.py b/neural_compressor/experimental/benchmark.py index 00329dabd43..bff0633e72b 100644 --- a/neural_compressor/experimental/benchmark.py +++ b/neural_compressor/experimental/benchmark.py @@ -338,7 +338,9 @@ def run_instance(self, mode): b_dataloader_cfg = deep_get(cfg, 'evaluation.{}.dataloader'.format(mode)) self._b_dataloader = create_dataloader(self.framework, b_dataloader_cfg) + is_measure = False if self._b_func is None: + is_measure = True self._b_func = create_eval_func(self.framework, \ self._b_dataloader, \ adaptor, \ @@ -353,10 +355,11 @@ def run_instance(self, mode): assert len(objectives) == 1, 'benchmark supports one objective at a time' self.objectives = MultiObjective(objectives, cfg.tuning.accuracy_criterion, - is_measure=True) + is_measure=is_measure) if self._custom_b_func: val = self.objectives.evaluate(self._b_func, self._model.model) + return else: val = self.objectives.evaluate(self._b_func, self._model) # measurer contain info not only performance(eg, memory, model_size) diff --git a/neural_compressor/model/model.py b/neural_compressor/model/model.py index bc63217d7e5..0528f75a7f5 100644 --- a/neural_compressor/model/model.py +++ b/neural_compressor/model/model.py @@ -119,6 +119,7 @@ def _is_pytorch(model): return 'NA' def _is_keras(model): + from neural_compressor.adaptor.tf_utils.util import is_saved_model_format if isinstance(model, str): model = os.path.abspath(os.path.expanduser(model)) if (model.endswith('.h5') and os.path.isfile(model)) or \ diff --git a/test/itex/test_keras_in_keras_out.py b/test/itex/test_keras_in_keras_out.py new file mode 100644 index 00000000000..d2b9e888122 --- /dev/null +++ b/test/itex/test_keras_in_keras_out.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import time +import shutil +import numpy as np +import tensorflow as tf +from tensorflow import keras +from neural_compressor.utils import logger + +test_mode = 'accuracy' + +def build_model(): + # Load MNIST dataset + mnist = keras.datasets.mnist + (train_images, train_labels), (test_images, test_labels) = mnist.load_data() + + # Normalize the input image so that each pixel value is between 0 to 1. + train_images = train_images / 255.0 + test_images = test_images / 255.0 + + # Define the model architecture. + model = keras.Sequential([ + keras.layers.InputLayer(input_shape=(28, 28)), + keras.layers.Reshape(target_shape=(28, 28, 1)), + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'), + keras.layers.MaxPooling2D(pool_size=(2, 2)), + keras.layers.Flatten(), + keras.layers.Dense(10) + ]) + # Train the digit classification model + model.compile(optimizer='adam', + loss=tf.keras.losses.SparseCategoricalCrossentropy( + from_logits=True), + metrics=['accuracy']) + + model.fit( + train_images, + train_labels, + epochs=1, + validation_split=0.1, + ) + + _, baseline_model_accuracy = model.evaluate( + test_images, test_labels, verbose=0) + + print('Baseline test accuracy:', baseline_model_accuracy) + model.save("baseline_model") + +def build_dataset(): + # Load the data and split it between train and test sets + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + + # Scale images to the [0, 1] range + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + # Make sure images have shape (28, 28, 1) + x_train = np.expand_dims(x_train, -1) + x_test = np.expand_dims(x_test, -1) + + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, 10) + y_test = keras.utils.to_categorical(y_test, 10) + return x_train, y_train, x_test, y_test + +def eval_func(model): + x_train, y_train, x_test, y_test = build_dataset() + start = time.time() + model.compile(metrics=["accuracy"], run_eagerly=False) + score = model.evaluate(x_test, y_test) + end = time.time() + + if test_mode == 'performance': + latency = end - start + print("Latency: {:.3f} ms".format(latency * 1000)) + print("Throughput: {:.3f} data/sec".format(1. / latency)) + return score[1] + +class Dataset(object): + def __init__(self, batch_size=100): + mnist = keras.datasets.mnist + (train_images, train_labels), (test_images, test_labels) = mnist.load_data() + + # Normalize the input image so that each pixel value is between 0 to 1. + self.train_images = train_images / 255.0 + self.test_images = test_images / 255.0 + self.train_labels = train_labels + self.test_labels = test_labels + + def __len__(self): + return len(self.test_images) + + def __getitem__(self, idx): + return self.test_images[idx], self.test_labels[idx] + + +class TestKerasInKerasOut(unittest.TestCase): + @classmethod + def setUpClass(self): + os.environ["ITEX_ONEDNN_GRAPH"] = '1' + + @classmethod + def tearDownClass(self): + shutil.rmtree('baseline_model',ignore_errors=True) + shutil.rmtree('itex_qdq_keras_model',ignore_errors=True) + + def test_keras_in_keras_out(self): + logger.info("Run test_keras_in_keras_out case...") + global test_mode + test_mode = 'accuracy' + build_model() + + from neural_compressor.quantization import fit + from neural_compressor.config import PostTrainingQuantConfig, set_random_seed + from neural_compressor.experimental import common + set_random_seed(9527) + config = PostTrainingQuantConfig() + logger.info("=================Run Quantization...") + q_model = fit(keras.models.load_model('./baseline_model'), + conf=config, + calib_dataloader=common.DataLoader(Dataset()), + eval_func=eval_func) + q_model.save("itex_qdq_keras_model") + model = keras.models.load_model('./itex_qdq_keras_model') + model.summary() + found_quantize = False + found_dequantize = False + for layer in model.layers: + if 'quantize' in layer.name: + found_quantize = True + if 'de_quantize' in layer.name: + found_dequantize = True + self.assertEqual(found_quantize, True) + self.assertEqual(found_dequantize, True) + + from neural_compressor.benchmark import fit + from neural_compressor.config import BenchmarkConfig + conf = BenchmarkConfig(iteration=100, cores_per_instance=1, num_of_instance=1) + logger.info("=================Run BenchMark...") + test_mode = 'performance' + fit(model, conf, b_func=eval_func) + + def test_keras_model_interface(self): + logger.info("Run test_keras_model_interface case...") + global test_mode + test_mode = 'accuracy' + build_model() + + from neural_compressor.quantization import fit + from neural_compressor.config import PostTrainingQuantConfig, set_random_seed + from neural_compressor.experimental import common + set_random_seed(9527) + config = PostTrainingQuantConfig() + q_model = fit(keras.models.load_model('./baseline_model'), + conf=config, + calib_dataloader=common.DataLoader(Dataset()), + eval_func=eval_func) + q_model.save("itex_qdq_keras_model") + self.assertEqual(q_model.framework(), 'keras') + + framework_config = { + 'framework': 'keras', + 'approach': 'post_training_static_quant' + } + q_model.q_config = framework_config + self.assertEqual(q_model.q_config['framework'], 'keras') + self.assertEqual(q_model.graph_info, None) + self.assertEqual(q_model.framework(), 'keras') + self.assertEqual(isinstance(q_model.model, tf.keras.Model), True) + +if __name__ == '__main__': + unittest.main() From 19b954c47930ce02fd0dc5a749ac13060daef933 Mon Sep 17 00:00:00 2001 From: "Lv, Liang1" Date: Fri, 9 Dec 2022 14:38:31 +0800 Subject: [PATCH 07/14] add itex to requirements.txt Signed-off-by: Lv, Liang1 --- examples/keras/mnist/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/keras/mnist/requirements.txt b/examples/keras/mnist/requirements.txt index 2755e1a41ac..cee1363064f 100644 --- a/examples/keras/mnist/requirements.txt +++ b/examples/keras/mnist/requirements.txt @@ -1,2 +1,3 @@ tensorflow neural-compressor +intel-extension-for-tensorflow[cpu] \ No newline at end of file From 8b2c428a4709f9bbf599bdf3bc64d64fc6445dd7 Mon Sep 17 00:00:00 2001 From: "Lv, Liang1" Date: Fri, 9 Dec 2022 15:00:06 +0800 Subject: [PATCH 08/14] rebase master Signed-off-by: Lv, Liang1 --- examples/keras/mnist/mnist.py | 3 ++- test/itex/test_keras_in_keras_out.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/keras/mnist/mnist.py b/examples/keras/mnist/mnist.py index 1899c76accb..aa28a10ad92 100644 --- a/examples/keras/mnist/mnist.py +++ b/examples/keras/mnist/mnist.py @@ -90,7 +90,8 @@ def main(): model = build_model(x_train, y_train, x_test, y_test) from neural_compressor.quantization import fit - from neural_compressor.config import PostTrainingQuantConfig, set_random_seed + from neural_compressor.config import PostTrainingQuantConfig + from neural_compressor.utils.utility import set_random_seed from neural_compressor.experimental import common set_random_seed(9527) config = PostTrainingQuantConfig() diff --git a/test/itex/test_keras_in_keras_out.py b/test/itex/test_keras_in_keras_out.py index d2b9e888122..43990ac7b50 100644 --- a/test/itex/test_keras_in_keras_out.py +++ b/test/itex/test_keras_in_keras_out.py @@ -127,7 +127,8 @@ def test_keras_in_keras_out(self): build_model() from neural_compressor.quantization import fit - from neural_compressor.config import PostTrainingQuantConfig, set_random_seed + from neural_compressor.config import PostTrainingQuantConfig + from neural_compressor.utils.utility import set_random_seed from neural_compressor.experimental import common set_random_seed(9527) config = PostTrainingQuantConfig() @@ -163,7 +164,8 @@ def test_keras_model_interface(self): build_model() from neural_compressor.quantization import fit - from neural_compressor.config import PostTrainingQuantConfig, set_random_seed + from neural_compressor.config import PostTrainingQuantConfig + from neural_compressor.utils.utility import set_random_seed from neural_compressor.experimental import common set_random_seed(9527) config = PostTrainingQuantConfig() From 3cb0873d80bf1e15e14eef3887d4193794dd80c7 Mon Sep 17 00:00:00 2001 From: Clark Chin Date: Fri, 9 Dec 2022 16:23:36 +0800 Subject: [PATCH 09/14] remove the version check --- neural_compressor/model/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/neural_compressor/model/model.py b/neural_compressor/model/model.py index c0601551910..bfabf384b9d 100644 --- a/neural_compressor/model/model.py +++ b/neural_compressor/model/model.py @@ -134,7 +134,6 @@ def _is_keras(model): if version1_lt_version2(tf.version.VERSION, '2.10.0'): logger.warn("keras model running on tensorflow 2.10.0 and" " lower not support intel ITEX.") - return 'NA' try: model = tf.keras.models.load_model(model) except: From 1662e0b7a96bcda6dbf478817cbd5d6672a44892 Mon Sep 17 00:00:00 2001 From: Clark Chin Date: Fri, 9 Dec 2022 17:35:47 +0800 Subject: [PATCH 10/14] change keras to tensorflow framework --- examples/keras/mnist/mnist.py | 4 +-- .../experimental/common/model.py | 5 ++- neural_compressor/experimental/component.py | 4 +++ neural_compressor/model/model.py | 36 ++++++++----------- 4 files changed, 25 insertions(+), 24 deletions(-) diff --git a/examples/keras/mnist/mnist.py b/examples/keras/mnist/mnist.py index aa28a10ad92..74038fbfbad 100644 --- a/examples/keras/mnist/mnist.py +++ b/examples/keras/mnist/mnist.py @@ -40,7 +40,7 @@ def build_dataset(): class Dataset(): def __init__(self, ): - self.inputs, self.labels, _, _ = build_dataset() + _, _ , self.inputs, self.labels = build_dataset() def __getitem__(self, idx): return self.inputs[idx], self.labels[idx] @@ -97,7 +97,7 @@ def main(): config = PostTrainingQuantConfig() quantized_model = fit(model, conf=config, - calib_dataloader=common.DataLoader((x_train[:100], y_train[:100]), batch_size=10), + calib_dataloader=common.DataLoader(Dataset(), batch_size=10), eval_func=eval_func) if __name__ == '__main__': diff --git a/neural_compressor/experimental/common/model.py b/neural_compressor/experimental/common/model.py index f34a5c35b80..8768abc8c96 100644 --- a/neural_compressor/experimental/common/model.py +++ b/neural_compressor/experimental/common/model.py @@ -44,7 +44,10 @@ def __new__(cls, root, **kwargs): model_type = kwargs['modelType'] else: model_type = get_model_type(root) - model = MODELS['tensorflow'](model_type, root, **kwargs) + if model_type == 'keras': + model = MODELS['keras'](root, **kwargs) + else: + model = MODELS['tensorflow'](model_type, root, **kwargs) elif framework == 'pytorch': model = MODELS[framework](root, **kwargs) else: diff --git a/neural_compressor/experimental/component.py b/neural_compressor/experimental/component.py index 8afc1703c23..67afb53ca5e 100644 --- a/neural_compressor/experimental/component.py +++ b/neural_compressor/experimental/component.py @@ -472,6 +472,10 @@ def model(self, user_model): assert not isinstance(user_model, BaseModel), \ "Please pass an original framework model but not neural compressor model!" self.framework = get_model_fwk_name(user_model) + if self.framework == "tensorflow": + from ..model.model import get_model_type + if get_model_type(user_model) == 'keras': + self.framework = 'keras' if self.framework == "pytorch": if self.cfg.model.backend == "default": self.framework = "pytorch_fx" diff --git a/neural_compressor/model/model.py b/neural_compressor/model/model.py index bfabf384b9d..74f1c6e9e21 100644 --- a/neural_compressor/model/model.py +++ b/neural_compressor/model/model.py @@ -57,6 +57,20 @@ def get_model_type(model): """ from neural_compressor.adaptor.tf_utils.util import is_saved_model_format, is_ckpt_format + if isinstance(model, str): + model = os.path.abspath(os.path.expanduser(model)) + if (model.endswith('.h5') and os.path.isfile(model)) or \ + is_saved_model_format(os.path.dirname(model)) or \ + (os.path.isdir(model) and is_saved_model_format(model)): + if version1_lt_version2(tf.version.VERSION, '2.10.0'): + logger.warn("keras model running on tensorflow 2.10.0 and" + " lower not support intel ITEX.") + try: + model = tf.keras.models.load_model(model) + except: + pass + if isinstance(model, tf.keras.Model) and hasattr(model, 'to_json'): + return 'keras' if isinstance(model, tf.Graph): return 'graph' elif isinstance(model, tf.compat.v1.GraphDef): @@ -123,26 +137,6 @@ def _is_pytorch(model): except: return 'NA' - def _is_keras(model): - from neural_compressor.adaptor.tf_utils.util import is_saved_model_format - if isinstance(model, str): - from neural_compressor.adaptor.tf_utils.util import is_saved_model_format - model = os.path.abspath(os.path.expanduser(model)) - if (model.endswith('.h5') and os.path.isfile(model)) or \ - is_saved_model_format(os.path.dirname(model)) or \ - (os.path.isdir(model) and is_saved_model_format(model)): - if version1_lt_version2(tf.version.VERSION, '2.10.0'): - logger.warn("keras model running on tensorflow 2.10.0 and" - " lower not support intel ITEX.") - try: - model = tf.keras.models.load_model(model) - except: - return 'NA' - if isinstance(model, tf.keras.Model) and hasattr(model, 'to_json'): - return 'keras' - else: - return 'NA' - def _is_tensorflow(model): try: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -177,7 +171,7 @@ def _is_mxnet(model): if isinstance(model, TensorflowBaseModel): return 'tensorflow' - checker = [_is_keras, _is_tensorflow, _is_pytorch, _is_onnxruntime, _is_mxnet] + checker = [_is_tensorflow, _is_pytorch, _is_onnxruntime, _is_mxnet] for handler in checker: fwk_name = handler(model) if fwk_name != 'NA': From 9d5aabad7f9b80823c4ca85910ae0bc10e80ddef Mon Sep 17 00:00:00 2001 From: "Lv, Liang1" Date: Fri, 9 Dec 2022 21:30:47 +0800 Subject: [PATCH 11/14] fix ut issue Signed-off-by: Lv, Liang1 --- neural_compressor/model/keras_model.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/neural_compressor/model/keras_model.py b/neural_compressor/model/keras_model.py index f7034e4e635..d0da97fdfad 100644 --- a/neural_compressor/model/keras_model.py +++ b/neural_compressor/model/keras_model.py @@ -73,3 +73,10 @@ def export( def framework(self): return 'keras' + def get_all_weight_names(self): + import tensorflow as tf + names = [] + for index, layer in enumerate(tf.keras.models.load_model(self._model).layers): + if len(layer.weights): + names.append(index) + return names From a97c86a824bcd80f94a116dca4a012bd503431ed Mon Sep 17 00:00:00 2001 From: chensuyue Date: Fri, 9 Dec 2022 23:24:42 +0800 Subject: [PATCH 12/14] update ut itex binary with nightly master version Signed-off-by: chensuyue --- .azure-pipelines/scripts/ut/env_setup.sh | 5 ++++- .azure-pipelines/scripts/ut/run_basic_itex.sh | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.azure-pipelines/scripts/ut/env_setup.sh b/.azure-pipelines/scripts/ut/env_setup.sh index 65a46b1f6dc..01c1a0b4a91 100644 --- a/.azure-pipelines/scripts/ut/env_setup.sh +++ b/.azure-pipelines/scripts/ut/env_setup.sh @@ -28,7 +28,10 @@ elif [[ "${tensorflow_version}" != "" ]]; then pip install intel-tensorflow==${tensorflow_version} fi -if [[ "${itex_version}" != "" ]]; then +if [[ "${itex_version}" == "nightly" ]]; then + pip install /tf_dataset/itex_binary/221209/intel_extension_for_tensorflow-1.1.0-cp38-cp38-linux_x86_64.whl + pip install /tf_dataset/itex_binary/221209/intel_extension_for_tensorflow_lib-1.1.0.0-cp38-cp38-linux_x86_64.whl +elif [[ "${itex_version}" != "" ]]; then pip install --upgrade intel-extension-for-tensorflow[cpu]==${itex_version} fi diff --git a/.azure-pipelines/scripts/ut/run_basic_itex.sh b/.azure-pipelines/scripts/ut/run_basic_itex.sh index 45278216f8d..c937992b7be 100644 --- a/.azure-pipelines/scripts/ut/run_basic_itex.sh +++ b/.azure-pipelines/scripts/ut/run_basic_itex.sh @@ -3,7 +3,7 @@ python -c "import neural_compressor as nc;print(nc.version.__version__)" echo "run basic itex" echo "specify fwk version..." -export itex_version='1.0.0' +export itex_version='nightly' export tensorflow_version='2.10.0-official' echo "set up UT env..." From 463403ff7590626d624c8efdb37bd916acac274c Mon Sep 17 00:00:00 2001 From: "Lv, Liang1" Date: Sat, 10 Dec 2022 15:13:17 +0800 Subject: [PATCH 13/14] only enable keras adaptor for itex backend Signed-off-by: Lv, Liang1 --- examples/keras/mnist/mnist.py | 2 +- .../adaptor/keras_utils/__init__.py | 17 +++++ neural_compressor/conf/config.py | 1 + neural_compressor/experimental/benchmark.py | 4 ++ .../experimental/common/model.py | 7 +- neural_compressor/experimental/component.py | 2 +- neural_compressor/model/keras_model.py | 8 --- neural_compressor/model/model.py | 69 +++++++++++++++++++ test/itex/test_keras_in_keras_out.py | 6 +- 9 files changed, 99 insertions(+), 17 deletions(-) create mode 100644 neural_compressor/adaptor/keras_utils/__init__.py diff --git a/examples/keras/mnist/mnist.py b/examples/keras/mnist/mnist.py index 74038fbfbad..a5cf1d7cebf 100644 --- a/examples/keras/mnist/mnist.py +++ b/examples/keras/mnist/mnist.py @@ -94,7 +94,7 @@ def main(): from neural_compressor.utils.utility import set_random_seed from neural_compressor.experimental import common set_random_seed(9527) - config = PostTrainingQuantConfig() + config = PostTrainingQuantConfig(backend='itex') quantized_model = fit(model, conf=config, calib_dataloader=common.DataLoader(Dataset(), batch_size=10), diff --git a/neural_compressor/adaptor/keras_utils/__init__.py b/neural_compressor/adaptor/keras_utils/__init__.py new file mode 100644 index 00000000000..ed04d17bdbe --- /dev/null +++ b/neural_compressor/adaptor/keras_utils/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/neural_compressor/conf/config.py b/neural_compressor/conf/config.py index 2cb547e6cbc..617bafd09eb 100644 --- a/neural_compressor/conf/config.py +++ b/neural_compressor/conf/config.py @@ -1391,6 +1391,7 @@ def map_pyconfig_to_cfg(self, pythonic_config): if pythonic_config.benchmark.outputs != []: mapping.update({'model.outputs': pythonic_config.benchmark.outputs}) mapping.update({ + 'model.backend': pythonic_config.benchmark.backend, 'evaluation.performance.warmup': pythonic_config.benchmark.warmup, 'evaluation.performance.iteration': pythonic_config.benchmark.iteration, 'evaluation.performance.configs.cores_per_instance': diff --git a/neural_compressor/experimental/benchmark.py b/neural_compressor/experimental/benchmark.py index 117cb67cbe1..7aff71cc14c 100644 --- a/neural_compressor/experimental/benchmark.py +++ b/neural_compressor/experimental/benchmark.py @@ -481,6 +481,10 @@ def model(self, user_model): assert not isinstance(user_model, BaseModel), \ "Please pass an original framework model but not neural compressor model!" self.framework = get_model_fwk_name(user_model) + if self.framework == "tensorflow": + from ..model.model import get_model_type + if get_model_type(user_model) == 'keras' and cfg.model.backend == 'itex': + self.framework = 'keras' if self.framework == "pytorch": if cfg.model.backend == "default": self.framework = "pytorch_fx" diff --git a/neural_compressor/experimental/common/model.py b/neural_compressor/experimental/common/model.py index 8768abc8c96..6fec668f9e8 100644 --- a/neural_compressor/experimental/common/model.py +++ b/neural_compressor/experimental/common/model.py @@ -44,10 +44,9 @@ def __new__(cls, root, **kwargs): model_type = kwargs['modelType'] else: model_type = get_model_type(root) - if model_type == 'keras': - model = MODELS['keras'](root, **kwargs) - else: - model = MODELS['tensorflow'](model_type, root, **kwargs) + model = MODELS['tensorflow'](model_type, root, **kwargs) + elif framework == 'keras': + model = MODELS['keras'](root, **kwargs) elif framework == 'pytorch': model = MODELS[framework](root, **kwargs) else: diff --git a/neural_compressor/experimental/component.py b/neural_compressor/experimental/component.py index 67afb53ca5e..4afbf2589e8 100644 --- a/neural_compressor/experimental/component.py +++ b/neural_compressor/experimental/component.py @@ -474,7 +474,7 @@ def model(self, user_model): self.framework = get_model_fwk_name(user_model) if self.framework == "tensorflow": from ..model.model import get_model_type - if get_model_type(user_model) == 'keras': + if get_model_type(user_model) == 'keras' and self.cfg.model.backend == 'itex': self.framework = 'keras' if self.framework == "pytorch": if self.cfg.model.backend == "default": diff --git a/neural_compressor/model/keras_model.py b/neural_compressor/model/keras_model.py index d0da97fdfad..f0995ceed59 100644 --- a/neural_compressor/model/keras_model.py +++ b/neural_compressor/model/keras_model.py @@ -72,11 +72,3 @@ def export( @abstractmethod def framework(self): return 'keras' - - def get_all_weight_names(self): - import tensorflow as tf - names = [] - for index, layer in enumerate(tf.keras.models.load_model(self._model).layers): - if len(layer.weights): - names.append(index) - return names diff --git a/neural_compressor/model/model.py b/neural_compressor/model/model.py index 74f1c6e9e21..cdeb526fa37 100644 --- a/neural_compressor/model/model.py +++ b/neural_compressor/model/model.py @@ -439,6 +439,28 @@ def get_graph_from_original_keras_v2(model, output_dir): output_names = [tensor.name.split(':')[0] for tensor in output_tensors] return graph_def, input_names, output_names +def check_keras_format(model, saved_model_dir): + from tensorflow.python import saved_model + from tensorflow.python.saved_model.load import load + from tensorflow.python.saved_model import save_options + from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info + version = 'saved_model_v2' + try: + saved_model.save( + model, + saved_model_dir, + options=save_options.SaveOptions(save_debug_info=True)) + except: + return 'trackable_object' + saved_model_proto, _ = parse_saved_model_with_debug_info(saved_model_dir) + saved_model_version = saved_model_proto.saved_model_schema_version + if saved_model_version == 0: + return 'saved_model_v1' + if saved_model_version not in [1, 2]: + raise ValueError("SavedModel file format({0}) is not supported".format( + saved_model_version)) + return version + def get_graph_from_saved_model_v1(model): from tensorflow.python.framework import ops from tensorflow.python.saved_model import constants @@ -482,6 +504,51 @@ def get_graph_from_saved_model_v1(model): sess, graph.as_graph_def(), output_nodes) return graph_def, inputs, outputs +def keras_session(model, input_tensor_names, output_tensor_names, **kwargs): + """Build session with keras model + Args: + model (string or tf.keras.Model): model path or tf.keras.Model object + input_tensor_names (list of string): input_tensor_names of model + output_tensor_names (list of string): output_tensor_names of model + Returns: + sess (tf.compat.v1.Session): tf.compat.v1.Session object + input_tensor_names (list of string): validated input_tensor_names + output_tensor_names (list of string): validated output_tensor_names + """ + temp_dir = tempfile.mkdtemp() + if tf.version.VERSION > '2.1.0': + if not isinstance(model, tf.keras.Model): + model = tf.keras.models.load_model(model) + keras_format = check_keras_format(model, temp_dir) + if keras_format == 'saved_model_v2': + try: + graph_def, input_names, output_names = get_graph_from_saved_model_v2( + temp_dir, input_tensor_names, output_tensor_names) + if '_FusedBatchNormEx' in [node.op for node in graph_def.node]: + keras_format = 'trackable_object' + except: + keras_format = 'trackable_object' + if keras_format == 'trackable_object': + try: + graph_def, input_names, output_names = get_graph_from_original_keras_v2( + model, temp_dir) + except: + keras_format = 'saved_model_v1' + if keras_format == 'saved_model_v1': + try: + tf.keras.backend.set_learning_phase(0) + graph_def, input_names, output_names = get_graph_from_saved_model_v1(model) + except: + raise ValueError('Not supported keras model type...') + + # tensorflow 1.x use v1 convert method + else: + tf.keras.backend.set_learning_phase(0) + graph_def, input_names, output_names = get_graph_from_saved_model_v1(model) + shutil.rmtree(temp_dir, True) + return graph_def_session(graph_def, input_names, output_names, **kwargs) + + def slim_session(model, input_tensor_names, output_tensor_names, **kwargs): """Build session with slim model @@ -643,6 +710,7 @@ def saved_model_session(model, input_tensor_names, output_tensor_names, **kwargs 'graph_def': graph_def_session, 'graph': graph_session, 'saved_model': saved_model_session, + 'keras': keras_session, 'checkpoint': checkpoint_session, 'estimator': estimator_session, 'slim': slim_session,} @@ -1011,6 +1079,7 @@ def graph_def(self, graph_def): 'estimator': TensorflowBaseModel, 'slim': TensorflowBaseModel, 'saved_model': TensorflowSavedModelModel, + 'keras': TensorflowSavedModelModel } class TensorflowModel(object): diff --git a/test/itex/test_keras_in_keras_out.py b/test/itex/test_keras_in_keras_out.py index 43990ac7b50..aa776d1d6fd 100644 --- a/test/itex/test_keras_in_keras_out.py +++ b/test/itex/test_keras_in_keras_out.py @@ -131,7 +131,7 @@ def test_keras_in_keras_out(self): from neural_compressor.utils.utility import set_random_seed from neural_compressor.experimental import common set_random_seed(9527) - config = PostTrainingQuantConfig() + config = PostTrainingQuantConfig(backend='itex') logger.info("=================Run Quantization...") q_model = fit(keras.models.load_model('./baseline_model'), conf=config, @@ -152,7 +152,7 @@ def test_keras_in_keras_out(self): from neural_compressor.benchmark import fit from neural_compressor.config import BenchmarkConfig - conf = BenchmarkConfig(iteration=100, cores_per_instance=1, num_of_instance=1) + conf = BenchmarkConfig(backend='itex', iteration=100, cores_per_instance=1, num_of_instance=1) logger.info("=================Run BenchMark...") test_mode = 'performance' fit(model, conf, b_func=eval_func) @@ -168,7 +168,7 @@ def test_keras_model_interface(self): from neural_compressor.utils.utility import set_random_seed from neural_compressor.experimental import common set_random_seed(9527) - config = PostTrainingQuantConfig() + config = PostTrainingQuantConfig(backend='itex') q_model = fit(keras.models.load_model('./baseline_model'), conf=config, calib_dataloader=common.DataLoader(Dataset()), From 839b9a7dbe977f8e3c751596e7fda6e1c2bbefde Mon Sep 17 00:00:00 2001 From: "Lv, Liang1" Date: Sat, 10 Dec 2022 21:02:27 +0800 Subject: [PATCH 14/14] fix mnist copyright Signed-off-by: Lv, Liang1 --- examples/keras/mnist/mnist.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/keras/mnist/mnist.py b/examples/keras/mnist/mnist.py index a5cf1d7cebf..8e0fbf411e5 100644 --- a/examples/keras/mnist/mnist.py +++ b/examples/keras/mnist/mnist.py @@ -1,18 +1,20 @@ -# Copyright 2019 Google LLC +#!/usr/bin/env python +# -*- coding: utf-8 -*- # +# Copyright (c) 2022 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== + import os import tensorflow as tf import numpy as np