Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,13 @@ private[regression] trait LinearRegressionParams extends PredictorParams
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
if ($(loss) == Huber) {
require($(solver)!= Normal, "LinearRegression with huber loss doesn't support " +
"normal solver, please change solver to auto or l-bfgs.")
require($(elasticNetParam) == 0.0, "LinearRegression with huber loss only supports " +
s"L2 regularization, but got elasticNetParam = $getElasticNetParam.")

if (fitting) {
if ($(loss) == Huber) {
require($(solver)!= Normal, "LinearRegression with huber loss doesn't support " +
"normal solver, please change solver to auto or l-bfgs.")
require($(elasticNetParam) == 0.0, "LinearRegression with huber loss only supports " +
s"L2 regularization, but got elasticNetParam = $getElasticNetParam.")
}
}
super.validateAndTransformSchema(schema, fitting, featuresDataType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,18 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
assert(model.numFeatures === numFeatures)
}

test("linear regression: can transform data with LinearRegressionModel") {
withClue("training related params like loss are only validated during fitting phase") {
val original = new LinearRegression().fit(datasetWithDenseFeature)

val deserialized = new LinearRegressionModel(uid = original.uid,
coefficients = original.coefficients,
intercept = original.intercept)
val output = deserialized.transform(datasetWithDenseFeature)
assert(output.collect().size > 0) // simple assertion to ensure no exception thrown
}
}

test("linear regression: illegal params") {
withClue("LinearRegression with huber loss only supports L2 regularization") {
intercept[IllegalArgumentException] {
Expand Down