From 5907806aacde016b0be1fd0f3b39e4f2d2306ccd Mon Sep 17 00:00:00 2001 From: asarb Date: Wed, 1 May 2019 20:04:55 +0100 Subject: [PATCH 1/2] only check training specific params when validateAndTransformSchema is called for training, ignore them during scoring --- .../spark/ml/regression/LinearRegression.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index cd0b7d3ff8aa..cbebd9c7589b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -109,12 +109,14 @@ private[regression] trait LinearRegressionParams extends PredictorParams schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { - if ($(loss) == Huber) { - require($(solver)!= Normal, "LinearRegression with huber loss doesn't support " + - "normal solver, please change solver to auto or l-bfgs.") - require($(elasticNetParam) == 0.0, "LinearRegression with huber loss only supports " + - s"L2 regularization, but got elasticNetParam = $getElasticNetParam.") + if (fitting) { + if ($(loss) == Huber) { + require($(solver)!= Normal, "LinearRegression with huber loss doesn't support " + + "normal solver, please change solver to auto or l-bfgs.") + require($(elasticNetParam) == 0.0, "LinearRegression with huber loss only supports " + + s"L2 regularization, but got elasticNetParam = $getElasticNetParam.") + } } super.validateAndTransformSchema(schema, fitting, featuresDataType) } From dfcd014c52a243eb8c2d2a357b1bf18064e4b11e Mon Sep 17 00:00:00 2001 From: asarb Date: Wed, 1 May 2019 21:10:58 +0100 Subject: [PATCH 2/2] add unit test for changes to change training related params like loss to only be validated during fitting phase --- .../spark/ml/regression/LinearRegression.scala | 1 - .../spark/ml/regression/LinearRegressionSuite.scala | 12 ++++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index cbebd9c7589b..09f3f94d346b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -115,7 +115,6 @@ private[regression] trait LinearRegressionParams extends PredictorParams "normal solver, please change solver to auto or l-bfgs.") require($(elasticNetParam) == 0.0, "LinearRegression with huber loss only supports " + s"L2 regularization, but got elasticNetParam = $getElasticNetParam.") - } } super.validateAndTransformSchema(schema, fitting, featuresDataType) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index d3df0e5b4448..82d984933d81 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -187,6 +187,18 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe assert(model.numFeatures === numFeatures) } + test("linear regression: can transform data with LinearRegressionModel") { + withClue("training related params like loss are only validated during fitting phase") { + val original = new LinearRegression().fit(datasetWithDenseFeature) + + val deserialized = new LinearRegressionModel(uid = original.uid, + coefficients = original.coefficients, + intercept = original.intercept) + val output = deserialized.transform(datasetWithDenseFeature) + assert(output.collect().size > 0) // simple assertion to ensure no exception thrown + } + } + test("linear regression: illegal params") { withClue("LinearRegression with huber loss only supports L2 regularization") { intercept[IllegalArgumentException] {