Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@
*/
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.DataFrame
import org.apache.spark.util.collection.OpenHashMap

/**
Expand Down Expand Up @@ -105,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit
*/
@Experimental
class CountVectorizer(override val uid: String)
extends Estimator[CountVectorizerModel] with CountVectorizerParams {
extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable {

def this() = this(Identifiable.randomUID("cntVec"))

Expand Down Expand Up @@ -169,6 +171,19 @@ class CountVectorizer(override val uid: String)
}

override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra)

@Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}

@Since("1.6.0")
object CountVectorizer extends Readable[CountVectorizer] {

@Since("1.6.0")
override def read: Reader[CountVectorizer] = new DefaultParamsReader

@Since("1.6.0")
override def load(path: String): CountVectorizer = super.load(path)
}

/**
Expand All @@ -178,7 +193,9 @@ class CountVectorizer(override val uid: String)
*/
@Experimental
class CountVectorizerModel(override val uid: String, val vocabulary: Array[String])
extends Model[CountVectorizerModel] with CountVectorizerParams {
extends Model[CountVectorizerModel] with CountVectorizerParams with Writable {

import CountVectorizerModel._

def this(vocabulary: Array[String]) = {
this(Identifiable.randomUID("cntVecModel"), vocabulary)
Expand Down Expand Up @@ -232,4 +249,47 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent)
copyValues(copied, extra)
}

@Since("1.6.0")
override def write: Writer = new CountVectorizerModelWriter(this)
}

@Since("1.6.0")
object CountVectorizerModel extends Readable[CountVectorizerModel] {

private[CountVectorizerModel]
class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer {

private case class Data(vocabulary: Seq[String])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.vocabulary)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class CountVectorizerModelReader extends Reader[CountVectorizerModel] {

private val className = "org.apache.spark.ml.feature.CountVectorizerModel"

override def load(path: String): CountVectorizerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath)
.select("vocabulary")
.head()
val vocabulary = data.getAs[Seq[String]](0).toArray
val model = new CountVectorizerModel(metadata.uid, vocabulary)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("1.6.0")
override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader

@Since("1.6.0")
override def load(path: String): CountVectorizerModel = super.load(path)
}
71 changes: 67 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
Expand Down Expand Up @@ -60,7 +62,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
* Compute the Inverse Document Frequency (IDF) given a collection of documents.
*/
@Experimental
final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase {
final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable {

def this() = this(Identifiable.randomUID("idf"))

Expand All @@ -85,6 +87,19 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
}

override def copy(extra: ParamMap): IDF = defaultCopy(extra)

@Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}

@Since("1.6.0")
object IDF extends Readable[IDF] {

@Since("1.6.0")
override def read: Reader[IDF] = new DefaultParamsReader

@Since("1.6.0")
override def load(path: String): IDF = super.load(path)
}

/**
Expand All @@ -95,7 +110,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
class IDFModel private[ml] (
override val uid: String,
idfModel: feature.IDFModel)
extends Model[IDFModel] with IDFBase {
extends Model[IDFModel] with IDFBase with Writable {

import IDFModel._

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
Expand All @@ -117,4 +134,50 @@ class IDFModel private[ml] (
val copied = new IDFModel(uid, idfModel)
copyValues(copied, extra).setParent(parent)
}

/** Returns the IDF vector. */
@Since("1.6.0")
def idf: Vector = idfModel.idf

@Since("1.6.0")
override def write: Writer = new IDFModelWriter(this)
}

@Since("1.6.0")
object IDFModel extends Readable[IDFModel] {

private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer {

private case class Data(idf: Vector)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.idf)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class IDFModelReader extends Reader[IDFModel] {

private val className = "org.apache.spark.ml.feature.IDFModel"

override def load(path: String): IDFModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath)
.select("idf")
.head()
val idf = data.getAs[Vector](0)
val model = new IDFModel(metadata.uid, new feature.IDFModel(idf))
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("1.6.0")
override def read: Reader[IDFModel] = new IDFModelReader

@Since("1.6.0")
override def load(path: String): IDFModel = super.load(path)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params}
import org.apache.spark.ml.util.Identifiable

import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.sql._
Expand Down Expand Up @@ -85,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
*/
@Experimental
class MinMaxScaler(override val uid: String)
extends Estimator[MinMaxScalerModel] with MinMaxScalerParams {
extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable {

def this() = this(Identifiable.randomUID("minMaxScal"))

Expand Down Expand Up @@ -115,6 +118,19 @@ class MinMaxScaler(override val uid: String)
}

override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra)

@Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}

@Since("1.6.0")
object MinMaxScaler extends Readable[MinMaxScaler] {

@Since("1.6.0")
override def read: Reader[MinMaxScaler] = new DefaultParamsReader

@Since("1.6.0")
override def load(path: String): MinMaxScaler = super.load(path)
}

/**
Expand All @@ -131,7 +147,9 @@ class MinMaxScalerModel private[ml] (
override val uid: String,
val originalMin: Vector,
val originalMax: Vector)
extends Model[MinMaxScalerModel] with MinMaxScalerParams {
extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable {

import MinMaxScalerModel._

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
Expand Down Expand Up @@ -175,4 +193,46 @@ class MinMaxScalerModel private[ml] (
val copied = new MinMaxScalerModel(uid, originalMin, originalMax)
copyValues(copied, extra).setParent(parent)
}

@Since("1.6.0")
override def write: Writer = new MinMaxScalerModelWriter(this)
}

@Since("1.6.0")
object MinMaxScalerModel extends Readable[MinMaxScalerModel] {

private[MinMaxScalerModel]
class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer {

private case class Data(originalMin: Vector, originalMax: Vector)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = new Data(instance.originalMin, instance.originalMax)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] {

private val className = "org.apache.spark.ml.feature.MinMaxScalerModel"

override def load(path: String): MinMaxScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath)
.select("originalMin", "originalMax")
.head()
val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("1.6.0")
override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader

@Since("1.6.0")
override def load(path: String): MinMaxScalerModel = super.load(path)
}
Loading