From 1f31d312c89c4634afbd0118b5d7e09221666d55 Mon Sep 17 00:00:00 2001 From: "pieths.dev@gmail.com" Date: Mon, 22 Jul 2019 14:42:46 -0700 Subject: [PATCH 1/4] Add classes_ to Pipeline and/or predictor when calling predict_proba. --- src/python/nimbusml/base_predictor.py | 11 --- src/python/nimbusml/pipeline.py | 29 ++++++-- .../test_predict_proba_decision_function.py | 68 +++++++++++++++++++ 3 files changed, 91 insertions(+), 17 deletions(-) diff --git a/src/python/nimbusml/base_predictor.py b/src/python/nimbusml/base_predictor.py index cbae2195..f33f746c 100644 --- a/src/python/nimbusml/base_predictor.py +++ b/src/python/nimbusml/base_predictor.py @@ -11,7 +11,6 @@ import os from sklearn.base import BaseEstimator -from sklearn.utils.multiclass import unique_labels from . import Pipeline from .internal.core.base_pipeline_item import BasePipelineItem @@ -39,16 +38,6 @@ def fit(self, X, y=None, **params): :param y: array-like with shape=[n_samples] :return: self """ - if y is not None and not isinstance(y, ( - str, tuple)) and self.type in set( - ['classifier', 'anomaly']): - unique_classes = unique_labels(y) - if len(unique_classes) < 2: - raise ValueError( - "Classifier can't train when only one class is " - "present.") - self.classes_ = unique_classes - # Clear cached summary since it should not # retain its value after a new call to fit if hasattr(self, 'model_summary_'): diff --git a/src/python/nimbusml/pipeline.py b/src/python/nimbusml/pipeline.py index 93384176..477e32e5 100644 --- a/src/python/nimbusml/pipeline.py +++ b/src/python/nimbusml/pipeline.py @@ -18,6 +18,7 @@ from pandas import DataFrame, Series from scipy.sparse import csr_matrix from sklearn.utils.validation import check_X_y, check_array +from sklearn.utils.multiclass import unique_labels from .internal.core.base_pipeline_item import BasePipelineItem from .internal.entrypoints.data_customtextloader import \ @@ -1111,6 +1112,8 @@ def fit(self, X, y=None, verbose=1, **params): i, n.__class__.__name__), TrainedWarning) break + self._extract_classes(y) + graph, X, y, weights, start_time, schema, telemetry_info, \ learner_features, _, max_slots = self._fit_graph( X, y, verbose, **params) @@ -1923,6 +1926,24 @@ def _predict(self, X, y=None, self._write_csv_time = graph._write_csv_time return out_data, out_metrics + def _extract_classes(self, y): + if ((len(self.steps) > 0) and + (self.last_node.type in ['classifier', 'anomaly']) and + (y is not None) and + (not isinstance(y, (str, tuple)))): + + unique_classes = unique_labels(y) + if len(unique_classes) < 2: + raise ValueError( + "Classifier can't train when only one class is " + "present.") + self._add_classes(unique_classes) + + def _extract_classes_from_headers(self, headers): + classes = [x.replace('Score.', '') for x in headers] + classes = np.array(classes).astype(self.last_node.classes_.dtype) + self._add_classes(classes) + def _add_classes(self, classes): # Create classes_ attribute similar to scikit # Add both to pipeline and ending classifier @@ -1974,11 +1995,7 @@ def predict_proba(self, X, verbose=0, **params): # for multiclass, scores are probabilities pcols = [i for i in scores.columns if i.startswith('Score.')] if len(pcols) > 0: - # [todo]: this is a bug, predict_proba should not change - # internal state of pipeline. - # test check_dict_unchanged() detects that, commenting line - # for now - # self._add_classes([x.replace('Score.', '') for x in pcols]) + self._extract_classes_from_headers(pcols) return scores.loc[:, pcols].values raise ValueError( @@ -2019,7 +2036,7 @@ def decision_function(self, X, verbose=0, **params): # for multiclass with n_classes > 2 if len(scols) > 2: - self._add_classes([x.replace('Score.', '') for x in scols]) + self._extract_classes_from_headers(scols) return scores.loc[:, scols].values raise ValueError( diff --git a/src/python/nimbusml/tests/pipeline/test_predict_proba_decision_function.py b/src/python/nimbusml/tests/pipeline/test_predict_proba_decision_function.py index 21aa24a0..5da37807 100644 --- a/src/python/nimbusml/tests/pipeline/test_predict_proba_decision_function.py +++ b/src/python/nimbusml/tests/pipeline/test_predict_proba_decision_function.py @@ -28,6 +28,14 @@ X_train, X_test, y_train, y_test = \ train_test_split(features, labels) +# 3 class dataset with integer labels +np.random.seed(0) +df = get_dataset("iris").as_df() +df.drop(['Species'], inplace=True, axis=1) +features_3class_int, labels_3class_int = split_features_and_label(df, 'Label') +X_train_3class_int, X_test_3class_int, y_train_3class_int, y_test_3class_int = \ + train_test_split(features_3class_int, labels_3class_int) + # 3 class dataset with string labels np.random.seed(0) df = get_dataset("iris").as_df() @@ -115,6 +123,36 @@ def test_pass_predict_proba_multiclass_3class(self): err_msg=invalid_decision_function_output) assert_equal(set(clf.classes_), {'Blue', 'Green', 'Red'}) + def test_pass_predict_proba_multiclass_with_pipeline_adds_classes(self): + clf = FastLinearClassifier(number_of_threads=1) + pipeline = Pipeline([clf]) + pipeline.fit(X_train_3class, y_train_3class) + + expected_classes = {'Blue', 'Green', 'Red'} + assert_equal(set(clf.classes_), expected_classes) + assert_equal(set(pipeline.classes_), expected_classes) + + s = pipeline.predict_proba(X_test_3class).sum() + assert_almost_equal( + s, + 38.0, + decimal=4, + err_msg=invalid_decision_function_output) + + assert_equal(set(clf.classes_), expected_classes) + assert_equal(set(pipeline.classes_), expected_classes) + + def test_pass_predict_proba_multiclass_3class_retains_classes_type(self): + clf = FastLinearClassifier(number_of_threads=1) + clf.fit(X_train_3class_int, y_train_3class_int) + s = clf.predict_proba(X_test_3class_int).sum() + assert_almost_equal( + s, + 38.0, + decimal=4, + err_msg=invalid_decision_function_output) + assert_equal(set(clf.classes_), {0, 1, 2}) + def test_fail_predict_proba_multiclass_with_pipeline(self): check_unsupported_predict_proba(self, Pipeline( [NaiveBayesClassifier()]), X_train, y_train, X_test) @@ -174,6 +212,36 @@ def test_pass_decision_function_multiclass_3class(self): err_msg=invalid_decision_function_output) assert_equal(set(clf.classes_), {'Blue', 'Green', 'Red'}) + def test_pass_decision_function_multiclass_with_pipeline_adds_classes(self): + clf = FastLinearClassifier(number_of_threads=1) + pipeline = Pipeline([clf]) + pipeline.fit(X_train_3class, y_train_3class) + + expected_classes = {'Blue', 'Green', 'Red'} + assert_equal(set(clf.classes_), expected_classes) + assert_equal(set(pipeline.classes_), expected_classes) + + s = pipeline.decision_function(X_test_3class).sum() + assert_almost_equal( + s, + 38.0, + decimal=4, + err_msg=invalid_decision_function_output) + + assert_equal(set(clf.classes_), expected_classes) + assert_equal(set(pipeline.classes_), expected_classes) + + def test_pass_decision_function_multiclass_3class_retains_classes_type(self): + clf = FastLinearClassifier(number_of_threads=1) + clf.fit(X_train_3class_int, y_train_3class_int) + s = clf.decision_function(X_test_3class_int).sum() + assert_almost_equal( + s, + 38.0, + decimal=4, + err_msg=invalid_decision_function_output) + assert_equal(set(clf.classes_), {0, 1, 2}) + def test_fail_decision_function_multiclass(self): check_unsupported_decision_function( self, LogisticRegressionClassifier(), X_train, y_train, X_test) From 1fd3f738d679713df2b2daffab994b669d49cd21 Mon Sep 17 00:00:00 2001 From: "pieths.dev@gmail.com" Date: Mon, 22 Jul 2019 17:00:32 -0700 Subject: [PATCH 2/4] Exclude LogisticRegressionClassifier and FastLinearClassifier from being run through the test_dict_unchanged test in test_estimator_checks. --- src/python/tests/test_estimator_checks.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/python/tests/test_estimator_checks.py b/src/python/tests/test_estimator_checks.py index f101b1ec..07886a77 100644 --- a/src/python/tests/test_estimator_checks.py +++ b/src/python/tests/test_estimator_checks.py @@ -71,7 +71,12 @@ 'check_regressors_int', # bug decision function shape should be 1 # dimensional arrays, tolerance - 'FastLinearClassifier': 'check_classifiers_train', + 'FastLinearClassifier': 'check_classifiers_train' + # Everything is working as expected. Comparing numpy + # arrays doesn't work the same way as comparing python lists. + # The truth value of an array with more than one element + # is ambiguous. Use a.any() or a.all() + 'check_dict_unchanged', 'FastForestRegressor': 'check_fit_score_takes_y', # bug # bug in decision_function 'FastTreesBinaryClassifier': @@ -86,7 +91,12 @@ 'Indicator': 'check_estimators_dtypes', # tolerance - 'LogisticRegressionClassifier': 'check_classifiers_train', + 'LogisticRegressionClassifier': 'check_classifiers_train,' + # Everything is working as expected. Comparing numpy + # arrays doesn't work the same way as comparing python lists. + # The truth value of an array with more than one element + # is ambiguous. Use a.any() or a.all() + 'check_dict_unchanged', # bug decision function shape, prediction bug 'NaiveBayesClassifier': 'check_classifiers_train, check_classifiers_classes', From 5fcceffd4079beff096e75a2b167af31b62e59a1 Mon Sep 17 00:00:00 2001 From: "pieths.dev@gmail.com" Date: Wed, 24 Jul 2019 09:23:41 -0700 Subject: [PATCH 3/4] Update test_estimator_checks.py to skip the check_dict_unchanged test for any estimator which supports predict_proba or decision_function. --- src/python/tests/test_estimator_checks.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/python/tests/test_estimator_checks.py b/src/python/tests/test_estimator_checks.py index 07886a77..ae826b6e 100644 --- a/src/python/tests/test_estimator_checks.py +++ b/src/python/tests/test_estimator_checks.py @@ -71,12 +71,7 @@ 'check_regressors_int', # bug decision function shape should be 1 # dimensional arrays, tolerance - 'FastLinearClassifier': 'check_classifiers_train' - # Everything is working as expected. Comparing numpy - # arrays doesn't work the same way as comparing python lists. - # The truth value of an array with more than one element - # is ambiguous. Use a.any() or a.all() - 'check_dict_unchanged', + 'FastLinearClassifier': 'check_classifiers_train', 'FastForestRegressor': 'check_fit_score_takes_y', # bug # bug in decision_function 'FastTreesBinaryClassifier': @@ -91,12 +86,7 @@ 'Indicator': 'check_estimators_dtypes', # tolerance - 'LogisticRegressionClassifier': 'check_classifiers_train,' - # Everything is working as expected. Comparing numpy - # arrays doesn't work the same way as comparing python lists. - # The truth value of an array with more than one element - # is ambiguous. Use a.any() or a.all() - 'check_dict_unchanged', + 'LogisticRegressionClassifier': 'check_classifiers_train,', # bug decision function shape, prediction bug 'NaiveBayesClassifier': 'check_classifiers_train, check_classifiers_classes', @@ -300,6 +290,14 @@ def load_json(file_path): estimator = estimator << 'F0' for check in _yield_all_checks(class_name, estimator): + # Skip check_dict_unchanged for estimators which + # update the classes_ attribute. For more details + # see https://github.com/microsoft/NimbusML/pull/200 + if (check.__name__ == 'check_dict_unchanged') and \ + (hasattr(estimator, 'predict_proba') or + hasattr(estimator, 'decision_function')): + continue + if check.__name__ in OMITTED_CHECKS_ALWAYS: continue if 'Binary' in class_name and check.__name__ in NOBINARY_CHECKS: From 6a7bb1b79809b2ad7c5d901e8aced734d42da2e0 Mon Sep 17 00:00:00 2001 From: "pieths.dev@gmail.com" Date: Wed, 24 Jul 2019 09:31:52 -0700 Subject: [PATCH 4/4] Remove unnecessary comma in test_estimator_checks. --- src/python/tests/test_estimator_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/tests/test_estimator_checks.py b/src/python/tests/test_estimator_checks.py index ae826b6e..6c8ef557 100644 --- a/src/python/tests/test_estimator_checks.py +++ b/src/python/tests/test_estimator_checks.py @@ -86,7 +86,7 @@ 'Indicator': 'check_estimators_dtypes', # tolerance - 'LogisticRegressionClassifier': 'check_classifiers_train,', + 'LogisticRegressionClassifier': 'check_classifiers_train', # bug decision function shape, prediction bug 'NaiveBayesClassifier': 'check_classifiers_train, check_classifiers_classes',