Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions neural_compressor/mix_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,35 +252,47 @@ def metric(self):
def metric(self, user_metric):
"""Set metric class or a dict of built-in metric configures.

1. neural_compressor have many built-in metrics, user can pass a metric configure dict to tell neural
compressor what metric will be use.
You can set multi-metrics to evaluate the performance of a specific model.
1. neural_compressor have many built-in metrics,
user can pass a metric configure dict to tell neural compressor what metric will be use.
You also can set multi-metrics to evaluate the performance of a specific model.
Single metric:
{topk: 1}

Multi-metrics:
{topk: 1,
MSE: {compare_label: False},
weight: [0.5, 0.5],
higher_is_better: [True, False]
}
Refer to this [file](../docs/source/metric.md#supported-built-in-metric-matrix) for built-in metric list
2. User also can set specific metric through this api. The metric class should take the outputs of the model or
postprocess(if have) as inputs, neural_compressor built-in metric always take(predictions, labels) as inputs
for update, and user_metric.metric_cls should be sub_class of neural_compressor.metric.BaseMetric.
For the built-in metrics, please refer to below link:
https://github.com/intel/neural-compressor/blob/master/docs/source/metric.md#supported-built-in-metric-matrix.

2. User also can get the built-in metrics by neural_compressor.Metric:
Metric(name="topk", k=1)
3. User also can set specific metric through this api. The metric class should take the outputs of the model or
postprocess(if have) as inputs, neural_compressor built-in metric always take(predictions, labels)
as inputs for update, and user_metric.metric_cls should be sub_class of neural_compressor.metric.BaseMetric.

Args:
user_metric(neural_compressor.metric.Metric or a dict of built-in metric configures):
user_metric(neural_compressor.metric.Metric or a dict of built-in metric configurations):
The object of Metric or a dict of built-in metric configurations.

"""
from .metric import Metric as NCMetric, METRICS
from .metric import Metric as NCMetric
from .metric import METRICS
if isinstance(user_metric, dict):
metric_cfg = user_metric
else:
if isinstance(user_metric, NCMetric):
name = user_metric.name
metric_cls = user_metric.metric_cls
metric_cfg = {name: {**user_metric.kwargs}}
if user_metric.metric_cls is None:
name = user_metric.name
metric_cls = METRICS(self.conf.mixed_precision.framework).metrics[name]
metric_cfg = {name: {**user_metric.kwargs}}
self._metric = metric_cfg
return
else:
name = user_metric.name
metric_cls = user_metric.metric_cls
metric_cfg = {name: {**user_metric.kwargs}}
else:
for i in ['reset', 'update', 'result']:
assert hasattr(user_metric, i), 'Please realise {} function' \
Expand Down