-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-12230][ML] WeightedLeastSquares.fit() should handle division by zero properly if standard deviation of target variable is zero. #10274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -86,6 +86,24 @@ private[ml] class WeightedLeastSquares( | |
| val aaBar = summary.aaBar | ||
| val aaValues = aaBar.values | ||
|
|
||
| if (bStd == 0) { | ||
| if (fitIntercept) { | ||
| logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + | ||
| s"zeros and the intercept will be the mean of the label; as a result, " + | ||
| s"training is not needed.") | ||
| val coefficients = new DenseVector(Array.ofDim(k-1)) | ||
| val intercept = bBar | ||
| val diagInvAtWA = new DenseVector(Array(0D)) | ||
| return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA) | ||
| } else { | ||
| require(!(regParam > 0.0 && standardizeLabel), | ||
| "The standard deviation of the label is zero. " + | ||
| "Model cannot be regularized with standardization=true") | ||
| logWarning(s"The standard deviation of the label is zero. " + | ||
| "Consider setting fitIntercept=true.") | ||
| } | ||
| } | ||
|
|
||
| // add regularization to diagonals | ||
| var i = 0 | ||
| var j = 2 | ||
|
|
@@ -94,8 +112,7 @@ private[ml] class WeightedLeastSquares( | |
| if (standardizeFeatures) { | ||
| lambda *= aVar(j - 2) | ||
| } | ||
| if (standardizeLabel) { | ||
| // TODO: handle the case when bStd = 0 | ||
| if (standardizeLabel && bStd != 0) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you check when
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dbtsai The problem here is that for regularized regression in R, I need to use Note that in this example, I expect same results from both Right now
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. As you said, we will expect non zero coefficients in this case, so we don't have to match glmnet. However, we may want to throw excpetion when standerizeLabe is true, and ystd is zero since the problem is not well defined. Thanks.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The We can throw an exception when The option could be to simply log a warning when we don't standardize the label here. Let me know what you think.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's interesting to see that when regularization is zero, with/without standardization on labels and features will not change the solution of Linear Regression which you can experiment. As a result, the only issue that the model will be non-interpretable will be
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is the exercise, test("WLS against lm") {
/*
R code:
df <- as.data.frame(cbind(A, b))
for (formula in c(b ~ . -1, b ~ .)) {
model <- lm(formula, data=df, weights=w)
print(as.vector(coef(model)))
}
[1] -3.727121 3.009983
[1] 18.08 6.08 -0.60
*/
val expected = Seq(
Vectors.dense(0.0, -3.727121, 3.009983),
Vectors.dense(18.08, 6.08, -0.60))
var idx = 0
for (fitIntercept <- Seq(false, true)) {
for (standardization <- Seq(false, true)) {
val wls = new WeightedLeastSquares(
fitIntercept, regParam = 0.0, standardizeFeatures = standardization,
standardizeLabel = standardization).fit(instances)
val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
assert(actual ~== expected(idx) absTol 1e-4)
}
idx += 1
}
} |
||
| lambda /= bStd | ||
| } | ||
| aaValues(i) += lambda | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
LinearRegressionhas a bug related to this,https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala#L226
when fitIntercept is false, the code should still train the model. Can you fix it in either separate PR or here?
Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's fix it in a separate PR to make thing easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did notice that bug. I was planning to create separate jira for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dbtsai I just created PR for this bug with separate jira (SPARK-12732)..