@@ -20,26 +20,29 @@ package org.apache.spark.ml.feature
2020import org .apache .spark .annotation .AlphaComponent
2121import org .apache .spark .ml ._
2222import org .apache .spark .ml .param ._
23+ import org .apache .spark .ml .param .shared ._
24+ import org .apache .spark .ml .util .SchemaUtils
2325import org .apache .spark .mllib .feature
2426import org .apache .spark .mllib .linalg .{Vector , VectorUDT }
2527import org .apache .spark .sql ._
2628import org .apache .spark .sql .functions ._
27- import org .apache .spark .sql .types .{ StructField , StructType }
29+ import org .apache .spark .sql .types .StructType
2830
2931/**
3032 * Params for [[IDF ]] and [[IDFModel ]].
3133 */
32- private [feature] trait IDFParams extends Params with HasInputCol with HasOutputCol {
34+ private [feature] trait IDFBase extends Params with HasInputCol with HasOutputCol {
3335
3436 /**
3537 * The minimum of documents in which a term should appear.
3638 * @group param
3739 */
3840 val minDocFreq = new IntParam (
39- this , " minDocFreq" , " minimum of documents in which a term should appear for filtering" , Some (0 ))
41+ this , " minDocFreq" , " minimum of documents in which a term should appear for filtering" )
42+ setDefault(minDocFreq -> 0 )
4043
4144 /** @group getParam */
42- def getMinDocFreq : Int = get (minDocFreq)
45+ def getMinDocFreq : Int = getOrDefault (minDocFreq)
4346
4447 /** @group setParam */
4548 def setMinDocFreq (value : Int ): this .type = set(minDocFreq, value)
@@ -48,14 +51,9 @@ private[feature] trait IDFParams extends Params with HasInputCol with HasOutputC
4851 * Validate and transform the input schema.
4952 */
5053 protected def validateAndTransformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
51- val map = this .paramMap ++ paramMap
52- val inputType = schema(map(inputCol)).dataType
53- require(inputType.isInstanceOf [VectorUDT ],
54- s " Input column ${map(inputCol)} must be a vector column " )
55- require(! schema.fieldNames.contains(map(outputCol)),
56- s " Output column ${map(outputCol)} already exists. " )
57- val outputFields = schema.fields :+ StructField (map(outputCol), new VectorUDT , false )
58- StructType (outputFields)
54+ val map = extractParamMap(paramMap)
55+ SchemaUtils .checkColumnType(schema, map(inputCol), new VectorUDT )
56+ SchemaUtils .appendColumn(schema, map(outputCol), new VectorUDT )
5957 }
6058}
6159
@@ -64,7 +62,7 @@ private[feature] trait IDFParams extends Params with HasInputCol with HasOutputC
6462 * Compute the Inverse Document Frequency (IDF) given a collection of documents.
6563 */
6664@ AlphaComponent
67- class IDF extends Estimator [IDFModel ] with IDFParams {
65+ class IDF extends Estimator [IDFModel ] with IDFBase {
6866
6967 /** @group setParam */
7068 def setInputCol (value : String ): this .type = set(inputCol, value)
@@ -74,7 +72,7 @@ class IDF extends Estimator[IDFModel] with IDFParams {
7472
7573 override def fit (dataset : DataFrame , paramMap : ParamMap ): IDFModel = {
7674 transformSchema(dataset.schema, paramMap, logging = true )
77- val map = this . paramMap ++ paramMap
75+ val map = extractParamMap( paramMap)
7876 val input = dataset.select(map(inputCol)).map { case Row (v : Vector ) => v }
7977 val idf = new feature.IDF (getMinDocFreq).fit(input)
8078 val model = new IDFModel (this , map, idf)
@@ -96,7 +94,7 @@ class IDFModel private[ml] (
9694 override val parent : IDF ,
9795 override val fittingParamMap : ParamMap ,
9896 idfModel : feature.IDFModel )
99- extends Model [IDFModel ] with IDFParams {
97+ extends Model [IDFModel ] with IDFBase {
10098
10199 /** @group setParam */
102100 def setInputCol (value : String ): this .type = set(inputCol, value)
@@ -106,9 +104,9 @@ class IDFModel private[ml] (
106104
107105 override def transform (dataset : DataFrame , paramMap : ParamMap ): DataFrame = {
108106 transformSchema(dataset.schema, paramMap, logging = true )
109- val map = this . paramMap ++ paramMap
110- val idf : Vector => Vector = ( vec) => idfModel.transform(vec)
111- dataset.withColumn(map(outputCol), callUDF( idf, new VectorUDT , col(map(inputCol))))
107+ val map = extractParamMap( paramMap)
108+ val idf = udf { vec : Vector => idfModel.transform(vec) }
109+ dataset.withColumn(map(outputCol), idf( col(map(inputCol))))
112110 }
113111
114112 override def transformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
0 commit comments