Skip to content

Commit 5867c09

Browse files
committed
refine IDF transformer with new interfaces
1 parent 7727cae commit 5867c09

File tree

1 file changed

+16
-18
lines changed
  • mllib/src/main/scala/org/apache/spark/ml/feature

1 file changed

+16
-18
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,29 @@ package org.apache.spark.ml.feature
2020
import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml._
2222
import org.apache.spark.ml.param._
23+
import org.apache.spark.ml.param.shared._
24+
import org.apache.spark.ml.util.SchemaUtils
2325
import org.apache.spark.mllib.feature
2426
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
2527
import org.apache.spark.sql._
2628
import org.apache.spark.sql.functions._
27-
import org.apache.spark.sql.types.{StructField, StructType}
29+
import org.apache.spark.sql.types.StructType
2830

2931
/**
3032
* Params for [[IDF]] and [[IDFModel]].
3133
*/
32-
private[feature] trait IDFParams extends Params with HasInputCol with HasOutputCol {
34+
private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol {
3335

3436
/**
3537
* The minimum of documents in which a term should appear.
3638
* @group param
3739
*/
3840
val minDocFreq = new IntParam(
39-
this, "minDocFreq", "minimum of documents in which a term should appear for filtering", Some(0))
41+
this, "minDocFreq", "minimum of documents in which a term should appear for filtering")
42+
setDefault(minDocFreq -> 0)
4043

4144
/** @group getParam */
42-
def getMinDocFreq: Int = get(minDocFreq)
45+
def getMinDocFreq: Int = getOrDefault(minDocFreq)
4346

4447
/** @group setParam */
4548
def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
@@ -48,14 +51,9 @@ private[feature] trait IDFParams extends Params with HasInputCol with HasOutputC
4851
* Validate and transform the input schema.
4952
*/
5053
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
51-
val map = this.paramMap ++ paramMap
52-
val inputType = schema(map(inputCol)).dataType
53-
require(inputType.isInstanceOf[VectorUDT],
54-
s"Input column ${map(inputCol)} must be a vector column")
55-
require(!schema.fieldNames.contains(map(outputCol)),
56-
s"Output column ${map(outputCol)} already exists.")
57-
val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
58-
StructType(outputFields)
54+
val map = extractParamMap(paramMap)
55+
SchemaUtils.checkColumnType(schema, map(inputCol), new VectorUDT)
56+
SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT)
5957
}
6058
}
6159

@@ -64,7 +62,7 @@ private[feature] trait IDFParams extends Params with HasInputCol with HasOutputC
6462
* Compute the Inverse Document Frequency (IDF) given a collection of documents.
6563
*/
6664
@AlphaComponent
67-
class IDF extends Estimator[IDFModel] with IDFParams {
65+
class IDF extends Estimator[IDFModel] with IDFBase {
6866

6967
/** @group setParam */
7068
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -74,7 +72,7 @@ class IDF extends Estimator[IDFModel] with IDFParams {
7472

7573
override def fit(dataset: DataFrame, paramMap: ParamMap): IDFModel = {
7674
transformSchema(dataset.schema, paramMap, logging = true)
77-
val map = this.paramMap ++ paramMap
75+
val map = extractParamMap(paramMap)
7876
val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
7977
val idf = new feature.IDF(getMinDocFreq).fit(input)
8078
val model = new IDFModel(this, map, idf)
@@ -96,7 +94,7 @@ class IDFModel private[ml] (
9694
override val parent: IDF,
9795
override val fittingParamMap: ParamMap,
9896
idfModel: feature.IDFModel)
99-
extends Model[IDFModel] with IDFParams {
97+
extends Model[IDFModel] with IDFBase {
10098

10199
/** @group setParam */
102100
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -106,9 +104,9 @@ class IDFModel private[ml] (
106104

107105
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
108106
transformSchema(dataset.schema, paramMap, logging = true)
109-
val map = this.paramMap ++ paramMap
110-
val idf: Vector => Vector = (vec) => idfModel.transform(vec)
111-
dataset.withColumn(map(outputCol), callUDF(idf, new VectorUDT, col(map(inputCol))))
107+
val map = extractParamMap(paramMap)
108+
val idf = udf { vec: Vector => idfModel.transform(vec) }
109+
dataset.withColumn(map(outputCol), idf(col(map(inputCol))))
112110
}
113111

114112
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {

0 commit comments

Comments
 (0)