Skip to content

Commit dfcd014

Browse files
author
asarb
committed
add unit test for changes to change training related params like loss to only be validated during fitting phase
1 parent 5907806 commit dfcd014

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ private[regression] trait LinearRegressionParams extends PredictorParams
115115
"normal solver, please change solver to auto or l-bfgs.")
116116
require($(elasticNetParam) == 0.0, "LinearRegression with huber loss only supports " +
117117
s"L2 regularization, but got elasticNetParam = $getElasticNetParam.")
118-
119118
}
120119
}
121120
super.validateAndTransformSchema(schema, fitting, featuresDataType)

mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,18 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
187187
assert(model.numFeatures === numFeatures)
188188
}
189189

190+
test("linear regression: can transform data with LinearRegressionModel") {
191+
withClue("training related params like loss are only validated during fitting phase") {
192+
val original = new LinearRegression().fit(datasetWithDenseFeature)
193+
194+
val deserialized = new LinearRegressionModel(uid = original.uid,
195+
coefficients = original.coefficients,
196+
intercept = original.intercept)
197+
val output = deserialized.transform(datasetWithDenseFeature)
198+
assert(output.collect().size > 0) // simple assertion to ensure no exception thrown
199+
}
200+
}
201+
190202
test("linear regression: illegal params") {
191203
withClue("LinearRegression with huber loss only supports L2 regularization") {
192204
intercept[IllegalArgumentException] {

0 commit comments

Comments
 (0)