Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -219,33 +219,49 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
}

val yMean = ySummarizer.mean(0)
val yStd = math.sqrt(ySummarizer.variance(0))

// If the yStd is zero, then the intercept is yMean with zero coefficient;
// as a result, training is not needed.
if (yStd == 0.0) {
logWarning(s"The standard deviation of the label is zero, so the coefficients will be " +
s"zeros and the intercept will be the mean of the label; as a result, " +
s"training is not needed.")
if (handlePersistence) instances.unpersist()
val coefficients = Vectors.sparse(numFeatures, Seq())
val intercept = yMean

val model = new LinearRegressionModel(uid, coefficients, intercept)
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()

val trainingSummary = new LinearRegressionTrainingSummary(
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
model,
Array(0D),
$(featuresCol),
Array(0D))
return copyValues(model.setSummary(trainingSummary))
val rawYStd = math.sqrt(ySummarizer.variance(0))
if (rawYStd == 0.0) {
if ($(fitIntercept) || yMean==0.0) {
// If the rawYStd is zero and fitIntercept=true, then the intercept is yMean with
// zero coefficient; as a result, training is not needed.
// Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of
// the fitIntercept
if (yMean == 0.0) {
logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " +
s"and the intercept will all be zero; as a result, training is not needed.")
} else {
logWarning(s"The standard deviation of the label is zero, so the coefficients will be " +
s"zeros and the intercept will be the mean of the label; as a result, " +
s"training is not needed.")
}
if (handlePersistence) instances.unpersist()
val coefficients = Vectors.sparse(numFeatures, Seq())
val intercept = yMean

val model = new LinearRegressionModel(uid, coefficients, intercept)
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()

val trainingSummary = new LinearRegressionTrainingSummary(
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
model,
Array(0D),
$(featuresCol),
Array(0D))
return copyValues(model.setSummary(trainingSummary))
} else {
require($(regParam) == 0.0, "The standard deviation of the label is zero. " +
"Model cannot be regularized.")
logWarning(s"The standard deviation of the label is zero. " +
"Consider setting fitIntercept=true.")
}
}

// if y is constant (rawYStd is zero), then y cannot be scaled. In this case
// setting yStd=1.0 ensures that y is not scaled anymore in l-bfgs algorithm.
val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean)
val featuresMean = featuresSummarizer.mean.toArray
val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class LinearRegressionSuite
@transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _
@transient var datasetWithSparseFeature: DataFrame = _
@transient var datasetWithWeight: DataFrame = _
@transient var datasetWithWeightConstantLabel: DataFrame = _
@transient var datasetWithWeightZeroLabel: DataFrame = _

/*
In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
Expand Down Expand Up @@ -92,6 +94,29 @@ class LinearRegressionSuite
Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
), 2))

/*
R code:

A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
b.const <- c(17, 17, 17, 17)
w <- c(1, 2, 3, 4)
df.const.label <- as.data.frame(cbind(A, b.const))
*/
datasetWithWeightConstantLabel = sqlContext.createDataFrame(
sc.parallelize(Seq(
Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)),
Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(17.0, 4.0, Vectors.dense(3.0, 13.0))
), 2))
datasetWithWeightZeroLabel = sqlContext.createDataFrame(
sc.parallelize(Seq(
Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)),
Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(0.0, 4.0, Vectors.dense(3.0, 13.0))
), 2))
}

test("params") {
Expand Down Expand Up @@ -558,6 +583,86 @@ class LinearRegressionSuite
}
}

test("linear regression model with constant label") {
/*
R code:
for (formula in c(b.const ~ . -1, b.const ~ .)) {
model <- lm(formula, data=df.const.label, weights=w)
print(as.vector(coef(model)))
}
[1] -9.221298 3.394343
[1] 17 0 0
*/
val expected = Seq(
Vectors.dense(0.0, -9.221298, 3.394343),
Vectors.dense(17.0, 0.0, 0.0))

Seq("auto", "l-bfgs", "normal").foreach { solver =>
var idx = 0
for (fitIntercept <- Seq(false, true)) {
val model1 = new LinearRegression()
.setFitIntercept(fitIntercept)
.setWeightCol("weight")
.setSolver(solver)
.fit(datasetWithWeightConstantLabel)
val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0),
model1.coefficients(1))
assert(actual1 ~== expected(idx) absTol 1e-4)

val model2 = new LinearRegression()
.setFitIntercept(fitIntercept)
.setWeightCol("weight")
.setSolver(solver)
.fit(datasetWithWeightZeroLabel)
val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0),
model2.coefficients(1))
assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4)
idx += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When fitInercept = true, check the size of loss history is zero. (since the solution is returned without any optimization.)

Will be nice to add one small test that labelStd = 0 and labelMean = 0 when fitInercept = false

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to check the size of lost history. Could you please point me to some example?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In LinearRegressionTrainingSummary, you can get it from objectiveHistory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added tests as suggested.

}
}
}

test("regularized linear regression through origin with constant label") {
// The problem is ill-defined if fitIntercept=false, regParam is non-zero.
// An exception is thrown in this case.
Seq("auto", "l-bfgs", "normal").foreach { solver =>
for (standardization <- Seq(false, true)) {
val model = new LinearRegression().setFitIntercept(false)
.setRegParam(0.1).setStandardization(standardization).setSolver(solver)
intercept[IllegalArgumentException] {
model.fit(datasetWithWeightConstantLabel)
}
}
}
}

test("linear regression with l-bfgs when training is not needed") {
// When label is constant, l-bfgs solver returns results without training.
// There are two possibilities: If the label is non-zero but constant,
// and fitIntercept is true, then the model return yMean as intercept without training.
// If label is all zeros, then all coefficients are zero regardless of fitIntercept, so
// no training is needed.
for (fitIntercept <- Seq(false, true)) {
for (standardization <- Seq(false, true)) {
val model1 = new LinearRegression()
.setFitIntercept(fitIntercept)
.setStandardization(standardization)
.setWeightCol("weight")
.setSolver("l-bfgs")
.fit(datasetWithWeightConstantLabel)
if (fitIntercept) {
assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
}
val model2 = new LinearRegression()
.setFitIntercept(fitIntercept)
.setWeightCol("weight")
.setSolver("l-bfgs")
.fit(datasetWithWeightZeroLabel)
assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
}
}
}

test("linear regression model training summary") {
Seq("auto", "l-bfgs", "normal").foreach { solver =>
val trainer = new LinearRegression().setSolver(solver)
Expand Down