diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index cd0b7d3ff8aa..09f3f94d346b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -109,12 +109,13 @@ private[regression] trait LinearRegressionParams extends PredictorParams schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { - if ($(loss) == Huber) { - require($(solver)!= Normal, "LinearRegression with huber loss doesn't support " + - "normal solver, please change solver to auto or l-bfgs.") - require($(elasticNetParam) == 0.0, "LinearRegression with huber loss only supports " + - s"L2 regularization, but got elasticNetParam = $getElasticNetParam.") - + if (fitting) { + if ($(loss) == Huber) { + require($(solver)!= Normal, "LinearRegression with huber loss doesn't support " + + "normal solver, please change solver to auto or l-bfgs.") + require($(elasticNetParam) == 0.0, "LinearRegression with huber loss only supports " + + s"L2 regularization, but got elasticNetParam = $getElasticNetParam.") + } } super.validateAndTransformSchema(schema, fitting, featuresDataType) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index d3df0e5b4448..82d984933d81 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -187,6 +187,18 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe assert(model.numFeatures === numFeatures) } + test("linear regression: can transform data with LinearRegressionModel") { + withClue("training related params like loss are only validated during fitting phase") { + val original = new LinearRegression().fit(datasetWithDenseFeature) + + val deserialized = new LinearRegressionModel(uid = original.uid, + coefficients = original.coefficients, + intercept = original.intercept) + val output = deserialized.transform(datasetWithDenseFeature) + assert(output.collect().size > 0) // simple assertion to ensure no exception thrown + } + } + test("linear regression: illegal params") { withClue("LinearRegression with huber loss only supports L2 regularization") { intercept[IllegalArgumentException] {