Skip to content

Commit 99c6893

Browse files
yanboliangMarcelo Vanzin
authored andcommitted
[SPARK-10470] [ML] ml.IsotonicRegressionModel.copy should set parent
Copied model must have the same parent, but ml.IsotonicRegressionModel.copy did not set parent. Here fix it and add test case. Author: Yanbo Liang <[email protected]> Closes apache#8637 from yanboliang/spark-10470. (cherry picked from commit f7b55db) Signed-off-by: Xiangrui Meng <[email protected]> (cherry picked from commit 34d417e)
1 parent 969d6ca commit 99c6893

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ class IsotonicRegressionModel private[ml] (
202202
def predictions: Vector = Vectors.dense(oldModel.predictions)
203203

204204
override def copy(extra: ParamMap): IsotonicRegressionModel = {
205-
copyValues(new IsotonicRegressionModel(uid, oldModel), extra)
205+
copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent)
206206
}
207207

208208
override def transform(dataset: DataFrame): DataFrame = {

mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.ml.regression
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.ml.param.ParamsSuite
22+
import org.apache.spark.ml.util.MLTestingUtils
2223
import org.apache.spark.mllib.linalg.Vectors
2324
import org.apache.spark.mllib.util.MLlibTestSparkContext
2425
import org.apache.spark.sql.{DataFrame, Row}
@@ -89,6 +90,10 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
8990
assert(ir.getFeatureIndex === 0)
9091

9192
val model = ir.fit(dataset)
93+
94+
// copied model must have the same parent.
95+
MLTestingUtils.checkCopy(model)
96+
9297
model.transform(dataset)
9398
.select("label", "features", "prediction", "weight")
9499
.collect()

0 commit comments

Comments
 (0)