Skip to content

Commit 03fbecb

Browse files
committed
remove duplicated code
1 parent 2aa4be0 commit 03fbecb

File tree

1 file changed

+17
-17
lines changed
  • mllib/src/main/scala/org/apache/spark/ml/feature

1 file changed

+17
-17
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,20 @@ private[feature] trait IDFParams extends Params with HasInputCol with HasOutputC
4040
def setMinDocFreq(value: Int): this.type = {
4141
set(minDocFreq, value)
4242
}
43+
44+
/**
45+
* Validate and transform the input schema.
46+
*/
47+
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
48+
val map = this.paramMap ++ paramMap
49+
val inputType = schema(map(inputCol)).dataType
50+
require(inputType.isInstanceOf[VectorUDT],
51+
s"Input column ${map(inputCol)} must be a vector column")
52+
require(!schema.fieldNames.contains(map(outputCol)),
53+
s"Output column ${map(outputCol)} already exists.")
54+
val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
55+
StructType(outputFields)
56+
}
4357
}
4458

4559
/**
@@ -66,14 +80,7 @@ class IDF extends Estimator[IDFModel] with IDFParams {
6680
}
6781

6882
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
69-
val map = this.paramMap ++ paramMap
70-
val inputType = schema(map(inputCol)).dataType
71-
require(inputType.isInstanceOf[VectorUDT],
72-
s"Input column ${map(inputCol)} must be a vector column")
73-
require(!schema.fieldNames.contains(map(outputCol)),
74-
s"Output column ${map(outputCol)} already exists.")
75-
val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
76-
StructType(outputFields)
83+
validateAndTransformSchema(schema, paramMap)
7784
}
7885
}
7986

@@ -97,18 +104,11 @@ class IDFModel private[ml] (
97104
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
98105
transformSchema(dataset.schema, paramMap, logging = true)
99106
val map = this.paramMap ++ paramMap
100-
val idf = udf((v: Vector) => { idfModel.transform(v) } : Vector)
107+
val idf = udf((v: Vector) => { idfModel.transform(v) })
101108
dataset.withColumn(map(outputCol), idf(col(map(inputCol))))
102109
}
103110

104111
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
105-
val map = this.paramMap ++ paramMap
106-
val inputType = schema(map(inputCol)).dataType
107-
require(inputType.isInstanceOf[VectorUDT],
108-
s"Input column ${map(inputCol)} must be a vector column")
109-
require(!schema.fieldNames.contains(map(outputCol)),
110-
s"Output column ${map(outputCol)} already exists.")
111-
val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
112-
StructType(outputFields)
112+
validateAndTransformSchema(schema, paramMap)
113113
}
114114
}

0 commit comments

Comments
 (0)