Skip to content

Commit 2480dc1

Browse files
committed
added test for the case when results are produced without training (when label is constant)
1 parent c0744d8 commit 2480dc1

File tree

1 file changed

+50
-5
lines changed

1 file changed

+50
-5
lines changed

mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class LinearRegressionSuite
3838
@transient var datasetWithSparseFeature: DataFrame = _
3939
@transient var datasetWithWeight: DataFrame = _
4040
@transient var datasetWithWeightConstantLabel: DataFrame = _
41+
@transient var datasetWithWeightZeroLabel: DataFrame = _
4142

4243
/*
4344
In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
@@ -109,6 +110,13 @@ class LinearRegressionSuite
109110
Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)),
110111
Instance(17.0, 4.0, Vectors.dense(3.0, 13.0))
111112
), 2))
113+
datasetWithWeightZeroLabel = sqlContext.createDataFrame(
114+
sc.parallelize(Seq(
115+
Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
116+
Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)),
117+
Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)),
118+
Instance(0.0, 4.0, Vectors.dense(3.0, 13.0))
119+
), 2))
112120
}
113121

114122
test("params") {
@@ -592,21 +600,31 @@ class LinearRegressionSuite
592600
Seq("auto", "l-bfgs", "normal").foreach { solver =>
593601
var idx = 0
594602
for (fitIntercept <- Seq(false, true)) {
595-
val model = new LinearRegression()
603+
val model1 = new LinearRegression()
596604
.setFitIntercept(fitIntercept)
597605
.setWeightCol("weight")
598606
.setSolver(solver)
599607
.fit(datasetWithWeightConstantLabel)
600-
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
601-
assert(actual ~== expected(idx) absTol 1e-4)
608+
val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0),
609+
model1.coefficients(1))
610+
assert(actual1 ~== expected(idx) absTol 1e-4)
611+
612+
val model2 = new LinearRegression()
613+
.setFitIntercept(fitIntercept)
614+
.setWeightCol("weight")
615+
.setSolver(solver)
616+
.fit(datasetWithWeightZeroLabel)
617+
val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0),
618+
model2.coefficients(1))
619+
assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4)
602620
idx += 1
603621
}
604622
}
605623
}
606624

607625
test("regularized linear regression through origin with constant label") {
608-
// The problem is ill-defined if fitIntercept=false, regParam is non-zero and
609-
// standardization=true. An exception is thrown in this case.
626+
// The problem is ill-defined if fitIntercept=false, regParam is non-zero.
627+
// An exception is thrown in this case.
610628
Seq("auto", "l-bfgs", "normal").foreach { solver =>
611629
for (standardization <- Seq(false, true)) {
612630
val model = new LinearRegression().setFitIntercept(false)
@@ -618,6 +636,33 @@ class LinearRegressionSuite
618636
}
619637
}
620638

639+
test("linear regression with l-bfgs when training is not needed") {
640+
// When label is constant, l-bfgs solver returns results without training.
641+
// There are two possibilities: If the label is non-zero but constant,
642+
// and fitIntercept is true, then the model return yMean as intercept without training.
643+
// If label is all zeros, then all coefficients are zero regardless of fitIntercept, so
644+
// no training is needed.
645+
for (fitIntercept <- Seq(false, true)) {
646+
for (standardization <- Seq(false, true)) {
647+
val model1 = new LinearRegression()
648+
.setFitIntercept(fitIntercept)
649+
.setStandardization(standardization)
650+
.setWeightCol("weight")
651+
.setSolver("l-bfgs")
652+
.fit(datasetWithWeightConstantLabel)
653+
if (fitIntercept) {
654+
assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
655+
}
656+
val model2 = new LinearRegression()
657+
.setFitIntercept(fitIntercept)
658+
.setWeightCol("weight")
659+
.setSolver("l-bfgs")
660+
.fit(datasetWithWeightZeroLabel)
661+
assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
662+
}
663+
}
664+
}
665+
621666
test("linear regression model training summary") {
622667
Seq("auto", "l-bfgs", "normal").foreach { solver =>
623668
val trainer = new LinearRegression().setSolver(solver)

0 commit comments

Comments
 (0)