-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-15339] [ML] ML 2.0 QA: Scala APIs and code audit for regression #13129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
254313c
645f6c4
374e610
d38b1eb
1fbd1dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -159,9 +159,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
|
|
||
| override protected def train(dataset: Dataset[_]): LinearRegressionModel = { | ||
| // Extract the number of features before deciding optimization solver. | ||
| val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map { | ||
| case Row(features: Vector) => features.size | ||
| }.first() | ||
| val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here, can we do |
||
| val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) | ||
|
|
||
| if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && | ||
|
|
@@ -240,7 +238,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
| val coefficients = Vectors.sparse(numFeatures, Seq()) | ||
| val intercept = yMean | ||
|
|
||
| val model = new LinearRegressionModel(uid, coefficients, intercept) | ||
| val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept)) | ||
| // Handle possible missing or invalid prediction columns | ||
| val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() | ||
|
|
||
|
|
@@ -252,7 +250,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
| model, | ||
| Array(0D), | ||
| Array(0D)) | ||
| return copyValues(model.setSummary(trainingSummary)) | ||
| return model.setSummary(trainingSummary) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a minor bug of
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So test cases didn't pick this up? We should look into why and amend the tests accordingly.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is due to we don't have excellent test coverage ...
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MLnick I added test case for this scenario and updated other test cases to ensure coping prediction column(and other params) correct in all situations.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also need to setParent
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jkbradley It does not necessary to |
||
| } else { | ||
| require($(regParam) == 0.0, "The standard deviation of the label is zero. " + | ||
| "Model cannot be regularized.") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we not do
dataset.select(col($(featuresCol))).as[Vector].first().size?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You means
dataset.select(col($(featuresCol))).as[Vector].first().size? I think it's OK.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like Spark does not provide encoder for Vector. If I change to use
as[Vector], the compiler will complain:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah right, we would need to add an implicit encoder
implicit def encoder: Encoder[Vector] = ExpressionEncoder(), e.g. see here #12718 (comment).However, let's leave that change for #12718