Skip to content

Commit a5fed34

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-5902] [ml] Made PipelineStage.transformSchema public instead of private to ml
For users to implement their own PipelineStages, we need to make PipelineStage.transformSchema be public instead of private to ml. This would be nice to include in Spark 1.3 CC: mengxr Author: Joseph K. Bradley <[email protected]> Closes apache#4682 from jkbradley/SPARK-5902 and squashes the following commits: 6f02357 [Joseph K. Bradley] Made transformSchema public 0e6d0a0 [Joseph K. Bradley] made implementations of transformSchema protected as well fdaf26a [Joseph K. Bradley] Made PipelineStage.transformSchema protected instead of private[ml]
1 parent 8ca3418 commit a5fed34

File tree

5 files changed

+20
-12
lines changed

5 files changed

+20
-12
lines changed

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.ml
2020
import scala.collection.mutable.ListBuffer
2121

2222
import org.apache.spark.Logging
23-
import org.apache.spark.annotation.AlphaComponent
23+
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
2424
import org.apache.spark.ml.param.{Param, ParamMap}
2525
import org.apache.spark.sql.DataFrame
2626
import org.apache.spark.sql.types.StructType
@@ -33,9 +33,17 @@ import org.apache.spark.sql.types.StructType
3333
abstract class PipelineStage extends Serializable with Logging {
3434

3535
/**
36+
* :: DeveloperAPI ::
37+
*
3638
* Derives the output schema from the input schema and parameters.
39+
* The schema describes the columns and types of the data.
40+
*
41+
* @param schema Input schema to this stage
42+
* @param paramMap Parameters passed to this stage
43+
* @return Output schema from this stage
3744
*/
38-
private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
45+
@DeveloperApi
46+
def transformSchema(schema: StructType, paramMap: ParamMap): StructType
3947

4048
/**
4149
* Derives the output schema from the input schema and parameters, optionally with logging.
@@ -126,7 +134,7 @@ class Pipeline extends Estimator[PipelineModel] {
126134
new PipelineModel(this, map, transformers.toArray)
127135
}
128136

129-
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
137+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
130138
val map = this.paramMap ++ paramMap
131139
val theStages = map(stages)
132140
require(theStages.toSet.size == theStages.size,
@@ -171,7 +179,7 @@ class PipelineModel private[ml] (
171179
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
172180
}
173181

174-
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
182+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
175183
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
176184
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
177185
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))

mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
5555
model
5656
}
5757

58-
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
58+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
5959
val map = this.paramMap ++ paramMap
6060
val inputType = schema(map(inputCol)).dataType
6161
require(inputType.isInstanceOf[VectorUDT],
@@ -91,7 +91,7 @@ class StandardScalerModel private[ml] (
9191
dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
9292
}
9393

94-
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
94+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
9595
val map = this.paramMap ++ paramMap
9696
val inputType = schema(map(inputCol)).dataType
9797
require(inputType.isInstanceOf[VectorUDT],

mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ private[spark] abstract class Predictor[
132132
@DeveloperApi
133133
protected def featuresDataType: DataType = new VectorUDT
134134

135-
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
135+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
136136
validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
137137
}
138138

@@ -184,7 +184,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
184184
@DeveloperApi
185185
protected def featuresDataType: DataType = new VectorUDT
186186

187-
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
187+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
188188
validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)
189189
}
190190

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class ALSModel private[ml] (
188188
.select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol)))
189189
}
190190

191-
override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
191+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
192192
validateAndTransformSchema(schema, paramMap)
193193
}
194194
}
@@ -292,7 +292,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
292292
model
293293
}
294294

295-
override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
295+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
296296
validateAndTransformSchema(schema, paramMap)
297297
}
298298
}

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
129129
cvModel
130130
}
131131

132-
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
132+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
133133
val map = this.paramMap ++ paramMap
134134
map(estimator).transformSchema(schema, paramMap)
135135
}
@@ -150,7 +150,7 @@ class CrossValidatorModel private[ml] (
150150
bestModel.transform(dataset, paramMap)
151151
}
152152

153-
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
153+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
154154
bestModel.transformSchema(schema, paramMap)
155155
}
156156
}

0 commit comments

Comments
 (0)