diff --git a/release-next.md b/release-next.md index 504e4763..7fc64ba9 100644 --- a/release-next.md +++ b/release-next.md @@ -12,6 +12,22 @@ xf.fit(train_df) result = xf.transform(train_df, as_csr=True) ``` + +- **Permutation Feature Importance for model interpretibility.** + + [PR#279](https://github.com/microsoft/NimbusML/pull/279) + Adds `permutation_feature_importance()` method to `Pipeline` and + predictor estimators, enabling evaluation of model-wide feature + importances on any dataset with same schema as the dataset used + to fit the `Pipeline`. + + ```python + pipe = Pipeline([ + LogisticRegressionBinaryClassifier(label='label', feature=['feature']) + ]) + pipe.fit(data) + pipe.permutation_feature_importance(data) + ``` - **Initial implementation of LpScaler.** diff --git a/src/DotNetBridge/DotNetBridge.csproj b/src/DotNetBridge/DotNetBridge.csproj index db737e30..822db6aa 100644 --- a/src/DotNetBridge/DotNetBridge.csproj +++ b/src/DotNetBridge/DotNetBridge.csproj @@ -32,19 +32,19 @@ all runtime; build; native; contentfiles; analyzers - - - - - - - - - - - + + + + + + + + + + + - + diff --git a/src/Platforms/build.csproj b/src/Platforms/build.csproj index 10f89106..3db67054 100644 --- a/src/Platforms/build.csproj +++ b/src/Platforms/build.csproj @@ -11,19 +11,19 @@ - - - - - - - - - - - + + + + + + + + + + + - + diff --git a/src/python/nimbusml.pyproj b/src/python/nimbusml.pyproj index 345e7ccf..768636af 100644 --- a/src/python/nimbusml.pyproj +++ b/src/python/nimbusml.pyproj @@ -181,6 +181,7 @@ + @@ -436,6 +437,7 @@ + @@ -682,6 +684,7 @@ + diff --git a/src/python/nimbusml/base_predictor.py b/src/python/nimbusml/base_predictor.py index f33f746c..538e7b5c 100644 --- a/src/python/nimbusml/base_predictor.py +++ b/src/python/nimbusml/base_predictor.py @@ -88,7 +88,13 @@ def _invoke_inference_method(self, method, X, **params): @trace def get_feature_contributions(self, X, **params): - return self._invoke_inference_method('get_feature_contributions', X, **params) + return self._invoke_inference_method('get_feature_contributions', + X, **params) + + @trace + def permutation_feature_importance(self, X, **params): + return self._invoke_inference_method('permutation_feature_importance', + X, **params) @trace def predict(self, X, **params): diff --git a/src/python/nimbusml/examples/PermutationFeatureImportance.py b/src/python/nimbusml/examples/PermutationFeatureImportance.py new file mode 100644 index 00000000..44a476ba --- /dev/null +++ b/src/python/nimbusml/examples/PermutationFeatureImportance.py @@ -0,0 +1,173 @@ +############################################################################### +# Permutation Feature Importance (PFI) + +# Permutation feature importance (PFI) is a technique to determine the global +# importance of features in a trained machine learning model. PFI is a simple +# yet powerful technique motivated by Breiman in section 10 of his Random +# Forests paper (Machine Learning, 2001). The advantage of the PFI method is +# that it is model agnostic - it works with any model that can be evaluated - +# and it can use any dataset, not just the training set, to compute feature +# importance metrics. + +# PFI works by taking a labeled dataset, choosing a feature, and permuting the +# values for that feature across all the examples, so that each example now has +# a random value for the feature and the original values for all other +# features. The evaluation metric (e.g. NDCG) is then calculated for this +# modified dataset, and the change in the evaluation metric from the original +# dataset is computed. The larger the change in the evaluation metric, the more +# important the feature is to the model, i.e. the most important features are +# those that the model is most sensitive to. PFI works by performing this +# permutation analysis across allthe features of a model, one after another. + +# PFI is supported for binary classifiers, classifiers, regressors, and +# rankers. + +from nimbusml import Pipeline, FileDataStream +from nimbusml.datasets import get_dataset +from nimbusml.ensemble import LightGbmRanker +from nimbusml.feature_extraction.categorical import OneHotVectorizer +from nimbusml.linear_model import LogisticRegressionBinaryClassifier, \ + FastLinearClassifier, FastLinearRegressor +from nimbusml.preprocessing import ToKey +from numpy.testing import assert_almost_equal + +# data input (as a FileDataStream) +adult_path = get_dataset('uciadult_train').as_filepath() +classification_data = FileDataStream.read_csv(adult_path) +print(classification_data.head()) +# label workclass education ... capital-loss hours-per-week +# 0 0 Private 11th ... 0 40 +# 1 0 Private HS-grad ... 0 50 +# 2 1 Local-gov Assoc-acdm ... 0 40 +# 3 1 Private Some-college ... 0 40 +# 4 0 ? Some-college ... 0 30 + +###################################### +# PFI for Binary Classification models +###################################### +# define the training pipeline with a binary classifier +binary_pipeline = Pipeline([ + OneHotVectorizer(columns=['education']), + LogisticRegressionBinaryClassifier( + feature=['age', 'education'], label='label')]) + +# train the model +binary_model = binary_pipeline.fit(classification_data) + +# get permutation feature importance +binary_pfi = binary_model.permutation_feature_importance(classification_data) + +# Print PFI for each feature, ordered by most important features w.r.t. AUC. +# Since AUC is an increasing metric, the highest negative changes indicate the +# most important features. +print("============== PFI for Binary Classification Model ==============") +print(binary_pfi.sort_values('AreaUnderRocCurve').head()) +# FeatureName AreaUnderRocCurve AreaUnderRocCurve.StdErr ... +# 0 age -0.081604 0.0 ... +# 6 education.Prof-school -0.012964 0.0 ... +# 10 education.Doctorate -0.012863 0.0 ... +# 8 education.Bachelors -0.010593 0.0 ... +# 2 education.HS-grad -0.005918 0.0 ... + + +############################### +# PFI for Classification models +############################### +# define the training pipeline with a classifier +# use 1 thread and no shuffling to force determinism +multiclass_pipeline = Pipeline([ + OneHotVectorizer(columns=['education']), + FastLinearClassifier(feature=['age', 'education'], label='label', + number_of_threads=1, shuffle=False)]) + +# train the model +multiclass_model = multiclass_pipeline.fit(classification_data) + +# get permutation feature importance +multiclass_pfi = multiclass_model.permutation_feature_importance(classification_data) + +# Print PFI for each feature, ordered by most important features w.r.t. Macro +# accuracy. Since Macro accuracy is an increasing metric, the highest negative +# changes indicate the most important features. +print("================== PFI for Classification Model ==================") +print(multiclass_pfi.sort_values('MacroAccuracy').head()) +# FeatureName MacroAccuracy ... MicroAccuracy ... +# 10 education.Doctorate -0.028233 ... -0.020 ... +# 0 age -0.001750 ... 0.002 ... +# 6 education.Prof-school -0.001750 ... 0.002 ... +# 9 education.Masters -0.001299 ... -0.002 ... +# 1 education.11th 0.000000 ... 0.000 ... + +########################### +# PFI for Regression models +########################### +# load input data +infert_path = get_dataset('infert').as_filepath() +regression_data = FileDataStream.read_csv(infert_path) +print(regression_data.head()) +# age case education induced parity ... row_num spontaneous ... +# 0 26 1 0-5yrs 1 6 ... 1 2 ... +# 1 42 1 0-5yrs 1 1 ... 2 0 ... +# 2 39 1 0-5yrs 2 6 ... 3 0 ... +# 3 34 1 0-5yrs 2 4 ... 4 0 ... +# 4 35 1 6-11yrs 1 3 ... 5 1 ... + +# define the training pipeline with a regressor +# use 1 thread and no shuffling to force determinism +regression_pipeline = Pipeline([ + OneHotVectorizer(columns=['education']), + FastLinearRegressor(feature=['induced', 'education'], label='age', + number_of_threads=1, shuffle=False)]) + +# train the model +regression_model = regression_pipeline.fit(regression_data) + +# get permutation feature importance +regression_pfi = regression_model.permutation_feature_importance(regression_data) + +# print PFI for each feaure, ordered by most important features w.r.t. MAE. +# Since MAE is a decreasing metric, the highest positive changes indicate the +# most important features. +print("==================== PFI for Regression Model ====================") +print(regression_pfi.sort_values('MeanAbsoluteError', ascending=False).head()) +# FeatureName MeanAbsoluteError ... RSquared RSquared.StdErr +#3 education.12+ yrs 0.393451 ... -0.146338 0.0 +#0 induced 0.085804 ... -0.026168 0.0 +#1 education.0-5yrs 0.064460 ... -0.027587 0.0 +#2 education.6-11yrs -0.000047 ... 0.000059 0.0 + +######################## +# PFI for Ranking models +######################## +# load input data +ticket_path = get_dataset('gen_tickettrain').as_filepath() +ranking_data = FileDataStream.read_csv(ticket_path) +print(ranking_data.head()) +# rank group carrier price Class dep_day nbr_stops duration +# 0 2 1 AA 240 3 1 0 12.0 +# 1 1 1 AA 300 3 0 1 15.0 +# 2 1 1 AA 360 3 0 2 18.0 +# 3 0 1 AA 540 2 0 0 12.0 +# 4 1 1 AA 600 2 0 1 15.0 + +# define the training pipeline with a ranker +ranking_pipeline = Pipeline([ + ToKey(columns=['group']), + LightGbmRanker(feature=['Class', 'dep_day', 'duration'], + label='rank', group_id='group')]) + +# train the model +ranking_model = ranking_pipeline.fit(ranking_data) + +# get permutation feature importance +ranking_pfi = ranking_model.permutation_feature_importance(ranking_data) + +# Print PFI for each feature, ordered by most important features w.r.t. DCG@1. +# Since DCG is an increasing metric, the highest negative changes indicate the +# most important features. +print("===================== PFI for Ranking Model =====================") +print(ranking_pfi.sort_values('DCG@1').head()) +# Feature DCG@1 DCG@2 DCG@3 ... NDCG@1 NDCG@2 ... +# 0 Class -4.869096 -7.030914 -5.948893 ... -0.420238 -0.407281 ... +# 2 duration -2.344379 -3.595958 -3.956632 ... -0.232143 -0.231539 ... +# 1 dep_day 0.000000 0.000000 0.000000 ... 0.000000 0.000000 ... diff --git a/src/python/nimbusml/internal/entrypoints/transforms_permutationfeatureimportance.py b/src/python/nimbusml/internal/entrypoints/transforms_permutationfeatureimportance.py new file mode 100644 index 00000000..18ff2e51 --- /dev/null +++ b/src/python/nimbusml/internal/entrypoints/transforms_permutationfeatureimportance.py @@ -0,0 +1,81 @@ +# - Generated by tools/entrypoint_compiler.py: do not edit by hand +""" +Transforms.PermutationFeatureImportance +""" + +import numbers + +from ..utils.entrypoints import EntryPoint +from ..utils.utils import try_set, unlist + + +def transforms_permutationfeatureimportance( + data, + predictor_model, + metrics=None, + use_feature_weight_filter=False, + number_of_examples_to_use=None, + permutation_count=1, + **params): + """ + **Description** + Permutation Feature Importance (PFI) + + :param data: Input dataset (inputs). + :param predictor_model: The path to the model file (inputs). + :param use_feature_weight_filter: Use feature weights to pre- + filter features (inputs). + :param number_of_examples_to_use: Limit the number of examples to + evaluate on (inputs). + :param permutation_count: The number of permutations to perform + (inputs). + :param metrics: The PFI metrics (outputs). + """ + + entrypoint_name = 'Transforms.PermutationFeatureImportance' + inputs = {} + outputs = {} + + if data is not None: + inputs['Data'] = try_set( + obj=data, + none_acceptable=False, + is_of_type=str) + if predictor_model is not None: + inputs['PredictorModel'] = try_set( + obj=predictor_model, + none_acceptable=False, + is_of_type=str) + if use_feature_weight_filter is not None: + inputs['UseFeatureWeightFilter'] = try_set( + obj=use_feature_weight_filter, + none_acceptable=True, + is_of_type=bool) + if number_of_examples_to_use is not None: + inputs['NumberOfExamplesToUse'] = try_set( + obj=number_of_examples_to_use, + none_acceptable=True, + is_of_type=numbers.Real) + if permutation_count is not None: + inputs['PermutationCount'] = try_set( + obj=permutation_count, + none_acceptable=True, + is_of_type=numbers.Real) + if metrics is not None: + outputs['Metrics'] = try_set( + obj=metrics, + none_acceptable=False, + is_of_type=str) + + input_variables = { + x for x in unlist(inputs.values()) + if isinstance(x, str) and x.startswith("$")} + output_variables = { + x for x in unlist(outputs.values()) + if isinstance(x, str) and x.startswith("$")} + + entrypoint = EntryPoint( + name=entrypoint_name, inputs=inputs, outputs=outputs, + input_variables=input_variables, + output_variables=output_variables) + return entrypoint diff --git a/src/python/nimbusml/pipeline.py b/src/python/nimbusml/pipeline.py index b3be72f8..fa2542c2 100644 --- a/src/python/nimbusml/pipeline.py +++ b/src/python/nimbusml/pipeline.py @@ -60,6 +60,8 @@ transforms_modelcombiner from .internal.entrypoints.transforms_optionalcolumncreator import \ transforms_optionalcolumncreator +from .internal.entrypoints.transforms_permutationfeatureimportance import \ + transforms_permutationfeatureimportance from .internal.entrypoints \ .transforms_predictedlabelcolumnoriginalvalueconverter import \ transforms_predictedlabelcolumnoriginalvalueconverter @@ -1738,7 +1740,7 @@ def get_feature_contributions(self, X, top=10, bottom=10, verbose=0, to report. :param bottom: The number of negative contributions with highest magnitude to report. - :return: dataframe of containing the raw data, predicted label, score, + :return: dataframe containing the raw data, predicted label, score, probabilities, and feature contributions. """ self.verbose = verbose @@ -1855,6 +1857,181 @@ def get_schema(self, verbose=0, **params): return out_data + @trace + def permutation_feature_importance(self, X, number_of_examples=None, + permutation_count=1, + filter_zero_weight_features=False, + verbose=0, as_binary_data_stream=False, + **params): + """ + Permutation feature importance (PFI) is a technique to determine the + global importance of features in a trained machine learning model. PFI + is a simple yet powerful technique motivated by Breiman in section 10 + of his Random Forests paper (Machine Learning, 2001). The advantage of + the PFI method is that it is model agnostic - it works with any model + that can be evaluated - and it can use any dataset, not just the + training set, to compute feature importance metrics. + + PFI works by taking a labeled dataset, choosing a feature, and + permuting the values for that feature across all the examples, so that + each example now has a random value for the feature and the original + values for all other features. The evaluation metric (e.g. NDCG) is + then calculated for this modified dataset, and the change in the + evaluation metric from the original dataset is computed. The larger the + change in the evaluation metric, the more important the feature is to + the model, i.e. the most important features are those that the model is + most sensitive to. PFI works by performing this permutation analysis + across all the features of a model, one after another. + + Note that for increasing metrics (e.g. AUC, accuracy, R-Squared, NDCG), + the most important features will be those with the highest negative + mean change in the metric. Conversely, for decreasing metrics (e.g. + Mean Squared Error, Log loss), the most important features will be + those with the highest positive mean change in the metric. + + PFI is supported for binary classifiers, classifiers, regressors, and + rankers. + + The mean changes and statndard errors of the means are evaluated for + the following metrics are evaluated for PFI: + + * Binary Classification: + + * Area under ROC curve (AUC) + * Accuracy + * Positive precision + * Positive recall + * Negative precision + * Negative recall + * F1 score + * Area under Precision-Recall curve (AUPRC) + + * Multiclass classification: + + * Macro accuracy + * Micro accuracy + * Log loss + * Log loss reduction + * Top k accuracy + * Per-class log loss + + * Regression: + + * Mean absolute error (MAE) + * Mean squared error (MSE) + * Root mean squared error (RMSE) + * Loss function + * R-Squared + + * Ranking + + * Discounted cumulative gains (DCG) @1, @2, and @3 + * Normalized discounted cumulative gains (NDCG) @1, @2, and @3 + + **Reference** + + `Breiman, L. Random Forests. Machine Learning (2001) 45: 5. + `_ + + :param X: {array-like [n_samples, n_features], + :py:class:`nimbusml.FileDataStream` } + :param number_of_examples: Limit the number of examples to evaluate on. + ``'None'`` means all examples in the dataset are used. + :param permutation_count: The number of permutations to perform. + :filter_zero_weight_features: Pre-filter features with zero weight. PFI + will not be evaluated on these features. + :return: dataframe containing the mean change in evaluation metrics and + standard error of the mean for each feature. Features with the + largest change in a metric are the most important in the model with + respect to that metric. + """ + self.verbose = verbose + + if not self._is_fitted: + raise ValueError( + "Model is not fitted. Train or load a model before test().") + + X, _, _, _, _, schema, _, _ = self._preprocess_X_y(X) + + all_nodes = [] + inputs = dict([('data', ''), ('predictor_model', self.model)]) + if isinstance(X, FileDataStream): + importtext_node = data_customtextloader( + input_file="$file", + data="$data", + custom_schema=schema.to_string( + add_sep=True)) + all_nodes = [importtext_node] + inputs = dict([('file', ''), ('predictor_model', self.model)]) + + pfi_node = transforms_permutationfeatureimportance( + data="$data", + predictor_model="$predictor_model", + metrics="$output_data", + permutation_count=permutation_count, + number_of_examples_to_use=number_of_examples, + use_feature_weight_filter=filter_zero_weight_features) + + all_nodes.extend([pfi_node]) + + outputs = dict(output_data="") + + data_output_format = DataOutputFormat.IDV if as_binary_data_stream \ + else DataOutputFormat.DF, + + graph = Graph( + inputs, + outputs, + data_output_format, + *all_nodes) + + class_name = type(self).__name__ + method_name = inspect.currentframe().f_code.co_name + telemetry_info = ".".join([class_name, method_name]) + + try: + (out_model, out_data, out_metrics, _) = graph.run( + X=X, + random_state=self.random_state, + model=self.model, + verbose=verbose, + telemetry_info=telemetry_info, + **params) + except RuntimeError as e: + raise e + + out_data = self._fix_pfi_columns(out_data) + + return out_data + + def _fix_pfi_columns(self, data): + cols = [] + for i in range(len(data.columns)): + if 'StdErr' in data.columns.values[i]: + if data.columns.values[i][:15] == 'PerClassLogLoss' : + cols.append('PerClassLogLoss' + \ + data.columns.values[i][21:] + '.StdErr') + elif data.columns.values[i][:10] == 'Discounted': + pos = int(data.columns.values[i][-1]) + 1 + cols.append('DCG@' + str(pos) + '.StdErr') + elif data.columns.values[i][:10] == 'Normalized': + pos = int(data.columns.values[i][-1]) + 1 + cols.append('NDCG@' + str(pos) + '.StdErr') + else: + cols.append(data.columns.values[i][:-6] + '.StdErr') + else: + if data.columns.values[i][:10] == 'Discounted': + pos = int(data.columns.values[i][26]) + 1 + cols.append('DCG@' + str(pos)) + elif data.columns.values[i][:10] == 'Normalized': + pos = int(data.columns.values[i][36]) + 1 + cols.append('NDCG@' + str(pos)) + else: + cols.append(data.columns.values[i]) + data.columns = cols + + return data + @trace def _predict(self, X, y=None, evaltype='auto', group_id=None, diff --git a/src/python/nimbusml/tests/pipeline/test_permutation_feature_importance.py b/src/python/nimbusml/tests/pipeline/test_permutation_feature_importance.py new file mode 100644 index 00000000..347b2798 --- /dev/null +++ b/src/python/nimbusml/tests/pipeline/test_permutation_feature_importance.py @@ -0,0 +1,125 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------------------------- +import os +import unittest + +from nimbusml import FileDataStream +from nimbusml import Pipeline +from nimbusml.datasets import get_dataset +from nimbusml.ensemble import LightGbmRanker +from nimbusml.feature_extraction.categorical import OneHotVectorizer +from nimbusml.linear_model import LogisticRegressionBinaryClassifier, \ + FastLinearClassifier, FastLinearRegressor +from nimbusml.preprocessing import ToKey +from numpy.testing import assert_almost_equal +from pandas.testing import assert_frame_equal + +class TestPermutationFeatureImportance(unittest.TestCase): + + @classmethod + def setUpClass(self): + adult_path = get_dataset('uciadult_train').as_filepath() + self.classification_data = FileDataStream.read_csv(adult_path) + binary_pipeline = Pipeline([ + OneHotVectorizer(columns=['education']), + LogisticRegressionBinaryClassifier( + feature=['age', 'education'], label='label', + number_of_threads=1)]) + self.binary_model = binary_pipeline.fit(self.classification_data) + self.binary_pfi = self.binary_model.permutation_feature_importance(self.classification_data) + classifier_pipeline = Pipeline([ + OneHotVectorizer(columns=['education']), + FastLinearClassifier(feature=['age', 'education'], label='label', + number_of_threads=1, shuffle=False)]) + self.classifier_model = classifier_pipeline.fit(self.classification_data) + self.classifier_pfi = self.classifier_model.permutation_feature_importance(self.classification_data) + + infert_path = get_dataset('infert').as_filepath() + self.regression_data = FileDataStream.read_csv(infert_path) + regressor_pipeline = Pipeline([ + OneHotVectorizer(columns=['education']), + FastLinearRegressor(feature=['induced', 'education'], label='age', + number_of_threads=1, shuffle=False)]) + self.regressor_model = regressor_pipeline.fit(self.regression_data) + self.regressor_pfi = self.regressor_model.permutation_feature_importance(self.regression_data) + + ticket_path = get_dataset('gen_tickettrain').as_filepath() + self.ranking_data = FileDataStream.read_csv(ticket_path) + ranker_pipeline = Pipeline([ + ToKey(columns=['group']), + LightGbmRanker(feature=['Class', 'dep_day', 'duration'], + label='rank', group_id='group', + random_state=0, number_of_threads=1)]) + self.ranker_model = ranker_pipeline.fit(self.ranking_data) + self.ranker_pfi = self.ranker_model.permutation_feature_importance(self.ranking_data) + + def test_binary_classifier(self): + assert_almost_equal(self.binary_pfi['AreaUnderRocCurve'].sum(), -0.140824, 6) + assert_almost_equal(self.binary_pfi['PositivePrecision'].sum(), -0.482143, 6) + assert_almost_equal(self.binary_pfi['PositiveRecall'].sum(), -0.0695652, 6) + assert_almost_equal(self.binary_pfi['NegativePrecision'].sum(), -0.0139899, 6) + assert_almost_equal(self.binary_pfi['NegativeRecall'].sum(), -0.00779221, 6) + assert_almost_equal(self.binary_pfi['F1Score'].sum(), -0.126983, 6) + assert_almost_equal(self.binary_pfi['AreaUnderPrecisionRecallCurve'].sum(), -0.19365, 5) + + def test_binary_classifier_from_loaded_model(self): + model_path = "model.zip" + self.binary_model.save_model(model_path) + loaded_model = Pipeline() + loaded_model.load_model(model_path) + pfi_from_loaded = loaded_model.permutation_feature_importance(self.classification_data) + assert_frame_equal(self.binary_pfi, pfi_from_loaded) + os.remove(model_path) + + def test_clasifier(self): + assert_almost_equal(self.classifier_pfi['MacroAccuracy'].sum(), -0.0256352, 6) + assert_almost_equal(self.classifier_pfi['LogLoss'].sum(), 0.158811, 6) + assert_almost_equal(self.classifier_pfi['LogLossReduction'].sum(), -0.29449, 5) + assert_almost_equal(self.classifier_pfi['PerClassLogLoss.0'].sum(), 0.0808459, 6) + assert_almost_equal(self.classifier_pfi['PerClassLogLoss.1'].sum(), 0.419826, 6) + + def test_classifier_from_loaded_model(self): + model_path = "model.zip" + self.classifier_model.save_model(model_path) + loaded_model = Pipeline() + loaded_model.load_model(model_path) + pfi_from_loaded = loaded_model.permutation_feature_importance(self.classification_data) + assert_frame_equal(self.classifier_pfi, pfi_from_loaded) + os.remove(model_path) + + def test_regressor(self): + assert_almost_equal(self.regressor_pfi['MeanAbsoluteError'].sum(), 0.504701, 6) + assert_almost_equal(self.regressor_pfi['MeanSquaredError'].sum(), 5.59277, 5) + assert_almost_equal(self.regressor_pfi['RootMeanSquaredError'].sum(), 0.553048, 6) + assert_almost_equal(self.regressor_pfi['RSquared'].sum(), -0.203612, 6) + + def test_regressor_from_loaded_model(self): + model_path = "model.zip" + self.regressor_model.save_model(model_path) + loaded_model = Pipeline() + loaded_model.load_model(model_path) + pfi_from_loaded = loaded_model.permutation_feature_importance(self.regression_data) + assert_frame_equal(self.regressor_pfi, pfi_from_loaded) + os.remove(model_path) + + def test_ranker(self): + assert_almost_equal(self.ranker_pfi['DCG@1'].sum(), -2.16404, 5) + assert_almost_equal(self.ranker_pfi['DCG@2'].sum(), -3.5294, 4) + assert_almost_equal(self.ranker_pfi['DCG@3'].sum(), -4.9721, 4) + assert_almost_equal(self.ranker_pfi['NDCG@1'].sum(), -0.114286, 6) + assert_almost_equal(self.ranker_pfi['NDCG@2'].sum(), -0.198631, 6) + assert_almost_equal(self.ranker_pfi['NDCG@3'].sum(), -0.236544, 6) + + def test_ranker_from_loaded_model(self): + model_path = "model.zip" + self.ranker_model.save_model(model_path) + loaded_model = Pipeline() + loaded_model.load_model(model_path) + pfi_from_loaded = loaded_model.permutation_feature_importance(self.ranking_data) + assert_frame_equal(self.ranker_pfi, pfi_from_loaded) + os.remove(model_path) + +if __name__ == '__main__': + unittest.main() diff --git a/src/python/tests/test_estimator_checks.py b/src/python/tests/test_estimator_checks.py index e47ce638..7dfd5eb8 100644 --- a/src/python/tests/test_estimator_checks.py +++ b/src/python/tests/test_estimator_checks.py @@ -16,6 +16,7 @@ from nimbusml.ensemble import LightGbmRegressor from nimbusml.feature_extraction.text import NGramFeaturizer from nimbusml.internal.entrypoints._ngramextractor_ngram import n_gram +from nimbusml.linear_model import SgdBinaryClassifier from nimbusml.preprocessing import TensorFlowScorer from nimbusml.preprocessing.filter import SkipFilter, TakeFilter from nimbusml.timeseries import (IidSpikeDetector, IidChangePointDetector, @@ -201,6 +202,7 @@ 'LightGbmRanker': LightGbmRanker( minimum_example_count_per_group=1, minimum_example_count_per_leaf=1), 'NGramFeaturizer': NGramFeaturizer(word_feature_extractor=n_gram()), + 'SgdBinaryClassifier': SgdBinaryClassifier(number_of_threads=1, shuffle=False), 'SkipFilter': SkipFilter(count=5), 'TakeFilter': TakeFilter(count=100000), 'IidSpikeDetector': IidSpikeDetector(columns=['F0']), diff --git a/src/python/tools/manifest.json b/src/python/tools/manifest.json index a68890fe..c8e6d6e5 100644 --- a/src/python/tools/manifest.json +++ b/src/python/tools/manifest.json @@ -21710,6 +21710,79 @@ "ITransformOutput" ] }, + { + "Name": "Transforms.PermutationFeatureImportance", + "Desc": "Permutation Feature Importance (PFI)", + "FriendlyName": "PFI", + "ShortName": "PFI", + "Inputs": [ + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "PredictorModel", + "Type": "PredictorModel", + "Desc": "The path to the model file", + "Aliases": [ + "path" + ], + "Required": true, + "SortOrder": 150.0, + "IsNullable": false + }, + { + "Name": "UseFeatureWeightFilter", + "Type": "Bool", + "Desc": "Use feature weights to pre-filter features", + "Aliases": [ + "usefw" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": false + }, + { + "Name": "NumberOfExamplesToUse", + "Type": "Int", + "Desc": "Limit the number of examples to evaluate on", + "Aliases": [ + "numexamples" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "PermutationCount", + "Type": "Int", + "Desc": "The number of permutations to perform", + "Aliases": [ + "permutations" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": 1 + } + ], + "Outputs": [ + { + "Name": "Metrics", + "Type": "DataView", + "Desc": "The PFI metrics" + } + ], + "InputKind": [ + "ITransformInput" + ] + }, { "Name": "Transforms.PredictedLabelColumnOriginalValueConverter", "Desc": "Transforms a predicted label column to its original values, unless it is of type bool.",