@@ -38,6 +38,7 @@ class LinearRegressionSuite
3838 @ transient var datasetWithSparseFeature : DataFrame = _
3939 @ transient var datasetWithWeight : DataFrame = _
4040 @ transient var datasetWithWeightConstantLabel : DataFrame = _
41+ @ transient var datasetWithWeightZeroLabel : DataFrame = _
4142
4243 /*
4344 In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
@@ -109,6 +110,13 @@ class LinearRegressionSuite
109110 Instance (17.0 , 3.0 , Vectors .dense(2.0 , 11.0 )),
110111 Instance (17.0 , 4.0 , Vectors .dense(3.0 , 13.0 ))
111112 ), 2 ))
113+ datasetWithWeightZeroLabel = sqlContext.createDataFrame(
114+ sc.parallelize(Seq (
115+ Instance (0.0 , 1.0 , Vectors .dense(0.0 , 5.0 ).toSparse),
116+ Instance (0.0 , 2.0 , Vectors .dense(1.0 , 7.0 )),
117+ Instance (0.0 , 3.0 , Vectors .dense(2.0 , 11.0 )),
118+ Instance (0.0 , 4.0 , Vectors .dense(3.0 , 13.0 ))
119+ ), 2 ))
112120 }
113121
114122 test(" params" ) {
@@ -592,21 +600,31 @@ class LinearRegressionSuite
592600 Seq (" auto" , " l-bfgs" , " normal" ).foreach { solver =>
593601 var idx = 0
594602 for (fitIntercept <- Seq (false , true )) {
595- val model = new LinearRegression ()
603+ val model1 = new LinearRegression ()
596604 .setFitIntercept(fitIntercept)
597605 .setWeightCol(" weight" )
598606 .setSolver(solver)
599607 .fit(datasetWithWeightConstantLabel)
600- val actual = Vectors .dense(model.intercept, model.coefficients(0 ), model.coefficients(1 ))
601- assert(actual ~== expected(idx) absTol 1e-4 )
608+ val actual1 = Vectors .dense(model1.intercept, model1.coefficients(0 ),
609+ model1.coefficients(1 ))
610+ assert(actual1 ~== expected(idx) absTol 1e-4 )
611+
612+ val model2 = new LinearRegression ()
613+ .setFitIntercept(fitIntercept)
614+ .setWeightCol(" weight" )
615+ .setSolver(solver)
616+ .fit(datasetWithWeightZeroLabel)
617+ val actual2 = Vectors .dense(model2.intercept, model2.coefficients(0 ),
618+ model2.coefficients(1 ))
619+ assert(actual2 ~== Vectors .dense(0.0 , 0.0 , 0.0 ) absTol 1e-4 )
602620 idx += 1
603621 }
604622 }
605623 }
606624
607625 test(" regularized linear regression through origin with constant label" ) {
608- // The problem is ill-defined if fitIntercept=false, regParam is non-zero and
609- // standardization=true. An exception is thrown in this case.
626+ // The problem is ill-defined if fitIntercept=false, regParam is non-zero.
627+ // An exception is thrown in this case.
610628 Seq (" auto" , " l-bfgs" , " normal" ).foreach { solver =>
611629 for (standardization <- Seq (false , true )) {
612630 val model = new LinearRegression ().setFitIntercept(false )
@@ -618,6 +636,33 @@ class LinearRegressionSuite
618636 }
619637 }
620638
639+ test(" linear regression with l-bfgs when training is not needed" ) {
640+ // When label is constant, l-bfgs solver returns results without training.
641+ // There are two possibilities: If the label is non-zero but constant,
642+ // and fitIntercept is true, then the model return yMean as intercept without training.
643+ // If label is all zeros, then all coefficients are zero regardless of fitIntercept, so
644+ // no training is needed.
645+ for (fitIntercept <- Seq (false , true )) {
646+ for (standardization <- Seq (false , true )) {
647+ val model1 = new LinearRegression ()
648+ .setFitIntercept(fitIntercept)
649+ .setStandardization(standardization)
650+ .setWeightCol(" weight" )
651+ .setSolver(" l-bfgs" )
652+ .fit(datasetWithWeightConstantLabel)
653+ if (fitIntercept) {
654+ assert(model1.summary.objectiveHistory(0 ) ~== 0.0 absTol 1e-4 )
655+ }
656+ val model2 = new LinearRegression ()
657+ .setFitIntercept(fitIntercept)
658+ .setWeightCol(" weight" )
659+ .setSolver(" l-bfgs" )
660+ .fit(datasetWithWeightZeroLabel)
661+ assert(model2.summary.objectiveHistory(0 ) ~== 0.0 absTol 1e-4 )
662+ }
663+ }
664+ }
665+
621666 test(" linear regression model training summary" ) {
622667 Seq (" auto" , " l-bfgs" , " normal" ).foreach { solver =>
623668 val trainer = new LinearRegression ().setSolver(solver)
0 commit comments