diff --git a/release-next.md b/release-next.md
index 504e4763..7fc64ba9 100644
--- a/release-next.md
+++ b/release-next.md
@@ -12,6 +12,22 @@
xf.fit(train_df)
result = xf.transform(train_df, as_csr=True)
```
+
+- **Permutation Feature Importance for model interpretibility.**
+
+ [PR#279](https://github.com/microsoft/NimbusML/pull/279)
+ Adds `permutation_feature_importance()` method to `Pipeline` and
+ predictor estimators, enabling evaluation of model-wide feature
+ importances on any dataset with same schema as the dataset used
+ to fit the `Pipeline`.
+
+ ```python
+ pipe = Pipeline([
+ LogisticRegressionBinaryClassifier(label='label', feature=['feature'])
+ ])
+ pipe.fit(data)
+ pipe.permutation_feature_importance(data)
+ ```
- **Initial implementation of LpScaler.**
diff --git a/src/DotNetBridge/DotNetBridge.csproj b/src/DotNetBridge/DotNetBridge.csproj
index db737e30..822db6aa 100644
--- a/src/DotNetBridge/DotNetBridge.csproj
+++ b/src/DotNetBridge/DotNetBridge.csproj
@@ -32,19 +32,19 @@
all
runtime; build; native; contentfiles; analyzers
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
-
+
diff --git a/src/Platforms/build.csproj b/src/Platforms/build.csproj
index 10f89106..3db67054 100644
--- a/src/Platforms/build.csproj
+++ b/src/Platforms/build.csproj
@@ -11,19 +11,19 @@
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
-
+
diff --git a/src/python/nimbusml.pyproj b/src/python/nimbusml.pyproj
index 345e7ccf..768636af 100644
--- a/src/python/nimbusml.pyproj
+++ b/src/python/nimbusml.pyproj
@@ -181,6 +181,7 @@
+
@@ -436,6 +437,7 @@
+
@@ -682,6 +684,7 @@
+
diff --git a/src/python/nimbusml/base_predictor.py b/src/python/nimbusml/base_predictor.py
index f33f746c..538e7b5c 100644
--- a/src/python/nimbusml/base_predictor.py
+++ b/src/python/nimbusml/base_predictor.py
@@ -88,7 +88,13 @@ def _invoke_inference_method(self, method, X, **params):
@trace
def get_feature_contributions(self, X, **params):
- return self._invoke_inference_method('get_feature_contributions', X, **params)
+ return self._invoke_inference_method('get_feature_contributions',
+ X, **params)
+
+ @trace
+ def permutation_feature_importance(self, X, **params):
+ return self._invoke_inference_method('permutation_feature_importance',
+ X, **params)
@trace
def predict(self, X, **params):
diff --git a/src/python/nimbusml/examples/PermutationFeatureImportance.py b/src/python/nimbusml/examples/PermutationFeatureImportance.py
new file mode 100644
index 00000000..44a476ba
--- /dev/null
+++ b/src/python/nimbusml/examples/PermutationFeatureImportance.py
@@ -0,0 +1,173 @@
+###############################################################################
+# Permutation Feature Importance (PFI)
+
+# Permutation feature importance (PFI) is a technique to determine the global
+# importance of features in a trained machine learning model. PFI is a simple
+# yet powerful technique motivated by Breiman in section 10 of his Random
+# Forests paper (Machine Learning, 2001). The advantage of the PFI method is
+# that it is model agnostic - it works with any model that can be evaluated -
+# and it can use any dataset, not just the training set, to compute feature
+# importance metrics.
+
+# PFI works by taking a labeled dataset, choosing a feature, and permuting the
+# values for that feature across all the examples, so that each example now has
+# a random value for the feature and the original values for all other
+# features. The evaluation metric (e.g. NDCG) is then calculated for this
+# modified dataset, and the change in the evaluation metric from the original
+# dataset is computed. The larger the change in the evaluation metric, the more
+# important the feature is to the model, i.e. the most important features are
+# those that the model is most sensitive to. PFI works by performing this
+# permutation analysis across allthe features of a model, one after another.
+
+# PFI is supported for binary classifiers, classifiers, regressors, and
+# rankers.
+
+from nimbusml import Pipeline, FileDataStream
+from nimbusml.datasets import get_dataset
+from nimbusml.ensemble import LightGbmRanker
+from nimbusml.feature_extraction.categorical import OneHotVectorizer
+from nimbusml.linear_model import LogisticRegressionBinaryClassifier, \
+ FastLinearClassifier, FastLinearRegressor
+from nimbusml.preprocessing import ToKey
+from numpy.testing import assert_almost_equal
+
+# data input (as a FileDataStream)
+adult_path = get_dataset('uciadult_train').as_filepath()
+classification_data = FileDataStream.read_csv(adult_path)
+print(classification_data.head())
+# label workclass education ... capital-loss hours-per-week
+# 0 0 Private 11th ... 0 40
+# 1 0 Private HS-grad ... 0 50
+# 2 1 Local-gov Assoc-acdm ... 0 40
+# 3 1 Private Some-college ... 0 40
+# 4 0 ? Some-college ... 0 30
+
+######################################
+# PFI for Binary Classification models
+######################################
+# define the training pipeline with a binary classifier
+binary_pipeline = Pipeline([
+ OneHotVectorizer(columns=['education']),
+ LogisticRegressionBinaryClassifier(
+ feature=['age', 'education'], label='label')])
+
+# train the model
+binary_model = binary_pipeline.fit(classification_data)
+
+# get permutation feature importance
+binary_pfi = binary_model.permutation_feature_importance(classification_data)
+
+# Print PFI for each feature, ordered by most important features w.r.t. AUC.
+# Since AUC is an increasing metric, the highest negative changes indicate the
+# most important features.
+print("============== PFI for Binary Classification Model ==============")
+print(binary_pfi.sort_values('AreaUnderRocCurve').head())
+# FeatureName AreaUnderRocCurve AreaUnderRocCurve.StdErr ...
+# 0 age -0.081604 0.0 ...
+# 6 education.Prof-school -0.012964 0.0 ...
+# 10 education.Doctorate -0.012863 0.0 ...
+# 8 education.Bachelors -0.010593 0.0 ...
+# 2 education.HS-grad -0.005918 0.0 ...
+
+
+###############################
+# PFI for Classification models
+###############################
+# define the training pipeline with a classifier
+# use 1 thread and no shuffling to force determinism
+multiclass_pipeline = Pipeline([
+ OneHotVectorizer(columns=['education']),
+ FastLinearClassifier(feature=['age', 'education'], label='label',
+ number_of_threads=1, shuffle=False)])
+
+# train the model
+multiclass_model = multiclass_pipeline.fit(classification_data)
+
+# get permutation feature importance
+multiclass_pfi = multiclass_model.permutation_feature_importance(classification_data)
+
+# Print PFI for each feature, ordered by most important features w.r.t. Macro
+# accuracy. Since Macro accuracy is an increasing metric, the highest negative
+# changes indicate the most important features.
+print("================== PFI for Classification Model ==================")
+print(multiclass_pfi.sort_values('MacroAccuracy').head())
+# FeatureName MacroAccuracy ... MicroAccuracy ...
+# 10 education.Doctorate -0.028233 ... -0.020 ...
+# 0 age -0.001750 ... 0.002 ...
+# 6 education.Prof-school -0.001750 ... 0.002 ...
+# 9 education.Masters -0.001299 ... -0.002 ...
+# 1 education.11th 0.000000 ... 0.000 ...
+
+###########################
+# PFI for Regression models
+###########################
+# load input data
+infert_path = get_dataset('infert').as_filepath()
+regression_data = FileDataStream.read_csv(infert_path)
+print(regression_data.head())
+# age case education induced parity ... row_num spontaneous ...
+# 0 26 1 0-5yrs 1 6 ... 1 2 ...
+# 1 42 1 0-5yrs 1 1 ... 2 0 ...
+# 2 39 1 0-5yrs 2 6 ... 3 0 ...
+# 3 34 1 0-5yrs 2 4 ... 4 0 ...
+# 4 35 1 6-11yrs 1 3 ... 5 1 ...
+
+# define the training pipeline with a regressor
+# use 1 thread and no shuffling to force determinism
+regression_pipeline = Pipeline([
+ OneHotVectorizer(columns=['education']),
+ FastLinearRegressor(feature=['induced', 'education'], label='age',
+ number_of_threads=1, shuffle=False)])
+
+# train the model
+regression_model = regression_pipeline.fit(regression_data)
+
+# get permutation feature importance
+regression_pfi = regression_model.permutation_feature_importance(regression_data)
+
+# print PFI for each feaure, ordered by most important features w.r.t. MAE.
+# Since MAE is a decreasing metric, the highest positive changes indicate the
+# most important features.
+print("==================== PFI for Regression Model ====================")
+print(regression_pfi.sort_values('MeanAbsoluteError', ascending=False).head())
+# FeatureName MeanAbsoluteError ... RSquared RSquared.StdErr
+#3 education.12+ yrs 0.393451 ... -0.146338 0.0
+#0 induced 0.085804 ... -0.026168 0.0
+#1 education.0-5yrs 0.064460 ... -0.027587 0.0
+#2 education.6-11yrs -0.000047 ... 0.000059 0.0
+
+########################
+# PFI for Ranking models
+########################
+# load input data
+ticket_path = get_dataset('gen_tickettrain').as_filepath()
+ranking_data = FileDataStream.read_csv(ticket_path)
+print(ranking_data.head())
+# rank group carrier price Class dep_day nbr_stops duration
+# 0 2 1 AA 240 3 1 0 12.0
+# 1 1 1 AA 300 3 0 1 15.0
+# 2 1 1 AA 360 3 0 2 18.0
+# 3 0 1 AA 540 2 0 0 12.0
+# 4 1 1 AA 600 2 0 1 15.0
+
+# define the training pipeline with a ranker
+ranking_pipeline = Pipeline([
+ ToKey(columns=['group']),
+ LightGbmRanker(feature=['Class', 'dep_day', 'duration'],
+ label='rank', group_id='group')])
+
+# train the model
+ranking_model = ranking_pipeline.fit(ranking_data)
+
+# get permutation feature importance
+ranking_pfi = ranking_model.permutation_feature_importance(ranking_data)
+
+# Print PFI for each feature, ordered by most important features w.r.t. DCG@1.
+# Since DCG is an increasing metric, the highest negative changes indicate the
+# most important features.
+print("===================== PFI for Ranking Model =====================")
+print(ranking_pfi.sort_values('DCG@1').head())
+# Feature DCG@1 DCG@2 DCG@3 ... NDCG@1 NDCG@2 ...
+# 0 Class -4.869096 -7.030914 -5.948893 ... -0.420238 -0.407281 ...
+# 2 duration -2.344379 -3.595958 -3.956632 ... -0.232143 -0.231539 ...
+# 1 dep_day 0.000000 0.000000 0.000000 ... 0.000000 0.000000 ...
diff --git a/src/python/nimbusml/internal/entrypoints/transforms_permutationfeatureimportance.py b/src/python/nimbusml/internal/entrypoints/transforms_permutationfeatureimportance.py
new file mode 100644
index 00000000..18ff2e51
--- /dev/null
+++ b/src/python/nimbusml/internal/entrypoints/transforms_permutationfeatureimportance.py
@@ -0,0 +1,81 @@
+# - Generated by tools/entrypoint_compiler.py: do not edit by hand
+"""
+Transforms.PermutationFeatureImportance
+"""
+
+import numbers
+
+from ..utils.entrypoints import EntryPoint
+from ..utils.utils import try_set, unlist
+
+
+def transforms_permutationfeatureimportance(
+ data,
+ predictor_model,
+ metrics=None,
+ use_feature_weight_filter=False,
+ number_of_examples_to_use=None,
+ permutation_count=1,
+ **params):
+ """
+ **Description**
+ Permutation Feature Importance (PFI)
+
+ :param data: Input dataset (inputs).
+ :param predictor_model: The path to the model file (inputs).
+ :param use_feature_weight_filter: Use feature weights to pre-
+ filter features (inputs).
+ :param number_of_examples_to_use: Limit the number of examples to
+ evaluate on (inputs).
+ :param permutation_count: The number of permutations to perform
+ (inputs).
+ :param metrics: The PFI metrics (outputs).
+ """
+
+ entrypoint_name = 'Transforms.PermutationFeatureImportance'
+ inputs = {}
+ outputs = {}
+
+ if data is not None:
+ inputs['Data'] = try_set(
+ obj=data,
+ none_acceptable=False,
+ is_of_type=str)
+ if predictor_model is not None:
+ inputs['PredictorModel'] = try_set(
+ obj=predictor_model,
+ none_acceptable=False,
+ is_of_type=str)
+ if use_feature_weight_filter is not None:
+ inputs['UseFeatureWeightFilter'] = try_set(
+ obj=use_feature_weight_filter,
+ none_acceptable=True,
+ is_of_type=bool)
+ if number_of_examples_to_use is not None:
+ inputs['NumberOfExamplesToUse'] = try_set(
+ obj=number_of_examples_to_use,
+ none_acceptable=True,
+ is_of_type=numbers.Real)
+ if permutation_count is not None:
+ inputs['PermutationCount'] = try_set(
+ obj=permutation_count,
+ none_acceptable=True,
+ is_of_type=numbers.Real)
+ if metrics is not None:
+ outputs['Metrics'] = try_set(
+ obj=metrics,
+ none_acceptable=False,
+ is_of_type=str)
+
+ input_variables = {
+ x for x in unlist(inputs.values())
+ if isinstance(x, str) and x.startswith("$")}
+ output_variables = {
+ x for x in unlist(outputs.values())
+ if isinstance(x, str) and x.startswith("$")}
+
+ entrypoint = EntryPoint(
+ name=entrypoint_name, inputs=inputs, outputs=outputs,
+ input_variables=input_variables,
+ output_variables=output_variables)
+ return entrypoint
diff --git a/src/python/nimbusml/pipeline.py b/src/python/nimbusml/pipeline.py
index b3be72f8..fa2542c2 100644
--- a/src/python/nimbusml/pipeline.py
+++ b/src/python/nimbusml/pipeline.py
@@ -60,6 +60,8 @@
transforms_modelcombiner
from .internal.entrypoints.transforms_optionalcolumncreator import \
transforms_optionalcolumncreator
+from .internal.entrypoints.transforms_permutationfeatureimportance import \
+ transforms_permutationfeatureimportance
from .internal.entrypoints \
.transforms_predictedlabelcolumnoriginalvalueconverter import \
transforms_predictedlabelcolumnoriginalvalueconverter
@@ -1738,7 +1740,7 @@ def get_feature_contributions(self, X, top=10, bottom=10, verbose=0,
to report.
:param bottom: The number of negative contributions with highest
magnitude to report.
- :return: dataframe of containing the raw data, predicted label, score,
+ :return: dataframe containing the raw data, predicted label, score,
probabilities, and feature contributions.
"""
self.verbose = verbose
@@ -1855,6 +1857,181 @@ def get_schema(self, verbose=0, **params):
return out_data
+ @trace
+ def permutation_feature_importance(self, X, number_of_examples=None,
+ permutation_count=1,
+ filter_zero_weight_features=False,
+ verbose=0, as_binary_data_stream=False,
+ **params):
+ """
+ Permutation feature importance (PFI) is a technique to determine the
+ global importance of features in a trained machine learning model. PFI
+ is a simple yet powerful technique motivated by Breiman in section 10
+ of his Random Forests paper (Machine Learning, 2001). The advantage of
+ the PFI method is that it is model agnostic - it works with any model
+ that can be evaluated - and it can use any dataset, not just the
+ training set, to compute feature importance metrics.
+
+ PFI works by taking a labeled dataset, choosing a feature, and
+ permuting the values for that feature across all the examples, so that
+ each example now has a random value for the feature and the original
+ values for all other features. The evaluation metric (e.g. NDCG) is
+ then calculated for this modified dataset, and the change in the
+ evaluation metric from the original dataset is computed. The larger the
+ change in the evaluation metric, the more important the feature is to
+ the model, i.e. the most important features are those that the model is
+ most sensitive to. PFI works by performing this permutation analysis
+ across all the features of a model, one after another.
+
+ Note that for increasing metrics (e.g. AUC, accuracy, R-Squared, NDCG),
+ the most important features will be those with the highest negative
+ mean change in the metric. Conversely, for decreasing metrics (e.g.
+ Mean Squared Error, Log loss), the most important features will be
+ those with the highest positive mean change in the metric.
+
+ PFI is supported for binary classifiers, classifiers, regressors, and
+ rankers.
+
+ The mean changes and statndard errors of the means are evaluated for
+ the following metrics are evaluated for PFI:
+
+ * Binary Classification:
+
+ * Area under ROC curve (AUC)
+ * Accuracy
+ * Positive precision
+ * Positive recall
+ * Negative precision
+ * Negative recall
+ * F1 score
+ * Area under Precision-Recall curve (AUPRC)
+
+ * Multiclass classification:
+
+ * Macro accuracy
+ * Micro accuracy
+ * Log loss
+ * Log loss reduction
+ * Top k accuracy
+ * Per-class log loss
+
+ * Regression:
+
+ * Mean absolute error (MAE)
+ * Mean squared error (MSE)
+ * Root mean squared error (RMSE)
+ * Loss function
+ * R-Squared
+
+ * Ranking
+
+ * Discounted cumulative gains (DCG) @1, @2, and @3
+ * Normalized discounted cumulative gains (NDCG) @1, @2, and @3
+
+ **Reference**
+
+ `Breiman, L. Random Forests. Machine Learning (2001) 45: 5.
+ `_
+
+ :param X: {array-like [n_samples, n_features],
+ :py:class:`nimbusml.FileDataStream` }
+ :param number_of_examples: Limit the number of examples to evaluate on.
+ ``'None'`` means all examples in the dataset are used.
+ :param permutation_count: The number of permutations to perform.
+ :filter_zero_weight_features: Pre-filter features with zero weight. PFI
+ will not be evaluated on these features.
+ :return: dataframe containing the mean change in evaluation metrics and
+ standard error of the mean for each feature. Features with the
+ largest change in a metric are the most important in the model with
+ respect to that metric.
+ """
+ self.verbose = verbose
+
+ if not self._is_fitted:
+ raise ValueError(
+ "Model is not fitted. Train or load a model before test().")
+
+ X, _, _, _, _, schema, _, _ = self._preprocess_X_y(X)
+
+ all_nodes = []
+ inputs = dict([('data', ''), ('predictor_model', self.model)])
+ if isinstance(X, FileDataStream):
+ importtext_node = data_customtextloader(
+ input_file="$file",
+ data="$data",
+ custom_schema=schema.to_string(
+ add_sep=True))
+ all_nodes = [importtext_node]
+ inputs = dict([('file', ''), ('predictor_model', self.model)])
+
+ pfi_node = transforms_permutationfeatureimportance(
+ data="$data",
+ predictor_model="$predictor_model",
+ metrics="$output_data",
+ permutation_count=permutation_count,
+ number_of_examples_to_use=number_of_examples,
+ use_feature_weight_filter=filter_zero_weight_features)
+
+ all_nodes.extend([pfi_node])
+
+ outputs = dict(output_data="")
+
+ data_output_format = DataOutputFormat.IDV if as_binary_data_stream \
+ else DataOutputFormat.DF,
+
+ graph = Graph(
+ inputs,
+ outputs,
+ data_output_format,
+ *all_nodes)
+
+ class_name = type(self).__name__
+ method_name = inspect.currentframe().f_code.co_name
+ telemetry_info = ".".join([class_name, method_name])
+
+ try:
+ (out_model, out_data, out_metrics, _) = graph.run(
+ X=X,
+ random_state=self.random_state,
+ model=self.model,
+ verbose=verbose,
+ telemetry_info=telemetry_info,
+ **params)
+ except RuntimeError as e:
+ raise e
+
+ out_data = self._fix_pfi_columns(out_data)
+
+ return out_data
+
+ def _fix_pfi_columns(self, data):
+ cols = []
+ for i in range(len(data.columns)):
+ if 'StdErr' in data.columns.values[i]:
+ if data.columns.values[i][:15] == 'PerClassLogLoss' :
+ cols.append('PerClassLogLoss' + \
+ data.columns.values[i][21:] + '.StdErr')
+ elif data.columns.values[i][:10] == 'Discounted':
+ pos = int(data.columns.values[i][-1]) + 1
+ cols.append('DCG@' + str(pos) + '.StdErr')
+ elif data.columns.values[i][:10] == 'Normalized':
+ pos = int(data.columns.values[i][-1]) + 1
+ cols.append('NDCG@' + str(pos) + '.StdErr')
+ else:
+ cols.append(data.columns.values[i][:-6] + '.StdErr')
+ else:
+ if data.columns.values[i][:10] == 'Discounted':
+ pos = int(data.columns.values[i][26]) + 1
+ cols.append('DCG@' + str(pos))
+ elif data.columns.values[i][:10] == 'Normalized':
+ pos = int(data.columns.values[i][36]) + 1
+ cols.append('NDCG@' + str(pos))
+ else:
+ cols.append(data.columns.values[i])
+ data.columns = cols
+
+ return data
+
@trace
def _predict(self, X, y=None,
evaltype='auto', group_id=None,
diff --git a/src/python/nimbusml/tests/pipeline/test_permutation_feature_importance.py b/src/python/nimbusml/tests/pipeline/test_permutation_feature_importance.py
new file mode 100644
index 00000000..347b2798
--- /dev/null
+++ b/src/python/nimbusml/tests/pipeline/test_permutation_feature_importance.py
@@ -0,0 +1,125 @@
+# --------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------------------------
+import os
+import unittest
+
+from nimbusml import FileDataStream
+from nimbusml import Pipeline
+from nimbusml.datasets import get_dataset
+from nimbusml.ensemble import LightGbmRanker
+from nimbusml.feature_extraction.categorical import OneHotVectorizer
+from nimbusml.linear_model import LogisticRegressionBinaryClassifier, \
+ FastLinearClassifier, FastLinearRegressor
+from nimbusml.preprocessing import ToKey
+from numpy.testing import assert_almost_equal
+from pandas.testing import assert_frame_equal
+
+class TestPermutationFeatureImportance(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(self):
+ adult_path = get_dataset('uciadult_train').as_filepath()
+ self.classification_data = FileDataStream.read_csv(adult_path)
+ binary_pipeline = Pipeline([
+ OneHotVectorizer(columns=['education']),
+ LogisticRegressionBinaryClassifier(
+ feature=['age', 'education'], label='label',
+ number_of_threads=1)])
+ self.binary_model = binary_pipeline.fit(self.classification_data)
+ self.binary_pfi = self.binary_model.permutation_feature_importance(self.classification_data)
+ classifier_pipeline = Pipeline([
+ OneHotVectorizer(columns=['education']),
+ FastLinearClassifier(feature=['age', 'education'], label='label',
+ number_of_threads=1, shuffle=False)])
+ self.classifier_model = classifier_pipeline.fit(self.classification_data)
+ self.classifier_pfi = self.classifier_model.permutation_feature_importance(self.classification_data)
+
+ infert_path = get_dataset('infert').as_filepath()
+ self.regression_data = FileDataStream.read_csv(infert_path)
+ regressor_pipeline = Pipeline([
+ OneHotVectorizer(columns=['education']),
+ FastLinearRegressor(feature=['induced', 'education'], label='age',
+ number_of_threads=1, shuffle=False)])
+ self.regressor_model = regressor_pipeline.fit(self.regression_data)
+ self.regressor_pfi = self.regressor_model.permutation_feature_importance(self.regression_data)
+
+ ticket_path = get_dataset('gen_tickettrain').as_filepath()
+ self.ranking_data = FileDataStream.read_csv(ticket_path)
+ ranker_pipeline = Pipeline([
+ ToKey(columns=['group']),
+ LightGbmRanker(feature=['Class', 'dep_day', 'duration'],
+ label='rank', group_id='group',
+ random_state=0, number_of_threads=1)])
+ self.ranker_model = ranker_pipeline.fit(self.ranking_data)
+ self.ranker_pfi = self.ranker_model.permutation_feature_importance(self.ranking_data)
+
+ def test_binary_classifier(self):
+ assert_almost_equal(self.binary_pfi['AreaUnderRocCurve'].sum(), -0.140824, 6)
+ assert_almost_equal(self.binary_pfi['PositivePrecision'].sum(), -0.482143, 6)
+ assert_almost_equal(self.binary_pfi['PositiveRecall'].sum(), -0.0695652, 6)
+ assert_almost_equal(self.binary_pfi['NegativePrecision'].sum(), -0.0139899, 6)
+ assert_almost_equal(self.binary_pfi['NegativeRecall'].sum(), -0.00779221, 6)
+ assert_almost_equal(self.binary_pfi['F1Score'].sum(), -0.126983, 6)
+ assert_almost_equal(self.binary_pfi['AreaUnderPrecisionRecallCurve'].sum(), -0.19365, 5)
+
+ def test_binary_classifier_from_loaded_model(self):
+ model_path = "model.zip"
+ self.binary_model.save_model(model_path)
+ loaded_model = Pipeline()
+ loaded_model.load_model(model_path)
+ pfi_from_loaded = loaded_model.permutation_feature_importance(self.classification_data)
+ assert_frame_equal(self.binary_pfi, pfi_from_loaded)
+ os.remove(model_path)
+
+ def test_clasifier(self):
+ assert_almost_equal(self.classifier_pfi['MacroAccuracy'].sum(), -0.0256352, 6)
+ assert_almost_equal(self.classifier_pfi['LogLoss'].sum(), 0.158811, 6)
+ assert_almost_equal(self.classifier_pfi['LogLossReduction'].sum(), -0.29449, 5)
+ assert_almost_equal(self.classifier_pfi['PerClassLogLoss.0'].sum(), 0.0808459, 6)
+ assert_almost_equal(self.classifier_pfi['PerClassLogLoss.1'].sum(), 0.419826, 6)
+
+ def test_classifier_from_loaded_model(self):
+ model_path = "model.zip"
+ self.classifier_model.save_model(model_path)
+ loaded_model = Pipeline()
+ loaded_model.load_model(model_path)
+ pfi_from_loaded = loaded_model.permutation_feature_importance(self.classification_data)
+ assert_frame_equal(self.classifier_pfi, pfi_from_loaded)
+ os.remove(model_path)
+
+ def test_regressor(self):
+ assert_almost_equal(self.regressor_pfi['MeanAbsoluteError'].sum(), 0.504701, 6)
+ assert_almost_equal(self.regressor_pfi['MeanSquaredError'].sum(), 5.59277, 5)
+ assert_almost_equal(self.regressor_pfi['RootMeanSquaredError'].sum(), 0.553048, 6)
+ assert_almost_equal(self.regressor_pfi['RSquared'].sum(), -0.203612, 6)
+
+ def test_regressor_from_loaded_model(self):
+ model_path = "model.zip"
+ self.regressor_model.save_model(model_path)
+ loaded_model = Pipeline()
+ loaded_model.load_model(model_path)
+ pfi_from_loaded = loaded_model.permutation_feature_importance(self.regression_data)
+ assert_frame_equal(self.regressor_pfi, pfi_from_loaded)
+ os.remove(model_path)
+
+ def test_ranker(self):
+ assert_almost_equal(self.ranker_pfi['DCG@1'].sum(), -2.16404, 5)
+ assert_almost_equal(self.ranker_pfi['DCG@2'].sum(), -3.5294, 4)
+ assert_almost_equal(self.ranker_pfi['DCG@3'].sum(), -4.9721, 4)
+ assert_almost_equal(self.ranker_pfi['NDCG@1'].sum(), -0.114286, 6)
+ assert_almost_equal(self.ranker_pfi['NDCG@2'].sum(), -0.198631, 6)
+ assert_almost_equal(self.ranker_pfi['NDCG@3'].sum(), -0.236544, 6)
+
+ def test_ranker_from_loaded_model(self):
+ model_path = "model.zip"
+ self.ranker_model.save_model(model_path)
+ loaded_model = Pipeline()
+ loaded_model.load_model(model_path)
+ pfi_from_loaded = loaded_model.permutation_feature_importance(self.ranking_data)
+ assert_frame_equal(self.ranker_pfi, pfi_from_loaded)
+ os.remove(model_path)
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/src/python/tests/test_estimator_checks.py b/src/python/tests/test_estimator_checks.py
index e47ce638..7dfd5eb8 100644
--- a/src/python/tests/test_estimator_checks.py
+++ b/src/python/tests/test_estimator_checks.py
@@ -16,6 +16,7 @@
from nimbusml.ensemble import LightGbmRegressor
from nimbusml.feature_extraction.text import NGramFeaturizer
from nimbusml.internal.entrypoints._ngramextractor_ngram import n_gram
+from nimbusml.linear_model import SgdBinaryClassifier
from nimbusml.preprocessing import TensorFlowScorer
from nimbusml.preprocessing.filter import SkipFilter, TakeFilter
from nimbusml.timeseries import (IidSpikeDetector, IidChangePointDetector,
@@ -201,6 +202,7 @@
'LightGbmRanker': LightGbmRanker(
minimum_example_count_per_group=1, minimum_example_count_per_leaf=1),
'NGramFeaturizer': NGramFeaturizer(word_feature_extractor=n_gram()),
+ 'SgdBinaryClassifier': SgdBinaryClassifier(number_of_threads=1, shuffle=False),
'SkipFilter': SkipFilter(count=5),
'TakeFilter': TakeFilter(count=100000),
'IidSpikeDetector': IidSpikeDetector(columns=['F0']),
diff --git a/src/python/tools/manifest.json b/src/python/tools/manifest.json
index a68890fe..c8e6d6e5 100644
--- a/src/python/tools/manifest.json
+++ b/src/python/tools/manifest.json
@@ -21710,6 +21710,79 @@
"ITransformOutput"
]
},
+ {
+ "Name": "Transforms.PermutationFeatureImportance",
+ "Desc": "Permutation Feature Importance (PFI)",
+ "FriendlyName": "PFI",
+ "ShortName": "PFI",
+ "Inputs": [
+ {
+ "Name": "Data",
+ "Type": "DataView",
+ "Desc": "Input dataset",
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "PredictorModel",
+ "Type": "PredictorModel",
+ "Desc": "The path to the model file",
+ "Aliases": [
+ "path"
+ ],
+ "Required": true,
+ "SortOrder": 150.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "UseFeatureWeightFilter",
+ "Type": "Bool",
+ "Desc": "Use feature weights to pre-filter features",
+ "Aliases": [
+ "usefw"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": false
+ },
+ {
+ "Name": "NumberOfExamplesToUse",
+ "Type": "Int",
+ "Desc": "Limit the number of examples to evaluate on",
+ "Aliases": [
+ "numexamples"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": true,
+ "Default": null
+ },
+ {
+ "Name": "PermutationCount",
+ "Type": "Int",
+ "Desc": "The number of permutations to perform",
+ "Aliases": [
+ "permutations"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": 1
+ }
+ ],
+ "Outputs": [
+ {
+ "Name": "Metrics",
+ "Type": "DataView",
+ "Desc": "The PFI metrics"
+ }
+ ],
+ "InputKind": [
+ "ITransformInput"
+ ]
+ },
{
"Name": "Transforms.PredictedLabelColumnOriginalValueConverter",
"Desc": "Transforms a predicted label column to its original values, unless it is of type bool.",