Skip to content

Commit 43c7ec6

Browse files
committed
[SPARK-8151] [MLLIB] pipeline components should correctly implement copy
Otherwise, extra params get ignored in `PipelineModel.transform`. jkbradley Author: Xiangrui Meng <[email protected]> Closes #6622 from mengxr/SPARK-8087 and squashes the following commits: 0e4c8c4 [Xiangrui Meng] fix merge issues 26fc1f0 [Xiangrui Meng] address comments e607a04 [Xiangrui Meng] merge master b85b57e [Xiangrui Meng] fix examples/compile d6f7891 [Xiangrui Meng] rename defaultCopyWithParams to defaultCopy 84ec278 [Xiangrui Meng] remove setter checks due to generics 2cf2ed0 [Xiangrui Meng] snapshot 291814f [Xiangrui Meng] OneVsRest.copy 1dfe3bd [Xiangrui Meng] PipelineModel.copy should copy stages
1 parent 47af7c1 commit 43c7ec6

File tree

62 files changed

+350
-55
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+350
-55
lines changed

examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) {
156156
// Create a model, and return it.
157157
return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this);
158158
}
159+
160+
@Override
161+
public MyJavaLogisticRegression copy(ParamMap extra) {
162+
return defaultCopy(extra);
163+
}
159164
}
160165

161166
/**

examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ private class MyLogisticRegression(override val uid: String)
130130
// Create a model, and return it.
131131
new MyLogisticRegressionModel(uid, weights).setParent(this)
132132
}
133+
134+
override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra)
133135
}
134136

135137
/**

mllib/src/main/scala/org/apache/spark/ml/Estimator.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,5 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
7878
paramMaps.map(fit(dataset, _))
7979
}
8080

81-
override def copy(extra: ParamMap): Estimator[M] = {
82-
super.copy(extra).asInstanceOf[Estimator[M]]
83-
}
81+
override def copy(extra: ParamMap): Estimator[M]
8482
}

mllib/src/main/scala/org/apache/spark/ml/Model.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,5 @@ abstract class Model[M <: Model[M]] extends Transformer {
4545
/** Indicates whether this [[Model]] has a corresponding parent. */
4646
def hasParent: Boolean = parent != null
4747

48-
override def copy(extra: ParamMap): M = {
49-
// The default implementation of Params.copy doesn't work for models.
50-
throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)")
51-
}
48+
override def copy(extra: ParamMap): M
5249
}

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,7 @@ abstract class PipelineStage extends Params with Logging {
6666
outputSchema
6767
}
6868

69-
override def copy(extra: ParamMap): PipelineStage = {
70-
super.copy(extra).asInstanceOf[PipelineStage]
71-
}
69+
override def copy(extra: ParamMap): PipelineStage
7270
}
7371

7472
/**
@@ -198,6 +196,6 @@ class PipelineModel private[ml] (
198196
}
199197

200198
override def copy(extra: ParamMap): PipelineModel = {
201-
new PipelineModel(uid, stages)
199+
new PipelineModel(uid, stages.map(_.copy(extra)))
202200
}
203201
}

mllib/src/main/scala/org/apache/spark/ml/Predictor.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,7 @@ abstract class Predictor[
9090
copyValues(train(dataset).setParent(this))
9191
}
9292

93-
override def copy(extra: ParamMap): Learner = {
94-
super.copy(extra).asInstanceOf[Learner]
95-
}
93+
override def copy(extra: ParamMap): Learner
9694

9795
/**
9896
* Train a model using the given dataset and parameters.

mllib/src/main/scala/org/apache/spark/ml/Transformer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@ abstract class Transformer extends PipelineStage {
6767
*/
6868
def transform(dataset: DataFrame): DataFrame
6969

70-
override def copy(extra: ParamMap): Transformer = {
71-
super.copy(extra).asInstanceOf[Transformer]
72-
}
70+
override def copy(extra: ParamMap): Transformer
7371
}
7472

7573
/**
@@ -120,4 +118,6 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
120118
dataset.withColumn($(outputCol),
121119
callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol))))
122120
}
121+
122+
override def copy(extra: ParamMap): T = defaultCopy(extra)
123123
}

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.ml.classification
1919

2020
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.ml.param.ParamMap
2122
import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
2223
import org.apache.spark.ml.param.shared.HasRawPredictionCol
2324
import org.apache.spark.ml.util.SchemaUtils

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ final class DecisionTreeClassifier(override val uid: String)
8686
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
8787
subsamplingRate = 1.0)
8888
}
89+
90+
override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra)
8991
}
9092

9193
@Experimental

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ final class GBTClassifier(override val uid: String)
141141
val oldModel = oldGBT.run(oldDataset)
142142
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures)
143143
}
144+
145+
override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
144146
}
145147

146148
@Experimental

0 commit comments

Comments
 (0)