Skip to content

Commit 2cf2ed0

Browse files
committed
snapshot
1 parent 291814f commit 2cf2ed0

File tree

57 files changed

+293
-52
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+293
-52
lines changed

mllib/src/main/scala/org/apache/spark/ml/Estimator.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,5 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
7878
paramMaps.map(fit(dataset, _))
7979
}
8080

81-
override def copy(extra: ParamMap): Estimator[M] = {
82-
super.copy(extra).asInstanceOf[Estimator[M]]
83-
}
81+
override def copy(extra: ParamMap): Estimator[M]
8482
}

mllib/src/main/scala/org/apache/spark/ml/Model.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,5 @@ abstract class Model[M <: Model[M]] extends Transformer {
4545
/** Indicates whether this [[Model]] has a corresponding parent. */
4646
def hasParent: Boolean = parent != null
4747

48-
override def copy(extra: ParamMap): M = {
49-
// The default implementation of Params.copy doesn't work for models.
50-
throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)")
51-
}
48+
override def copy(extra: ParamMap): M
5249
}

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ abstract class PipelineStage extends Params with Logging {
6363
outputSchema
6464
}
6565

66-
override def copy(extra: ParamMap): PipelineStage = {
67-
super.copy(extra).asInstanceOf[PipelineStage]
68-
}
66+
override def copy(extra: ParamMap): PipelineStage
6967
}
7068

7169
/**

mllib/src/main/scala/org/apache/spark/ml/Predictor.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,7 @@ abstract class Predictor[
9090
copyValues(train(dataset).setParent(this))
9191
}
9292

93-
override def copy(extra: ParamMap): Learner = {
94-
super.copy(extra).asInstanceOf[Learner]
95-
}
93+
override def copy(extra: ParamMap): Learner
9694

9795
/**
9896
* Train a model using the given dataset and parameters.

mllib/src/main/scala/org/apache/spark/ml/Transformer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@ abstract class Transformer extends PipelineStage {
6767
*/
6868
def transform(dataset: DataFrame): DataFrame
6969

70-
override def copy(extra: ParamMap): Transformer = {
71-
super.copy(extra).asInstanceOf[Transformer]
72-
}
70+
override def copy(extra: ParamMap): Transformer
7371
}
7472

7573
/**
@@ -120,4 +118,6 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
120118
dataset.withColumn($(outputCol),
121119
callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol))))
122120
}
121+
122+
override def copy(extra: ParamMap): T = defaultCopyWithParams(extra)
123123
}

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.ml.classification
1919

2020
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.ml.param.ParamMap
2122
import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
2223
import org.apache.spark.ml.param.shared.HasRawPredictionCol
2324
import org.apache.spark.ml.util.SchemaUtils

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ final class DecisionTreeClassifier(override val uid: String)
8686
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
8787
subsamplingRate = 1.0)
8888
}
89+
90+
override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopyWithParams(extra)
8991
}
9092

9193
@Experimental

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ final class GBTClassifier(override val uid: String)
141141
val oldModel = oldGBT.run(oldDataset)
142142
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures)
143143
}
144+
145+
override def copy(extra: ParamMap): GBTClassifier = defaultCopyWithParams(extra)
144146
}
145147

146148
@Experimental

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ class LogisticRegression(override val uid: String)
220220

221221
new LogisticRegressionModel(uid, weights.compressed, intercept)
222222
}
223+
224+
override def copy(extra: ParamMap): LogisticRegression = defaultCopyWithParams(extra)
223225
}
224226

225227
/**

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.language.existentials
2424
import org.apache.spark.annotation.Experimental
2525
import org.apache.spark.ml._
2626
import org.apache.spark.ml.attribute._
27-
import org.apache.spark.ml.param.{ParamMap, Param}
27+
import org.apache.spark.ml.param.{Param, ParamMap}
2828
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
2929
import org.apache.spark.mllib.linalg.Vector
3030
import org.apache.spark.sql.{DataFrame, Row}
@@ -217,7 +217,7 @@ final class OneVsRest(override val uid: String)
217217
}
218218

219219
override def copy(extra: ParamMap): OneVsRest = {
220-
val copied = super.copy(extra).asInstanceOf[OneVsRest]
220+
val copied = defaultCopyWithParams(extra).asInstanceOf[OneVsRest]
221221
if (isDefined(classifier)) {
222222
copied.setClassifier($(classifier).copy(extra))
223223
}

0 commit comments

Comments
 (0)