Skip to content

Commit c94b34e

Browse files
yanboliangmengxr
authored andcommitted
[SPARK-15339][ML] ML 2.0 QA: Scala APIs and code audit for regression
## What changes were proposed in this pull request? * ```GeneralizedLinearRegression``` API docs enhancement. * The default value of ```GeneralizedLinearRegression``` ```linkPredictionCol``` is not set rather than empty. This will consistent with other similar params such as ```weightCol``` * Make some methods more private. * Fix a minor bug of LinearRegression. * Fix some other issues. ## How was this patch tested? Existing tests. Author: Yanbo Liang <[email protected]> Closes #13129 from yanboliang/spark-15339.
1 parent 5e20350 commit c94b34e

File tree

5 files changed

+58
-47
lines changed

5 files changed

+58
-47
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
8989
def getQuantilesCol: String = $(quantilesCol)
9090

9191
/** Checks whether the input has quantiles column name. */
92-
protected[regression] def hasQuantilesCol: Boolean = {
93-
isDefined(quantilesCol) && $(quantilesCol) != ""
92+
private[regression] def hasQuantilesCol: Boolean = {
93+
isDefined(quantilesCol) && $(quantilesCol).nonEmpty
9494
}
9595

9696
/**

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
4343
with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol
4444
with HasSolver with Logging {
4545

46+
import GeneralizedLinearRegression._
47+
4648
/**
4749
* Param for the name of family which is a description of the error distribution
4850
* to be used in the model.
@@ -54,8 +56,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
5456
@Since("2.0.0")
5557
final val family: Param[String] = new Param(this, "family",
5658
"The name of family which is a description of the error distribution to be used in the " +
57-
"model. Supported options: gaussian(default), binomial, poisson and gamma.",
58-
ParamValidators.inArray[String](GeneralizedLinearRegression.supportedFamilyNames.toArray))
59+
s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
60+
ParamValidators.inArray[String](supportedFamilyNames.toArray))
5961

6062
/** @group getParam */
6163
@Since("2.0.0")
@@ -71,29 +73,32 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
7173
@Since("2.0.0")
7274
final val link: Param[String] = new Param(this, "link", "The name of link function " +
7375
"which provides the relationship between the linear predictor and the mean of the " +
74-
"distribution function. Supported options: identity, log, inverse, logit, probit, " +
75-
"cloglog and sqrt.",
76-
ParamValidators.inArray[String](GeneralizedLinearRegression.supportedLinkNames.toArray))
76+
s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}",
77+
ParamValidators.inArray[String](supportedLinkNames.toArray))
7778

7879
/** @group getParam */
7980
@Since("2.0.0")
8081
def getLink: String = $(link)
8182

8283
/**
8384
* Param for link prediction (linear predictor) column name.
84-
* Default is empty, which means we do not output link prediction.
85+
* Default is not set, which means we do not output link prediction.
8586
*
8687
* @group param
8788
*/
8889
@Since("2.0.0")
8990
final val linkPredictionCol: Param[String] = new Param[String](this, "linkPredictionCol",
9091
"link prediction (linear predictor) column name")
91-
setDefault(linkPredictionCol, "")
9292

9393
/** @group getParam */
9494
@Since("2.0.0")
9595
def getLinkPredictionCol: String = $(linkPredictionCol)
9696

97+
/** Checks whether we should output link prediction. */
98+
private[regression] def hasLinkPredictionCol: Boolean = {
99+
isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty
100+
}
101+
97102
import GeneralizedLinearRegression._
98103

99104
@Since("2.0.0")
@@ -107,7 +112,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
107112
s"with ${$(family)} family does not support ${$(link)} link function.")
108113
}
109114
val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
110-
if ($(linkPredictionCol).nonEmpty) {
115+
if (hasLinkPredictionCol) {
111116
SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType)
112117
} else {
113118
newSchema
@@ -205,7 +210,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
205210
/**
206211
* Sets the value of param [[weightCol]].
207212
* If this is not set or empty, we treat all instance weights as 1.0.
208-
* Default is empty, so all instances have weight one.
213+
* Default is not set, so all instances have weight one.
209214
*
210215
* @group setParam
211216
*/
@@ -214,7 +219,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
214219

215220
/**
216221
* Sets the solver algorithm used for optimization.
217-
* Currently only support "irls" which is also the default solver.
222+
* Currently only supports "irls" which is also the default solver.
218223
*
219224
* @group setParam
220225
*/
@@ -239,10 +244,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
239244
}
240245
val familyAndLink = new FamilyAndLink(familyObj, linkObj)
241246

242-
val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd
243-
.map { case Row(features: Vector) =>
244-
features.size
245-
}.first()
247+
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
246248
if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) {
247249
val msg = "Currently, GeneralizedLinearRegression only supports number of features" +
248250
s" <= ${WeightedLeastSquares.MAX_NUM_FEATURES}. Found $numFeatures in the input dataset."
@@ -294,25 +296,25 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
294296
override def load(path: String): GeneralizedLinearRegression = super.load(path)
295297

296298
/** Set of family and link pairs that GeneralizedLinearRegression supports. */
297-
private[ml] lazy val supportedFamilyAndLinkPairs = Set(
299+
private[regression] lazy val supportedFamilyAndLinkPairs = Set(
298300
Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse,
299301
Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog,
300302
Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt,
301303
Gamma -> Inverse, Gamma -> Identity, Gamma -> Log
302304
)
303305

304306
/** Set of family names that GeneralizedLinearRegression supports. */
305-
private[ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
307+
private[regression] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
306308

307309
/** Set of link names that GeneralizedLinearRegression supports. */
308-
private[ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
310+
private[regression] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
309311

310-
private[ml] val epsilon: Double = 1E-16
312+
private[regression] val epsilon: Double = 1E-16
311313

312314
/**
313315
* Wrapper of family and link combination used in the model.
314316
*/
315-
private[ml] class FamilyAndLink(val family: Family, val link: Link) extends Serializable {
317+
private[regression] class FamilyAndLink(val family: Family, val link: Link) extends Serializable {
316318

317319
/** Linear predictor based on given mu. */
318320
def predict(mu: Double): Double = link.link(family.project(mu))
@@ -359,7 +361,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
359361
*
360362
* @param name the name of the family.
361363
*/
362-
private[ml] abstract class Family(val name: String) extends Serializable {
364+
private[regression] abstract class Family(val name: String) extends Serializable {
363365

364366
/** The default link instance of this family. */
365367
val defaultLink: Link
@@ -391,7 +393,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
391393
def project(mu: Double): Double = mu
392394
}
393395

394-
private[ml] object Family {
396+
private[regression] object Family {
395397

396398
/**
397399
* Gets the [[Family]] object from its name.
@@ -412,7 +414,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
412414
* Gaussian exponential family distribution.
413415
* The default link for the Gaussian family is the identity link.
414416
*/
415-
private[ml] object Gaussian extends Family("gaussian") {
417+
private[regression] object Gaussian extends Family("gaussian") {
416418

417419
val defaultLink: Link = Identity
418420

@@ -448,7 +450,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
448450
* Binomial exponential family distribution.
449451
* The default link for the Binomial family is the logit link.
450452
*/
451-
private[ml] object Binomial extends Family("binomial") {
453+
private[regression] object Binomial extends Family("binomial") {
452454

453455
val defaultLink: Link = Logit
454456

@@ -492,7 +494,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
492494
* Poisson exponential family distribution.
493495
* The default link for the Poisson family is the log link.
494496
*/
495-
private[ml] object Poisson extends Family("poisson") {
497+
private[regression] object Poisson extends Family("poisson") {
496498

497499
val defaultLink: Link = Log
498500

@@ -533,7 +535,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
533535
* Gamma exponential family distribution.
534536
* The default link for the Gamma family is the inverse link.
535537
*/
536-
private[ml] object Gamma extends Family("gamma") {
538+
private[regression] object Gamma extends Family("gamma") {
537539

538540
val defaultLink: Link = Inverse
539541

@@ -578,7 +580,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
578580
*
579581
* @param name the name of link function.
580582
*/
581-
private[ml] abstract class Link(val name: String) extends Serializable {
583+
private[regression] abstract class Link(val name: String) extends Serializable {
582584

583585
/** The link function. */
584586
def link(mu: Double): Double
@@ -590,7 +592,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
590592
def unlink(eta: Double): Double
591593
}
592594

593-
private[ml] object Link {
595+
private[regression] object Link {
594596

595597
/**
596598
* Gets the [[Link]] object from its name.
@@ -611,7 +613,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
611613
}
612614
}
613615

614-
private[ml] object Identity extends Link("identity") {
616+
private[regression] object Identity extends Link("identity") {
615617

616618
override def link(mu: Double): Double = mu
617619

@@ -620,7 +622,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
620622
override def unlink(eta: Double): Double = eta
621623
}
622624

623-
private[ml] object Logit extends Link("logit") {
625+
private[regression] object Logit extends Link("logit") {
624626

625627
override def link(mu: Double): Double = math.log(mu / (1.0 - mu))
626628

@@ -629,7 +631,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
629631
override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta))
630632
}
631633

632-
private[ml] object Log extends Link("log") {
634+
private[regression] object Log extends Link("log") {
633635

634636
override def link(mu: Double): Double = math.log(mu)
635637

@@ -638,7 +640,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
638640
override def unlink(eta: Double): Double = math.exp(eta)
639641
}
640642

641-
private[ml] object Inverse extends Link("inverse") {
643+
private[regression] object Inverse extends Link("inverse") {
642644

643645
override def link(mu: Double): Double = 1.0 / mu
644646

@@ -647,7 +649,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
647649
override def unlink(eta: Double): Double = 1.0 / eta
648650
}
649651

650-
private[ml] object Probit extends Link("probit") {
652+
private[regression] object Probit extends Link("probit") {
651653

652654
override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu)
653655

@@ -658,7 +660,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
658660
override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta)
659661
}
660662

661-
private[ml] object CLogLog extends Link("cloglog") {
663+
private[regression] object CLogLog extends Link("cloglog") {
662664

663665
override def link(mu: Double): Double = math.log(-1.0 * math.log(1 - mu))
664666

@@ -667,7 +669,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
667669
override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta))
668670
}
669671

670-
private[ml] object Sqrt extends Link("sqrt") {
672+
private[regression] object Sqrt extends Link("sqrt") {
671673

672674
override def link(mu: Double): Double = math.sqrt(mu)
673675

@@ -732,7 +734,7 @@ class GeneralizedLinearRegressionModel private[ml] (
732734
if ($(predictionCol).nonEmpty) {
733735
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
734736
}
735-
if ($(linkPredictionCol).nonEmpty) {
737+
if (hasLinkPredictionCol) {
736738
output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol))))
737739
}
738740
output.toDF()
@@ -860,7 +862,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
860862
*/
861863
@Since("2.0.0")
862864
val predictionCol: String = {
863-
if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol != "") {
865+
if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol.nonEmpty) {
864866
origModel.getPredictionCol
865867
} else {
866868
"prediction_" + java.util.UUID.randomUUID.toString

mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
6969
setDefault(isotonic -> true, featureIndex -> 0)
7070

7171
/** Checks whether the input has weight column. */
72-
protected[ml] def hasWeightCol: Boolean = {
73-
isDefined(weightCol) && $(weightCol) != ""
72+
private[regression] def hasWeightCol: Boolean = {
73+
isDefined(weightCol) && $(weightCol).nonEmpty
7474
}
7575

7676
/**

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
161161

162162
override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
163163
// Extract the number of features before deciding optimization solver.
164-
val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map {
165-
case Row(features: Vector) => features.size
166-
}.first()
164+
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
167165
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
168166

169167
if (($(solver) == "auto" && $(elasticNetParam) == 0.0 &&
@@ -242,7 +240,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
242240
val coefficients = Vectors.sparse(numFeatures, Seq())
243241
val intercept = yMean
244242

245-
val model = new LinearRegressionModel(uid, coefficients, intercept)
243+
val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept))
246244
// Handle possible missing or invalid prediction columns
247245
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
248246

@@ -254,7 +252,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
254252
model,
255253
Array(0D),
256254
Array(0D))
257-
return copyValues(model.setSummary(trainingSummary))
255+
return model.setSummary(trainingSummary)
258256
} else {
259257
require($(regParam) == 0.0, "The standard deviation of the label is zero. " +
260258
"Model cannot be regularized.")

mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,20 +610,31 @@ class LinearRegressionSuite
610610
val model1 = new LinearRegression()
611611
.setFitIntercept(fitIntercept)
612612
.setWeightCol("weight")
613+
.setPredictionCol("myPrediction")
613614
.setSolver(solver)
614615
.fit(datasetWithWeightConstantLabel)
615616
val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0),
616617
model1.coefficients(1))
617618
assert(actual1 ~== expected(idx) absTol 1e-4)
618619

620+
// Schema of summary.predictions should be a superset of the input dataset
621+
assert((datasetWithWeightConstantLabel.schema.fieldNames.toSet + model1.getPredictionCol)
622+
.subsetOf(model1.summary.predictions.schema.fieldNames.toSet))
623+
619624
val model2 = new LinearRegression()
620625
.setFitIntercept(fitIntercept)
621626
.setWeightCol("weight")
627+
.setPredictionCol("myPrediction")
622628
.setSolver(solver)
623629
.fit(datasetWithWeightZeroLabel)
624630
val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0),
625631
model2.coefficients(1))
626632
assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4)
633+
634+
// Schema of summary.predictions should be a superset of the input dataset
635+
assert((datasetWithWeightZeroLabel.schema.fieldNames.toSet + model2.getPredictionCol)
636+
.subsetOf(model2.summary.predictions.schema.fieldNames.toSet))
637+
627638
idx += 1
628639
}
629640
}
@@ -672,7 +683,7 @@ class LinearRegressionSuite
672683

673684
test("linear regression model training summary") {
674685
Seq("auto", "l-bfgs", "normal").foreach { solver =>
675-
val trainer = new LinearRegression().setSolver(solver)
686+
val trainer = new LinearRegression().setSolver(solver).setPredictionCol("myPrediction")
676687
val model = trainer.fit(datasetWithDenseFeature)
677688
val trainerNoPredictionCol = trainer.setPredictionCol("")
678689
val modelNoPredictionCol = trainerNoPredictionCol.fit(datasetWithDenseFeature)
@@ -682,7 +693,7 @@ class LinearRegressionSuite
682693
assert(modelNoPredictionCol.hasSummary)
683694

684695
// Schema should be a superset of the input dataset
685-
assert((datasetWithDenseFeature.schema.fieldNames.toSet + "prediction").subsetOf(
696+
assert((datasetWithDenseFeature.schema.fieldNames.toSet + model.getPredictionCol).subsetOf(
686697
model.summary.predictions.schema.fieldNames.toSet))
687698
// Validate that we re-insert a prediction column for evaluation
688699
val modelNoPredictionColFieldNames

0 commit comments

Comments
 (0)