You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[SPARK-10026] [ML] [PySpark] Implement some common Params for regression in PySpark
LinearRegression and LogisticRegression lack of some Params for Python, and some Params are not shared classes which lead we need to write them for each class. These kinds of Params are list here:
```scala
HasElasticNetParam
HasFitIntercept
HasStandardization
HasThresholds
```
Here we implement them in shared params at Python side and make LinearRegression/LogisticRegression parameters peer with Scala one.
Author: Yanbo Liang <[email protected]>
Closes#8508 from yanboliang/spark-10026.
Copy file name to clipboardExpand all lines: python/pyspark/ml/param/shared.py
+111Lines changed: 111 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -459,6 +459,117 @@ def getHandleInvalid(self):
459
459
returnself.getOrDefault(self.handleInvalid)
460
460
461
461
462
+
classHasElasticNetParam(Params):
463
+
"""
464
+
Mixin for param elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty..
465
+
"""
466
+
467
+
# a placeholder to make it appear in the generated doc
468
+
elasticNetParam=Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
469
+
470
+
def__init__(self):
471
+
super(HasElasticNetParam, self).__init__()
472
+
#: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
473
+
self.elasticNetParam=Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
474
+
self._setDefault(elasticNetParam=0.0)
475
+
476
+
defsetElasticNetParam(self, value):
477
+
"""
478
+
Sets the value of :py:attr:`elasticNetParam`.
479
+
"""
480
+
self._paramMap[self.elasticNetParam] =value
481
+
returnself
482
+
483
+
defgetElasticNetParam(self):
484
+
"""
485
+
Gets the value of elasticNetParam or its default value.
486
+
"""
487
+
returnself.getOrDefault(self.elasticNetParam)
488
+
489
+
490
+
classHasFitIntercept(Params):
491
+
"""
492
+
Mixin for param fitIntercept: whether to fit an intercept term..
493
+
"""
494
+
495
+
# a placeholder to make it appear in the generated doc
496
+
fitIntercept=Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.")
497
+
498
+
def__init__(self):
499
+
super(HasFitIntercept, self).__init__()
500
+
#: param for whether to fit an intercept term.
501
+
self.fitIntercept=Param(self, "fitIntercept", "whether to fit an intercept term.")
502
+
self._setDefault(fitIntercept=True)
503
+
504
+
defsetFitIntercept(self, value):
505
+
"""
506
+
Sets the value of :py:attr:`fitIntercept`.
507
+
"""
508
+
self._paramMap[self.fitIntercept] =value
509
+
returnself
510
+
511
+
defgetFitIntercept(self):
512
+
"""
513
+
Gets the value of fitIntercept or its default value.
514
+
"""
515
+
returnself.getOrDefault(self.fitIntercept)
516
+
517
+
518
+
classHasStandardization(Params):
519
+
"""
520
+
Mixin for param standardization: whether to standardize the training features before fitting the model..
521
+
"""
522
+
523
+
# a placeholder to make it appear in the generated doc
524
+
standardization=Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.")
525
+
526
+
def__init__(self):
527
+
super(HasStandardization, self).__init__()
528
+
#: param for whether to standardize the training features before fitting the model.
529
+
self.standardization=Param(self, "standardization", "whether to standardize the training features before fitting the model.")
530
+
self._setDefault(standardization=True)
531
+
532
+
defsetStandardization(self, value):
533
+
"""
534
+
Sets the value of :py:attr:`standardization`.
535
+
"""
536
+
self._paramMap[self.standardization] =value
537
+
returnself
538
+
539
+
defgetStandardization(self):
540
+
"""
541
+
Gets the value of standardization or its default value.
542
+
"""
543
+
returnself.getOrDefault(self.standardization)
544
+
545
+
546
+
classHasThresholds(Params):
547
+
"""
548
+
Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold..
549
+
"""
550
+
551
+
# a placeholder to make it appear in the generated doc
552
+
thresholds=Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.")
553
+
554
+
def__init__(self):
555
+
super(HasThresholds, self).__init__()
556
+
#: param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.
557
+
self.thresholds=Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.")
558
+
559
+
defsetThresholds(self, value):
560
+
"""
561
+
Sets the value of :py:attr:`thresholds`.
562
+
"""
563
+
self._paramMap[self.thresholds] =value
564
+
returnself
565
+
566
+
defgetThresholds(self):
567
+
"""
568
+
Gets the value of thresholds or its default value.
0 commit comments