Skip to content

Commit b656e61

Browse files
yanboliangmengxr
authored andcommitted
[SPARK-10026] [ML] [PySpark] Implement some common Params for regression in PySpark
LinearRegression and LogisticRegression lack of some Params for Python, and some Params are not shared classes which lead we need to write them for each class. These kinds of Params are list here: ```scala HasElasticNetParam HasFitIntercept HasStandardization HasThresholds ``` Here we implement them in shared params at Python side and make LinearRegression/LogisticRegression parameters peer with Scala one. Author: Yanbo Liang <[email protected]> Closes #8508 from yanboliang/spark-10026.
1 parent c268ca4 commit b656e61

File tree

4 files changed

+143
-96
lines changed

4 files changed

+143
-96
lines changed

python/pyspark/ml/classification.py

Lines changed: 11 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131

3232
@inherit_doc
3333
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
34-
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol):
34+
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
35+
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds):
3536
"""
3637
Logistic regression.
3738
Currently, this class only supports binary classification.
@@ -65,72 +66,44 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
6566
"""
6667

6768
# a placeholder to make it appear in the generated doc
68-
elasticNetParam = \
69-
Param(Params._dummy(), "elasticNetParam",
70-
"the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
71-
"the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
72-
fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.")
73-
thresholds = Param(Params._dummy(), "thresholds",
74-
"Thresholds in multi-class classification" +
75-
" to adjust the probability of predicting each class." +
76-
" Array must have length equal to the number of classes, with values >= 0." +
77-
" The class with largest value p/t is predicted, where p is the original" +
78-
" probability of that class and t is the class' threshold.")
7969
threshold = Param(Params._dummy(), "threshold",
8070
"Threshold in binary classification prediction, in range [0, 1]." +
8171
" If threshold and thresholds are both set, they must match.")
8272

8373
@keyword_only
8474
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
8575
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
86-
threshold=0.5, thresholds=None,
87-
probabilityCol="probability", rawPredictionCol="rawPrediction"):
76+
threshold=0.5, thresholds=None, probabilityCol="probability",
77+
rawPredictionCol="rawPrediction", standardization=True):
8878
"""
8979
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
9080
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
91-
threshold=0.5, thresholds=None, \
92-
probabilityCol="probability", rawPredictionCol="rawPrediction")
81+
threshold=0.5, thresholds=None, probabilityCol="probability", \
82+
rawPredictionCol="rawPrediction", standardization=True)
9383
If the threshold and thresholds Params are both set, they must be equivalent.
9484
"""
9585
super(LogisticRegression, self).__init__()
9686
self._java_obj = self._new_java_obj(
9787
"org.apache.spark.ml.classification.LogisticRegression", self.uid)
98-
#: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty
99-
# is an L2 penalty. For alpha = 1, it is an L1 penalty.
100-
self.elasticNetParam = \
101-
Param(self, "elasticNetParam",
102-
"the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
103-
"the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
104-
#: param for whether to fit an intercept term.
105-
self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.")
10688
#: param for threshold in binary classification, in range [0, 1].
10789
self.threshold = Param(self, "threshold",
10890
"Threshold in binary classification prediction, in range [0, 1]." +
10991
" If threshold and thresholds are both set, they must match.")
110-
#: param for thresholds or cutoffs in binary or multiclass classification
111-
self.thresholds = \
112-
Param(self, "thresholds",
113-
"Thresholds in multi-class classification" +
114-
" to adjust the probability of predicting each class." +
115-
" Array must have length equal to the number of classes, with values >= 0." +
116-
" The class with largest value p/t is predicted, where p is the original" +
117-
" probability of that class and t is the class' threshold.")
118-
self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6,
119-
fitIntercept=True, threshold=0.5)
92+
self._setDefault(maxIter=100, regParam=0.1, tol=1E-6, threshold=0.5)
12093
kwargs = self.__init__._input_kwargs
12194
self.setParams(**kwargs)
12295
self._checkThresholdConsistency()
12396

12497
@keyword_only
12598
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
12699
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
127-
threshold=0.5, thresholds=None,
128-
probabilityCol="probability", rawPredictionCol="rawPrediction"):
100+
threshold=0.5, thresholds=None, probabilityCol="probability",
101+
rawPredictionCol="rawPrediction", standardization=True):
129102
"""
130103
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
131104
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
132-
threshold=0.5, thresholds=None, \
133-
probabilityCol="probability", rawPredictionCol="rawPrediction")
105+
threshold=0.5, thresholds=None, probabilityCol="probability", \
106+
rawPredictionCol="rawPrediction", standardization=True)
134107
Sets params for logistic regression.
135108
If the threshold and thresholds Params are both set, they must be equivalent.
136109
"""
@@ -142,32 +115,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
142115
def _create_model(self, java_model):
143116
return LogisticRegressionModel(java_model)
144117

145-
def setElasticNetParam(self, value):
146-
"""
147-
Sets the value of :py:attr:`elasticNetParam`.
148-
"""
149-
self._paramMap[self.elasticNetParam] = value
150-
return self
151-
152-
def getElasticNetParam(self):
153-
"""
154-
Gets the value of elasticNetParam or its default value.
155-
"""
156-
return self.getOrDefault(self.elasticNetParam)
157-
158-
def setFitIntercept(self, value):
159-
"""
160-
Sets the value of :py:attr:`fitIntercept`.
161-
"""
162-
self._paramMap[self.fitIntercept] = value
163-
return self
164-
165-
def getFitIntercept(self):
166-
"""
167-
Gets the value of fitIntercept or its default value.
168-
"""
169-
return self.getOrDefault(self.fitIntercept)
170-
171118
def setThreshold(self, value):
172119
"""
173120
Sets the value of :py:attr:`threshold`.

python/pyspark/ml/param/_shared_params_code_gen.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,16 @@ def get$Name(self):
124124
("stepSize", "Step size to be used for each iteration of optimization.", None),
125125
("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " +
126126
"out rows with bad values), or error (which will throw an errror). More options may be " +
127-
"added later.", None)]
127+
"added later.", None),
128+
("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
129+
"the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0"),
130+
("fitIntercept", "whether to fit an intercept term.", "True"),
131+
("standardization", "whether to standardize the training features before fitting the " +
132+
"model.", "True"),
133+
("thresholds", "Thresholds in multi-class classification to adjust the probability of " +
134+
"predicting each class. Array must have length equal to the number of classes, with " +
135+
"values >= 0. The class with largest value p/t is predicted, where p is the original " +
136+
"probability of that class and t is the class' threshold.", None)]
128137
code = []
129138
for name, doc, defaultValueStr in shared:
130139
param_code = _gen_param_header(name, doc, defaultValueStr)

python/pyspark/ml/param/shared.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,117 @@ def getHandleInvalid(self):
459459
return self.getOrDefault(self.handleInvalid)
460460

461461

462+
class HasElasticNetParam(Params):
463+
"""
464+
Mixin for param elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty..
465+
"""
466+
467+
# a placeholder to make it appear in the generated doc
468+
elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
469+
470+
def __init__(self):
471+
super(HasElasticNetParam, self).__init__()
472+
#: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
473+
self.elasticNetParam = Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
474+
self._setDefault(elasticNetParam=0.0)
475+
476+
def setElasticNetParam(self, value):
477+
"""
478+
Sets the value of :py:attr:`elasticNetParam`.
479+
"""
480+
self._paramMap[self.elasticNetParam] = value
481+
return self
482+
483+
def getElasticNetParam(self):
484+
"""
485+
Gets the value of elasticNetParam or its default value.
486+
"""
487+
return self.getOrDefault(self.elasticNetParam)
488+
489+
490+
class HasFitIntercept(Params):
491+
"""
492+
Mixin for param fitIntercept: whether to fit an intercept term..
493+
"""
494+
495+
# a placeholder to make it appear in the generated doc
496+
fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.")
497+
498+
def __init__(self):
499+
super(HasFitIntercept, self).__init__()
500+
#: param for whether to fit an intercept term.
501+
self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.")
502+
self._setDefault(fitIntercept=True)
503+
504+
def setFitIntercept(self, value):
505+
"""
506+
Sets the value of :py:attr:`fitIntercept`.
507+
"""
508+
self._paramMap[self.fitIntercept] = value
509+
return self
510+
511+
def getFitIntercept(self):
512+
"""
513+
Gets the value of fitIntercept or its default value.
514+
"""
515+
return self.getOrDefault(self.fitIntercept)
516+
517+
518+
class HasStandardization(Params):
519+
"""
520+
Mixin for param standardization: whether to standardize the training features before fitting the model..
521+
"""
522+
523+
# a placeholder to make it appear in the generated doc
524+
standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.")
525+
526+
def __init__(self):
527+
super(HasStandardization, self).__init__()
528+
#: param for whether to standardize the training features before fitting the model.
529+
self.standardization = Param(self, "standardization", "whether to standardize the training features before fitting the model.")
530+
self._setDefault(standardization=True)
531+
532+
def setStandardization(self, value):
533+
"""
534+
Sets the value of :py:attr:`standardization`.
535+
"""
536+
self._paramMap[self.standardization] = value
537+
return self
538+
539+
def getStandardization(self):
540+
"""
541+
Gets the value of standardization or its default value.
542+
"""
543+
return self.getOrDefault(self.standardization)
544+
545+
546+
class HasThresholds(Params):
547+
"""
548+
Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold..
549+
"""
550+
551+
# a placeholder to make it appear in the generated doc
552+
thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.")
553+
554+
def __init__(self):
555+
super(HasThresholds, self).__init__()
556+
#: param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.
557+
self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.")
558+
559+
def setThresholds(self, value):
560+
"""
561+
Sets the value of :py:attr:`thresholds`.
562+
"""
563+
self._paramMap[self.thresholds] = value
564+
return self
565+
566+
def getThresholds(self):
567+
"""
568+
Gets the value of thresholds or its default value.
569+
"""
570+
return self.getOrDefault(self.thresholds)
571+
572+
462573
class DecisionTreeParams(Params):
463574
"""
464575
Mixin for Decision Tree parameters.

python/pyspark/ml/regression.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828

2929
@inherit_doc
3030
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
31-
HasRegParam, HasTol):
31+
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
32+
HasStandardization):
3233
"""
3334
Linear regression.
3435
@@ -63,38 +64,30 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
6364
TypeError: Method setParams forces keyword arguments.
6465
"""
6566

66-
# a placeholder to make it appear in the generated doc
67-
elasticNetParam = \
68-
Param(Params._dummy(), "elasticNetParam",
69-
"the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
70-
"the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
71-
7267
@keyword_only
7368
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
74-
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6):
69+
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
70+
standardization=True):
7571
"""
7672
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
77-
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6)
73+
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
74+
standardization=True)
7875
"""
7976
super(LinearRegression, self).__init__()
8077
self._java_obj = self._new_java_obj(
8178
"org.apache.spark.ml.regression.LinearRegression", self.uid)
82-
#: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty
83-
# is an L2 penalty. For alpha = 1, it is an L1 penalty.
84-
self.elasticNetParam = \
85-
Param(self, "elasticNetParam",
86-
"the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " +
87-
"is an L2 penalty. For alpha = 1, it is an L1 penalty.")
88-
self._setDefault(maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6)
79+
self._setDefault(maxIter=100, regParam=0.0, tol=1e-6)
8980
kwargs = self.__init__._input_kwargs
9081
self.setParams(**kwargs)
9182

9283
@keyword_only
9384
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
94-
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6):
85+
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
86+
standardization=True):
9587
"""
9688
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
97-
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6)
89+
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
90+
standardization=True)
9891
Sets params for linear regression.
9992
"""
10093
kwargs = self.setParams._input_kwargs
@@ -103,19 +96,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
10396
def _create_model(self, java_model):
10497
return LinearRegressionModel(java_model)
10598

106-
def setElasticNetParam(self, value):
107-
"""
108-
Sets the value of :py:attr:`elasticNetParam`.
109-
"""
110-
self._paramMap[self.elasticNetParam] = value
111-
return self
112-
113-
def getElasticNetParam(self):
114-
"""
115-
Gets the value of elasticNetParam or its default value.
116-
"""
117-
return self.getOrDefault(self.elasticNetParam)
118-
11999

120100
class LinearRegressionModel(JavaModel):
121101
"""

0 commit comments

Comments
 (0)