-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-7861][ML] PySpark OneVsRest #12124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
84f292b
a296a86
417d13f
6d30d77
b17cc7b
47bd709
ecdc742
cf4df64
fd4fc11
2fb4e3d
fb337cf
e0cf36f
6002b92
4e95ecb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,18 +15,21 @@ | |
| # limitations under the License. | ||
| # | ||
|
|
||
| import operator | ||
| import warnings | ||
|
|
||
| from pyspark import since | ||
| from pyspark.ml.util import * | ||
| from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper | ||
| from pyspark.ml.param import TypeConverters | ||
| from pyspark.ml import Estimator, Model | ||
| from pyspark.ml.param.shared import * | ||
| from pyspark.ml.regression import ( | ||
| RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels) | ||
| from pyspark.ml.util import * | ||
| from pyspark.ml.wrapper import JavaEstimator, JavaModel | ||
| from pyspark.ml.wrapper import JavaWrapper | ||
| from pyspark.mllib.common import inherit_doc | ||
| from pyspark.sql import DataFrame | ||
|
|
||
| from pyspark.sql.functions import udf, when | ||
| from pyspark.sql.types import ArrayType, DoubleType | ||
| from pyspark.storagelevel import StorageLevel | ||
|
|
||
| __all__ = ['LogisticRegression', 'LogisticRegressionModel', | ||
| 'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary', | ||
|
|
@@ -35,7 +38,8 @@ | |
| 'GBTClassifier', 'GBTClassificationModel', | ||
| 'RandomForestClassifier', 'RandomForestClassificationModel', | ||
| 'NaiveBayes', 'NaiveBayesModel', | ||
| 'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel'] | ||
| 'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel', | ||
| 'OneVsRest', 'OneVsRestModel'] | ||
|
|
||
|
|
||
| @inherit_doc | ||
|
|
@@ -1145,6 +1149,214 @@ def weights(self): | |
| return self._call_java("weights") | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): | ||
| """ | ||
| Reduction of Multiclass Classification to Binary Classification. | ||
| Performs reduction using one against all strategy. | ||
| For a multiclass classification with k classes, train k models (one per class). | ||
| Each example is scored against all k models and the model with highest score | ||
| is picked to label the example. | ||
|
|
||
| >>> from pyspark.sql import Row | ||
| >>> from pyspark.mllib.linalg import Vectors | ||
| >>> df = sc.parallelize([ | ||
| ... Row(label=0.0, features=Vectors.dense(1.0, 0.8)), | ||
| ... Row(label=1.0, features=Vectors.sparse(2, [], [])), | ||
| ... Row(label=2.0, features=Vectors.dense(0.5, 0.5))]).toDF() | ||
| >>> lr = LogisticRegression(maxIter=5, regParam=0.01) | ||
| >>> ovr = OneVsRest(classifier=lr) | ||
| >>> model = ovr.fit(df) | ||
| >>> [x.coefficients for x in model.models] | ||
| [DenseVector([3.3925, 1.8785]), DenseVector([-4.3016, -6.3163]), DenseVector([-4.5855, 6.1785])] | ||
| >>> [x.intercept for x in model.models] | ||
| [-3.6474708290602034, 2.5507881951814495, -1.1016513228162115] | ||
| >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0))]).toDF() | ||
| >>> model.transform(test0).head().prediction | ||
| 1.0 | ||
| >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() | ||
| >>> model.transform(test1).head().prediction | ||
| 0.0 | ||
| >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4))]).toDF() | ||
| >>> model.transform(test2).head().prediction | ||
| 2.0 | ||
|
|
||
| .. versionadded:: 2.0.0 | ||
| """ | ||
|
|
||
| classifier = Param(Params._dummy(), "classifier", "base binary classifier") | ||
|
|
||
| @keyword_only | ||
| def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", | ||
| classifier=None): | ||
| """ | ||
| __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ | ||
| classifier=None) | ||
| """ | ||
| super(OneVsRest, self).__init__() | ||
| kwargs = self.__init__._input_kwargs | ||
| self._set(**kwargs) | ||
|
|
||
| @keyword_only | ||
| @since("2.0.0") | ||
| def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None): | ||
| """ | ||
| setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None): | ||
| Sets params for OneVsRest. | ||
| """ | ||
| kwargs = self.setParams._input_kwargs | ||
| return self._set(**kwargs) | ||
|
|
||
| @since("2.0.0") | ||
| def setClassifier(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`classifier`. | ||
|
|
||
| .. note:: Only LogisticRegression and NaiveBayes are supported now. | ||
| """ | ||
| self._set(classifier=value) | ||
| return self | ||
|
|
||
| @since("2.0.0") | ||
| def getClassifier(self): | ||
| """ | ||
| Gets the value of classifier or its default value. | ||
| """ | ||
| return self.getOrDefault(self.classifier) | ||
|
|
||
| def _fit(self, dataset): | ||
| labelCol = self.getLabelCol() | ||
| featuresCol = self.getFeaturesCol() | ||
| predictionCol = self.getPredictionCol() | ||
| classifier = self.getClassifier() | ||
| assert isinstance(classifier, HasRawPredictionCol),\ | ||
| "Classifier %s doesn't extend from HasRawPredictionCol." % type(classifier) | ||
|
|
||
| numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1 | ||
|
|
||
| multiclassLabeled = dataset.select(labelCol, featuresCol) | ||
|
|
||
| # persist if underlying dataset is not persistent. | ||
| handlePersistence = \ | ||
| dataset.rdd.getStorageLevel() == StorageLevel(False, False, False, False) | ||
| if handlePersistence: | ||
| multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK) | ||
|
|
||
| def trainSingleClass(index): | ||
| binaryLabelCol = "mc2b$" + str(index) | ||
| trainingDataset = multiclassLabeled.withColumn( | ||
| binaryLabelCol, | ||
| when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Uh oh, I just realized this will only work with LogisticRegression and NaiveBayes. With trees, there is no good way to set the metadata from PySpark. We'll need to document that.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But I'm hoping to fix trees to not need metadata for 2.0, if we have time.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that's absolutely a problem since PySpark cannot handle metadata for now. I'll document it. |
||
| paramMap = dict([(classifier.labelCol, binaryLabelCol), | ||
| (classifier.featuresCol, featuresCol), | ||
| (classifier.predictionCol, predictionCol)]) | ||
| return classifier.fit(trainingDataset, paramMap) | ||
|
|
||
| # TODO: Parallel training for all classes. | ||
| models = [trainSingleClass(i) for i in range(numClasses)] | ||
|
|
||
| if handlePersistence: | ||
| multiclassLabeled.unpersist() | ||
|
|
||
| return self._copyValues(OneVsRestModel(models=models)) | ||
|
|
||
| @since("2.0.0") | ||
| def copy(self, extra=None): | ||
| """ | ||
| Creates a copy of this instance with a randomly generated uid | ||
| and some extra params. This creates a deep copy of the embedded paramMap, | ||
| and copies the embedded and extra parameters over. | ||
|
|
||
| :param extra: Extra parameters to copy to the new instance | ||
| :return: Copy of this instance | ||
| """ | ||
| if extra is None: | ||
| extra = dict() | ||
| newOvr = Params.copy(self, extra) | ||
| if self.isSet(self.classifier): | ||
| newOvr.setClassifier(self.getClassifier().copy(extra)) | ||
| return newOvr | ||
|
|
||
|
|
||
| class OneVsRestModel(Model, HasFeaturesCol, HasLabelCol, HasPredictionCol): | ||
| """ | ||
| Model fitted by OneVsRest. | ||
| This stores the models resulting from training k binary classifiers: one for each class. | ||
| Each example is scored against all k models, and the model with the highest score | ||
| is picked to label the example. | ||
|
|
||
| .. versionadded:: 2.0.0 | ||
| """ | ||
|
|
||
| def __init__(self, models): | ||
| super(OneVsRestModel, self).__init__() | ||
| self.models = models | ||
|
|
||
| def _transform(self, dataset): | ||
| # determine the input columns: these need to be passed through | ||
| origCols = dataset.columns | ||
|
|
||
| # add an accumulator column to store predictions of all the models | ||
| accColName = "mbc$acc" + str(uuid.uuid4()) | ||
| initUDF = udf(lambda _: [], ArrayType(DoubleType())) | ||
| newDataset = dataset.withColumn(accColName, initUDF(dataset[origCols[0]])) | ||
|
|
||
| # persist if underlying dataset is not persistent. | ||
| handlePersistence = \ | ||
| dataset.rdd.getStorageLevel() == StorageLevel(False, False, False, False) | ||
| if handlePersistence: | ||
| newDataset.persist(StorageLevel.MEMORY_AND_DISK) | ||
|
|
||
| # update the accumulator column with the result of prediction of models | ||
| aggregatedDataset = newDataset | ||
| for index, model in enumerate(self.models): | ||
| rawPredictionCol = model._call_java("getRawPredictionCol") | ||
| columns = origCols + [rawPredictionCol, accColName] | ||
|
|
||
| # add temporary column to store intermediate scores and update | ||
| tmpColName = "mbc$tmp" + str(uuid.uuid4()) | ||
| updateUDF = udf( | ||
| lambda predictions, prediction: predictions + [prediction.tolist()[1]], | ||
| ArrayType(DoubleType())) | ||
| transformedDataset = model.transform(aggregatedDataset).select(*columns) | ||
| updatedDataset = transformedDataset.withColumn( | ||
| tmpColName, | ||
| updateUDF(transformedDataset[accColName], transformedDataset[rawPredictionCol])) | ||
| newColumns = origCols + [tmpColName] | ||
|
|
||
| # switch out the intermediate column with the accumulator column | ||
| aggregatedDataset = updatedDataset\ | ||
| .select(*newColumns).withColumnRenamed(tmpColName, accColName) | ||
|
|
||
| if handlePersistence: | ||
| newDataset.unpersist() | ||
|
|
||
| # output the index of the classifier with highest confidence as prediction | ||
| labelUDF = udf( | ||
| lambda predictions: float(max(enumerate(predictions), key=operator.itemgetter(1))[0]), | ||
| DoubleType()) | ||
|
|
||
| # output label and label metadata as prediction | ||
| return aggregatedDataset.withColumn( | ||
| self.getPredictionCol(), labelUDF(aggregatedDataset[accColName])).drop(accColName) | ||
|
|
||
| @since("2.0.0") | ||
| def copy(self, extra=None): | ||
| """ | ||
| Creates a copy of this instance with a randomly generated uid | ||
| and some extra params. This creates a deep copy of the embedded paramMap, | ||
| and copies the embedded and extra parameters over. | ||
|
|
||
| :param extra: Extra parameters to copy to the new instance | ||
| :return: Copy of this instance | ||
| """ | ||
| if extra is None: | ||
| extra = dict() | ||
| newModel = Params.copy(self, extra) | ||
| newModel.models = [model.copy(extra) for model in self.models] | ||
| return newModel | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import doctest | ||
| import pyspark.ml.classification | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you ensure this is a valid classifier here? You should be able to assert that it has a rawPredictionCol.