diff --git a/neural_compressor/mix_precision.py b/neural_compressor/mix_precision.py index 1f5bc0445ba..f997e9f5471 100644 --- a/neural_compressor/mix_precision.py +++ b/neural_compressor/mix_precision.py @@ -252,35 +252,47 @@ def metric(self): def metric(self, user_metric): """Set metric class or a dict of built-in metric configures. - 1. neural_compressor have many built-in metrics, user can pass a metric configure dict to tell neural - compressor what metric will be use. - You can set multi-metrics to evaluate the performance of a specific model. + 1. neural_compressor have many built-in metrics, + user can pass a metric configure dict to tell neural compressor what metric will be use. + You also can set multi-metrics to evaluate the performance of a specific model. Single metric: {topk: 1} - Multi-metrics: {topk: 1, MSE: {compare_label: False}, weight: [0.5, 0.5], higher_is_better: [True, False] } - Refer to this [file](../docs/source/metric.md#supported-built-in-metric-matrix) for built-in metric list - 2. User also can set specific metric through this api. The metric class should take the outputs of the model or - postprocess(if have) as inputs, neural_compressor built-in metric always take(predictions, labels) as inputs - for update, and user_metric.metric_cls should be sub_class of neural_compressor.metric.BaseMetric. + For the built-in metrics, please refer to below link: + https://github.com/intel/neural-compressor/blob/master/docs/source/metric.md#supported-built-in-metric-matrix. + + 2. User also can get the built-in metrics by neural_compressor.Metric: + Metric(name="topk", k=1) + 3. User also can set specific metric through this api. The metric class should take the outputs of the model or + postprocess(if have) as inputs, neural_compressor built-in metric always take(predictions, labels) + as inputs for update, and user_metric.metric_cls should be sub_class of neural_compressor.metric.BaseMetric. Args: - user_metric(neural_compressor.metric.Metric or a dict of built-in metric configures): + user_metric(neural_compressor.metric.Metric or a dict of built-in metric configurations): The object of Metric or a dict of built-in metric configurations. + """ - from .metric import Metric as NCMetric, METRICS + from .metric import Metric as NCMetric + from .metric import METRICS if isinstance(user_metric, dict): metric_cfg = user_metric else: if isinstance(user_metric, NCMetric): - name = user_metric.name - metric_cls = user_metric.metric_cls - metric_cfg = {name: {**user_metric.kwargs}} + if user_metric.metric_cls is None: + name = user_metric.name + metric_cls = METRICS(self.conf.mixed_precision.framework).metrics[name] + metric_cfg = {name: {**user_metric.kwargs}} + self._metric = metric_cfg + return + else: + name = user_metric.name + metric_cls = user_metric.metric_cls + metric_cfg = {name: {**user_metric.kwargs}} else: for i in ['reset', 'update', 'result']: assert hasattr(user_metric, i), 'Please realise {} function' \