@@ -40,6 +40,20 @@ private[feature] trait IDFParams extends Params with HasInputCol with HasOutputC
4040 def setMinDocFreq (value : Int ): this .type = {
4141 set(minDocFreq, value)
4242 }
43+
44+ /**
45+ * Validate and transform the input schema.
46+ */
47+ protected def validateAndTransformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
48+ val map = this .paramMap ++ paramMap
49+ val inputType = schema(map(inputCol)).dataType
50+ require(inputType.isInstanceOf [VectorUDT ],
51+ s " Input column ${map(inputCol)} must be a vector column " )
52+ require(! schema.fieldNames.contains(map(outputCol)),
53+ s " Output column ${map(outputCol)} already exists. " )
54+ val outputFields = schema.fields :+ StructField (map(outputCol), new VectorUDT , false )
55+ StructType (outputFields)
56+ }
4357}
4458
4559/**
@@ -66,14 +80,7 @@ class IDF extends Estimator[IDFModel] with IDFParams {
6680 }
6781
6882 override def transformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
69- val map = this .paramMap ++ paramMap
70- val inputType = schema(map(inputCol)).dataType
71- require(inputType.isInstanceOf [VectorUDT ],
72- s " Input column ${map(inputCol)} must be a vector column " )
73- require(! schema.fieldNames.contains(map(outputCol)),
74- s " Output column ${map(outputCol)} already exists. " )
75- val outputFields = schema.fields :+ StructField (map(outputCol), new VectorUDT , false )
76- StructType (outputFields)
83+ validateAndTransformSchema(schema, paramMap)
7784 }
7885}
7986
@@ -97,18 +104,11 @@ class IDFModel private[ml] (
97104 override def transform (dataset : DataFrame , paramMap : ParamMap ): DataFrame = {
98105 transformSchema(dataset.schema, paramMap, logging = true )
99106 val map = this .paramMap ++ paramMap
100- val idf = udf((v : Vector ) => { idfModel.transform(v) } : Vector )
107+ val idf = udf((v : Vector ) => { idfModel.transform(v) })
101108 dataset.withColumn(map(outputCol), idf(col(map(inputCol))))
102109 }
103110
104111 override def transformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
105- val map = this .paramMap ++ paramMap
106- val inputType = schema(map(inputCol)).dataType
107- require(inputType.isInstanceOf [VectorUDT ],
108- s " Input column ${map(inputCol)} must be a vector column " )
109- require(! schema.fieldNames.contains(map(outputCol)),
110- s " Output column ${map(outputCol)} already exists. " )
111- val outputFields = schema.fields :+ StructField (map(outputCol), new VectorUDT , false )
112- StructType (outputFields)
112+ validateAndTransformSchema(schema, paramMap)
113113 }
114114}
0 commit comments