Skip to content

Commit 842a6e8

Browse files
committed
Merge branch 'master' into zixuan/new_config
Conflicts: neural_compressor/config.py
2 parents 35f699c + 84a6946 commit 842a6e8

File tree

1 file changed

+53
-1
lines changed

1 file changed

+53
-1
lines changed

neural_compressor/config.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,22 @@ def tensorboard(self, tensorboard):
149149
class BenchmarkConfig:
150150
"""Config Class for Benchmark.
151151
152+
Args:
153+
inputs (list, optional): A list of strings containing the inputs of model. Default is an empty list.
154+
outputs (list, optional): A list of strings containing the outputs of model. Default is an empty list.
155+
backend (str, optional): Backend name for model execution. Supported values include: 'default', 'itex',
156+
'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep'. Default value is 'default'.
157+
warmup (int, optional): The number of iterations to perform warmup before running performance tests.
158+
Default value is 5.
159+
iteration (int, optional): The number of iterations to run performance tests. Default is -1.
160+
cores_per_instance (int, optional): The number of CPU cores to use per instance. Default value is None.
161+
num_of_instance (int, optional): The number of instances to use for performance testing.
162+
Default value is None.
163+
inter_num_of_threads (int, optional): The number of threads to use for inter-thread operations.
164+
Default value is None.
165+
intra_num_of_threads (int, optional): The number of threads to use for intra-thread operations.
166+
Default value is None.
167+
152168
Example::
153169
154170
# Run benchmark according to config
@@ -1156,6 +1172,21 @@ def teacher_model(self, teacher_model):
11561172

11571173
class MixedPrecisionConfig(_BaseQuantizationConfig):
11581174
"""Config Class for MixedPrecision.
1175+
1176+
Args:
1177+
device (str, optional): device for execution. Support 'cpu' and 'gpu', default is 'cpu'
1178+
backend (str, optional): backend for model execution. Support 'default', 'itex', 'ipex',
1179+
'onnxrt_trt_ep', 'onnxrt_cuda_ep', default is 'default'
1180+
precision (str, optional): target precision for mix precision conversion.
1181+
Support 'bf16' and 'fp16', default is 'bf16'
1182+
inputs (list, optional): inputs of model, default is []
1183+
outputs (list, optional): outputs of model, default is []
1184+
tuning_criterion (TuningCriterion object, optional): accuracy tuning settings, it won't work
1185+
if there is no accuracy tuning process
1186+
accuracy_criterion (AccuracyCriterion object, optional): accuracy constraint settings, it won't
1187+
work if there is no accuracy tuning process
1188+
excluded_precisions (list, optional): precisions to be excluded during mix precision conversion,
1189+
default is []
11591190
11601191
Example::
11611192
@@ -1330,7 +1361,28 @@ def __init__(
13301361

13311362

13321363
class TF2ONNXConfig(ExportConfig):
1333-
"""Config Class for TF2ONNX."""
1364+
"""Config Class for TF2ONNX.
1365+
1366+
Args:
1367+
dtype (str, optional): The data type of export target model. Supports 'fp32' and 'int8'.
1368+
Defaults to 'int8'.
1369+
opset_version (int, optional): The version of the ONNX operator set to use. Defaults to 14.
1370+
quant_format (str, optional): The quantization format for the export target model.
1371+
Supports 'default', 'QDQ' and 'QOperator'. Defaults to 'QDQ'.
1372+
example_inputs (list, optional): A list example inputs to use for tracing the model.
1373+
Defaults to None.
1374+
input_names (list, optional): A list of model input names. Defaults to None.
1375+
output_names (list, optional): A list of model output names. Defaults to None.
1376+
dynamic_axes (dict, optional): A dictionary of dynamic axis information. Defaults to None.
1377+
**kwargs: Additional keyword arguments.
1378+
1379+
Examples::
1380+
1381+
# tensorflow QDQ int8 model 'q_model' export to ONNX int8 model
1382+
from neural_compressor.config import TF2ONNXConfig
1383+
config = TF2ONNXConfig()
1384+
q_model.export(output_graph, config)
1385+
"""
13341386
def __init__(
13351387
self,
13361388
dtype="int8",

0 commit comments

Comments
 (0)