diff --git a/src/python/nimbusml/base_predictor.py b/src/python/nimbusml/base_predictor.py index bfa2813f..588971b3 100644 --- a/src/python/nimbusml/base_predictor.py +++ b/src/python/nimbusml/base_predictor.py @@ -12,7 +12,6 @@ from sklearn.base import BaseEstimator from sklearn.utils.multiclass import unique_labels -from sklearn.utils.validation import check_is_fitted from . import Pipeline from .internal.core.base_pipeline_item import BasePipelineItem @@ -49,8 +48,6 @@ def fit(self, X, y=None, **params): "Classifier can't train when only one class is " "present.") self.classes_ = unique_classes - self.X_ = X - self.y_ = y # Clear cached summary since it should not # retain its value after a new call to fit @@ -69,13 +66,24 @@ def fit(self, X, y=None, **params): set_shape(self, X) return self + @property + def _is_fitted(self): + """ + Tells if the predictor was trained. + """ + return (hasattr(self, 'model_') and + self.model_ and + os.path.isfile(self.model_)) + @trace def _invoke_inference_method(self, method, X, **params): """ Returns predictions. Can be predicted labels, probabilities or else decision values. """ - check_is_fitted(self, ["X_", "y_"]) + if not self._is_fitted: + raise ValueError("Model is not fitted. " + "fit() must be called before {}.".format(method)) # Check that the input is of the same shape as the one passed # during