Skip to content

Commit 0802ff9

Browse files
YY-OnCallNick Pentreath
authored andcommitted
[SPARK-15668][ML] ml.feature: update check schema to avoid confusion when user use MLlib.vector as input type
## What changes were proposed in this pull request? ml.feature: update check schema to avoid confusion when user use MLlib.vector as input type ## How was this patch tested? existing ut Author: Yuhao Yang <[email protected]> Closes #13411 from hhbyyh/schemaCheck. (cherry picked from commit 5855e00) Signed-off-by: Nick Pentreath <[email protected]>
1 parent 698b6f6 commit 0802ff9

File tree

4 files changed

+25
-36
lines changed

4 files changed

+25
-36
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ private[feature] trait MaxAbsScalerParams extends Params with HasInputCol with H
3939

4040
/** Validates and transforms the input schema. */
4141
protected def validateAndTransformSchema(schema: StructType): StructType = {
42-
val inputType = schema($(inputCol)).dataType
43-
require(inputType.isInstanceOf[VectorUDT],
44-
s"Input column ${$(inputCol)} must be a vector column")
42+
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
4543
require(!schema.fieldNames.contains($(outputCol)),
4644
s"Output column ${$(outputCol)} already exists.")
4745
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)

mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
6363
/** Validates and transforms the input schema. */
6464
protected def validateAndTransformSchema(schema: StructType): StructType = {
6565
require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})")
66-
val inputType = schema($(inputCol)).dataType
67-
require(inputType.isInstanceOf[VectorUDT],
68-
s"Input column ${$(inputCol)} must be a vector column")
66+
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
6967
require(!schema.fieldNames.contains($(outputCol)),
7068
s"Output column ${$(outputCol)} already exists.")
7169
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)

mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,16 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC
4949
/** @group getParam */
5050
def getK: Int = $(k)
5151

52-
}
52+
/** Validates and transforms the input schema. */
53+
protected def validateAndTransformSchema(schema: StructType): StructType = {
54+
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
55+
require(!schema.fieldNames.contains($(outputCol)),
56+
s"Output column ${$(outputCol)} already exists.")
57+
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
58+
StructType(outputFields)
59+
}
5360

61+
}
5462
/**
5563
* :: Experimental ::
5664
* PCA trains a model to project vectors to a lower dimensional space of the top [[PCA!.k]]
@@ -86,13 +94,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
8694
}
8795

8896
override def transformSchema(schema: StructType): StructType = {
89-
val inputType = schema($(inputCol)).dataType
90-
require(inputType.isInstanceOf[VectorUDT],
91-
s"Input column ${$(inputCol)} must be a vector column")
92-
require(!schema.fieldNames.contains($(outputCol)),
93-
s"Output column ${$(outputCol)} already exists.")
94-
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
95-
StructType(outputFields)
97+
validateAndTransformSchema(schema)
9698
}
9799

98100
override def copy(extra: ParamMap): PCA = defaultCopy(extra)
@@ -148,13 +150,7 @@ class PCAModel private[ml] (
148150
}
149151

150152
override def transformSchema(schema: StructType): StructType = {
151-
val inputType = schema($(inputCol)).dataType
152-
require(inputType.isInstanceOf[VectorUDT],
153-
s"Input column ${$(inputCol)} must be a vector column")
154-
require(!schema.fieldNames.contains($(outputCol)),
155-
s"Output column ${$(outputCol)} already exists.")
156-
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
157-
StructType(outputFields)
153+
validateAndTransformSchema(schema)
158154
}
159155

160156
override def copy(extra: ParamMap): PCAModel = {
@@ -201,7 +197,7 @@ object PCAModel extends MLReadable[PCAModel] {
201197
val versionRegex = "([0-9]+)\\.([0-9]+).*".r
202198
val hasExplainedVariance = metadata.sparkVersion match {
203199
case versionRegex(major, minor) =>
204-
(major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6))
200+
major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6)
205201
case _ => false
206202
}
207203

mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
6262
/** @group getParam */
6363
def getWithStd: Boolean = $(withStd)
6464

65+
/** Validates and transforms the input schema. */
66+
protected def validateAndTransformSchema(schema: StructType): StructType = {
67+
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
68+
require(!schema.fieldNames.contains($(outputCol)),
69+
s"Output column ${$(outputCol)} already exists.")
70+
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
71+
StructType(outputFields)
72+
}
73+
6574
setDefault(withMean -> false, withStd -> true)
6675
}
6776

@@ -105,13 +114,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
105114
}
106115

107116
override def transformSchema(schema: StructType): StructType = {
108-
val inputType = schema($(inputCol)).dataType
109-
require(inputType.isInstanceOf[VectorUDT],
110-
s"Input column ${$(inputCol)} must be a vector column")
111-
require(!schema.fieldNames.contains($(outputCol)),
112-
s"Output column ${$(outputCol)} already exists.")
113-
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
114-
StructType(outputFields)
117+
validateAndTransformSchema(schema)
115118
}
116119

117120
override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra)
@@ -159,13 +162,7 @@ class StandardScalerModel private[ml] (
159162
}
160163

161164
override def transformSchema(schema: StructType): StructType = {
162-
val inputType = schema($(inputCol)).dataType
163-
require(inputType.isInstanceOf[VectorUDT],
164-
s"Input column ${$(inputCol)} must be a vector column")
165-
require(!schema.fieldNames.contains($(outputCol)),
166-
s"Output column ${$(outputCol)} already exists.")
167-
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
168-
StructType(outputFields)
165+
validateAndTransformSchema(schema)
169166
}
170167

171168
override def copy(extra: ParamMap): StandardScalerModel = {

0 commit comments

Comments
 (0)