@@ -136,15 +136,42 @@ def tensorboard(self, tensorboard):
136136
137137
138138class BenchmarkConfig :
139- def __init__ (self , warmup = 5 , iteration = - 1 , cores_per_instance = None , num_of_instance = None ,
140- inter_num_of_threads = None , intra_num_of_threads = None ):
139+ def __init__ (self ,
140+ inputs = [],
141+ outputs = [],
142+ warmup = 5 ,
143+ iteration = - 1 ,
144+ cores_per_instance = None ,
145+ num_of_instance = None ,
146+ inter_num_of_threads = None ,
147+ intra_num_of_threads = None ):
148+ self ._inputs = inputs
149+ self ._outputs = outputs
141150 self ._warmup = warmup
142151 self ._iteration = iteration
143152 self ._cores_per_instance = cores_per_instance
144153 self ._num_of_instance = num_of_instance
145154 self ._inter_num_of_threads = inter_num_of_threads
146155 self ._intra_num_of_threads = intra_num_of_threads
147156
157+ @property
158+ def outputs (self ):
159+ return self ._outputs
160+
161+ @outputs .setter
162+ def outputs (self , outputs ):
163+ if check_value ('outputs' , outputs , str ):
164+ self ._outputs = outputs
165+
166+ @property
167+ def inputs (self ):
168+ return self ._inputs
169+
170+ @inputs .setter
171+ def inputs (self , inputs ):
172+ if check_value ('inputs' , inputs , str ):
173+ self ._inputs = inputs
174+
148175 @property
149176 def warmup (self ):
150177 return self ._warmup
@@ -285,7 +312,7 @@ def __init__(self,
285312 max_trials = 100 ,
286313 performance_only = False ,
287314 reduce_range = None ,
288- extra_precisions = [],
315+ extra_precisions = ["bf16" ],
289316 accuracy_criterion = accuracy_criterion ):
290317 self ._inputs = inputs
291318 self ._outputs = outputs
@@ -503,16 +530,16 @@ def strategy(self, strategy):
503530
504531class PostTrainingQuantConfig (_BaseQuantizationConfig ):
505532 def __init__ (self ,
506- device = ' cpu' ,
533+ device = " cpu" ,
507534 backend = "NA" ,
508535 inputs = [],
509536 outputs = [],
510- approach = ' auto' ,
537+ approach = " auto" ,
511538 calibration_sampling_size = [100 ],
512539 op_type_list = None ,
513540 op_name_list = None ,
514541 reduce_range = None ,
515- extra_precisions = [],
542+ extra_precisions = ["bf16" ],
516543 tuning_criterion = tuning_criterion ,
517544 accuracy_criterion = accuracy_criterion ,
518545 ):
@@ -551,7 +578,7 @@ def __init__(self,
551578 op_type_list = None ,
552579 op_name_list = None ,
553580 reduce_range = None ,
554- extra_precisions = []):
581+ extra_precisions = ["bf16" ]):
555582 super ().__init__ (inputs = inputs , outputs = outputs , device = device , backend = backend ,
556583 op_type_list = op_type_list , op_name_list = op_name_list ,
557584 reduce_range = reduce_range , extra_precisions = extra_precisions )
@@ -789,16 +816,16 @@ def dynamic_axes(self, dynamic_axes):
789816
790817
791818class Torch2ONNXConfig (ExportConfig ):
792- def __init__ (
793- self ,
794- dtype = "int8" ,
795- opset_version = 14 ,
796- quant_format = "QDQ" ,
797- example_inputs = None ,
798- input_names = None ,
799- output_names = None ,
800- dynamic_axes = None ,
801- ** kwargs ,
819+ def __init__ (
820+ self ,
821+ dtype = "int8" ,
822+ opset_version = 14 ,
823+ quant_format = "QDQ" ,
824+ example_inputs = None ,
825+ input_names = None ,
826+ output_names = None ,
827+ dynamic_axes = None ,
828+ ** kwargs ,
802829 ):
803830 super ().__init__ (
804831 dtype = dtype ,
@@ -813,16 +840,16 @@ def __init__(
813840
814841
815842class TF2ONNXConfig (ExportConfig ):
816- def __init__ (
817- self ,
818- dtype = "int8" ,
819- opset_version = 14 ,
820- quant_format = "QDQ" ,
821- example_inputs = None ,
822- input_names = None ,
823- output_names = None ,
824- dynamic_axes = None ,
825- ** kwargs ,
843+ def __init__ (
844+ self ,
845+ dtype = "int8" ,
846+ opset_version = 14 ,
847+ quant_format = "QDQ" ,
848+ example_inputs = None ,
849+ input_names = None ,
850+ output_names = None ,
851+ dynamic_axes = None ,
852+ ** kwargs ,
826853 ):
827854 super ().__init__ (
828855 dtype = dtype ,
@@ -837,7 +864,7 @@ def __init__(
837864
838865
839866def set_random_seed (seed : int ):
840- options .random_seed
867+ options .random_seed = seed
841868
842869
843870def set_workspace (workspace : str ):
0 commit comments