Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
01e4cdf
Merge remote-tracking branch 'upstream/master'
gatorsmile Nov 13, 2015
6835704
Merge remote-tracking branch 'upstream/master'
gatorsmile Nov 14, 2015
9180687
Merge remote-tracking branch 'upstream/master'
gatorsmile Nov 14, 2015
b38a21e
SPARK-11633
gatorsmile Nov 17, 2015
d2b84af
Merge remote-tracking branch 'upstream/master' into joinMakeCopy
gatorsmile Nov 17, 2015
fda8025
Merge remote-tracking branch 'upstream/master'
gatorspark Nov 17, 2015
ac0dccd
Merge branch 'master' of https://github.com/gatorsmile/spark
gatorspark Nov 17, 2015
6e0018b
Merge remote-tracking branch 'upstream/master'
Nov 20, 2015
0546772
converge
gatorsmile Nov 20, 2015
b37a64f
converge
gatorsmile Nov 20, 2015
c2a872c
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 6, 2016
ab6dbd7
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 6, 2016
4276356
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 6, 2016
2dab708
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 7, 2016
0458770
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 8, 2016
1debdfa
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 9, 2016
763706d
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 14, 2016
4de6ec1
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 18, 2016
9422a4f
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 19, 2016
52bdf48
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 20, 2016
1e95df3
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 23, 2016
fab24cf
Merge remote-tracking branch 'upstream/master'
gatorsmile Feb 1, 2016
8b2e33b
Merge remote-tracking branch 'upstream/master'
gatorsmile Feb 5, 2016
2ee1876
Merge remote-tracking branch 'upstream/master'
gatorsmile Feb 11, 2016
b9f0090
Merge remote-tracking branch 'upstream/master'
gatorsmile Feb 12, 2016
ade6f7e
Merge remote-tracking branch 'upstream/master'
gatorsmile Feb 15, 2016
9fd63d2
Merge remote-tracking branch 'upstream/master'
gatorsmile Feb 19, 2016
5199d49
Merge remote-tracking branch 'upstream/master'
gatorsmile Feb 22, 2016
404214c
Merge remote-tracking branch 'upstream/master'
gatorsmile Feb 23, 2016
c001dd9
Merge remote-tracking branch 'upstream/master'
gatorsmile Feb 25, 2016
59daa48
Merge remote-tracking branch 'upstream/master'
gatorsmile Mar 5, 2016
41d5f64
Merge remote-tracking branch 'upstream/master'
gatorsmile Mar 7, 2016
472a6e3
Merge remote-tracking branch 'upstream/master'
gatorsmile Mar 10, 2016
0fba10a
Merge remote-tracking branch 'upstream/master'
gatorsmile Mar 12, 2016
cbf73b3
Merge remote-tracking branch 'upstream/master'
gatorsmile Mar 21, 2016
c08f561
Merge remote-tracking branch 'upstream/master'
gatorsmile Mar 22, 2016
474df88
Merge remote-tracking branch 'upstream/master'
gatorsmile Mar 22, 2016
3d9828d
Merge remote-tracking branch 'upstream/master'
gatorsmile Mar 24, 2016
72d2361
Merge remote-tracking branch 'upstream/master'
gatorsmile Mar 26, 2016
07afea5
Merge remote-tracking branch 'upstream/master'
gatorsmile Mar 29, 2016
8bf2007
Merge remote-tracking branch 'upstream/master'
gatorsmile Mar 30, 2016
87a165b
Merge remote-tracking branch 'upstream/master'
gatorsmile Mar 31, 2016
b9359cd
Merge remote-tracking branch 'upstream/master'
gatorsmile Apr 1, 2016
65bd090
Merge remote-tracking branch 'upstream/master'
gatorsmile Apr 5, 2016
babf2da
Merge remote-tracking branch 'upstream/master'
gatorsmile Apr 5, 2016
9e09469
Merge remote-tracking branch 'upstream/master'
gatorsmile Apr 6, 2016
50a8e4a
Merge remote-tracking branch 'upstream/master'
gatorsmile Apr 6, 2016
f3337fa
Merge remote-tracking branch 'upstream/master'
gatorsmile Apr 10, 2016
09cc36d
Merge remote-tracking branch 'upstream/master'
gatorsmile Apr 12, 2016
83a1915
Merge remote-tracking branch 'upstream/master'
gatorsmile Apr 14, 2016
0483145
Merge remote-tracking branch 'upstream/master'
gatorsmile Apr 19, 2016
236a5f4
Merge remote-tracking branch 'upstream/master'
gatorsmile Apr 20, 2016
08aaa4d
Merge remote-tracking branch 'upstream/master'
gatorsmile Apr 21, 2016
64f704e
Merge remote-tracking branch 'upstream/master'
gatorsmile Apr 24, 2016
006ea2d
Merge remote-tracking branch 'upstream/master'
gatorsmile Apr 26, 2016
0c0dc8a
Merge remote-tracking branch 'upstream/master'
gatorsmile Apr 27, 2016
7c4b2f0
Merge remote-tracking branch 'upstream/master'
gatorsmile May 1, 2016
38f3af9
Merge remote-tracking branch 'upstream/master'
gatorsmile May 1, 2016
8089c6f
Merge remote-tracking branch 'upstream/master'
gatorsmile May 4, 2016
a6c7518
Merge remote-tracking branch 'upstream/master'
gatorsmile May 4, 2016
546c1db
Merge remote-tracking branch 'upstream/master'
gatorsmile May 4, 2016
e2ece35
Merge remote-tracking branch 'upstream/master'
gatorsmile May 5, 2016
13c04be
Merge remote-tracking branch 'upstream/master'
gatorsmile May 6, 2016
ac88fc1
Merge remote-tracking branch 'upstream/master'
gatorsmile May 6, 2016
154d3df
Merge remote-tracking branch 'upstream/master'
gatorsmile May 10, 2016
412e88a
Merge remote-tracking branch 'upstream/master'
gatorsmile May 10, 2016
c570065
Merge remote-tracking branch 'upstream/master'
gatorsmile May 11, 2016
ac03674
Merge remote-tracking branch 'upstream/master'
gatorsmile May 11, 2016
650cdcc
Merge remote-tracking branch 'upstream/master'
gatorsmile May 15, 2016
29d16c1
Merge remote-tracking branch 'upstream/master'
gatorsmile May 20, 2016
8d02eea
Merge remote-tracking branch 'upstream/master'
gatorsmile May 22, 2016
c752518
Merge remote-tracking branch 'upstream/master'
gatorsmile May 26, 2016
db0f48c
Merge remote-tracking branch 'upstream/master'
gatorsmile May 27, 2016
acc4a1c
fix
gatorsmile May 28, 2016
2e907db
Merge remote-tracking branch 'upstream/master'
gatorsmile May 28, 2016
dffc628
Merge branch 'sqlcontextMLTest' into sqlContextML
gatorsmile May 28, 2016
5bdc447
revert it back
gatorsmile May 29, 2016
7dcaaa4
address comments
gatorsmile Jun 2, 2016
42a1ed9
name change
gatorsmile Jun 17, 2016
0f316f9
add deprecate
gatorsmile Jun 17, 2016
f0efc3e
Merge remote-tracking branch 'upstream/master' into sqlContextML
gatorsmile Jun 22, 2016
65534a0
address comments
gatorsmile Jun 22, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(nodeData).write.parquet(dataPath)
sparkSession.createDataFrame(nodeData).write.parquet(dataPath)
}
}

Expand All @@ -258,7 +258,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val root = loadTreeNodes(path, metadata, sqlContext)
val root = loadTreeNodes(path, metadata, sparkSession)
val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses)
DefaultParamsReader.getAndSetParams(model, metadata)
model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
val extraMetadata: JObject = Map(
"numFeatures" -> instance.numFeatures,
"numTrees" -> instance.getNumTrees)
EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
}
}

Expand All @@ -283,7 +283,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
override def load(path: String): GBTClassificationModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
val data = Data(instance.numClasses, instance.numFeatures, instance.intercept,
instance.coefficients)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

Expand All @@ -674,7 +674,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.format("parquet").load(dataPath)
val data = sparkSession.read.format("parquet").load(dataPath)
.select("numClasses", "numFeatures", "intercept", "coefficients").head()
// We will need numClasses, numFeatures in the future for multinomial logreg support.
// val numClasses = data.getInt(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ object MultilayerPerceptronClassificationModel
// Save model data: layers, weights
val data = Data(instance.layers, instance.weights)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

Expand All @@ -370,7 +370,7 @@ object MultilayerPerceptronClassificationModel
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("layers", "weights").head()
val data = sparkSession.read.parquet(dataPath).select("layers", "weights").head()
val layers = data.getAs[Seq[Int]](0).toArray
val weights = data.getAs[Vector](1)
val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
// Save model data: pi, theta
val data = Data(instance.pi, instance.theta)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

Expand All @@ -275,7 +275,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head()
val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head()
val pi = data.getAs[Vector](0)
val theta = data.getAs[Matrix](1)
val model = new NaiveBayesModel(metadata.uid, pi, theta)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
"numFeatures" -> instance.numFeatures,
"numClasses" -> instance.numClasses,
"numTrees" -> instance.getNumTrees)
EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
}
}

Expand All @@ -296,7 +296,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
override def load(path: String): RandomForestClassificationModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
val sigmas = gaussians.map(c => OldMatrices.fromML(c.cov))
val data = Data(weights, mus, sigmas)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

Expand All @@ -208,7 +208,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val row = sqlContext.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
val weights = row.getSeq[Double](0).toArray
val mus = row.getSeq[OldVector](1).toArray
val sigmas = row.getSeq[OldMatrix](2).toArray
Expand Down
10 changes: 5 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
Data(idx, center)
}
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(data).repartition(1).write.parquet(dataPath)
sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
}
}

Expand All @@ -222,8 +222,8 @@ object KMeansModel extends MLReadable[KMeansModel] {

override def load(path: String): KMeansModel = {
// Import implicits for Dataset Encoder
val sqlContext = super.sqlContext
import sqlContext.implicits._
val sparkSession = super.sparkSession
import sparkSession.implicits._

val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
Expand All @@ -232,11 +232,11 @@ object KMeansModel extends MLReadable[KMeansModel] {
val versionRegex(major, _) = metadata.sparkVersion

val clusterCenters = if (major.toInt >= 2) {
val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data]
val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
} else {
// Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.
sqlContext.read.parquet(dataPath).as[OldData].head().clusterCenters
sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters
}
val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
DefaultParamsReader.getAndSetParams(model, metadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.selectedFeatures.toSeq)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

Expand All @@ -212,7 +212,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
override def load(path: String): ChiSqSelectorModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("selectedFeatures").head()
val data = sparkSession.read.parquet(dataPath).select("selectedFeatures").head()
val selectedFeatures = data.getAs[Seq[Int]](0).toArray
val oldModel = new feature.ChiSqSelectorModel(selectedFeatures)
val model = new ChiSqSelectorModel(metadata.uid, oldModel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.vocabulary)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

Expand All @@ -308,7 +308,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
override def load(path: String): CountVectorizerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath)
val data = sparkSession.read.parquet(dataPath)
.select("vocabulary")
.head()
val vocabulary = data.getAs[Seq[String]](0).toArray
Expand Down
4 changes: 2 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ object IDFModel extends MLReadable[IDFModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.idf)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

Expand All @@ -179,7 +179,7 @@ object IDFModel extends MLReadable[IDFModel] {
override def load(path: String): IDFModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath)
val data = sparkSession.read.parquet(dataPath)
.select("idf")
.head()
val idf = data.getAs[Vector](0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = new Data(instance.maxAbs)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

Expand All @@ -172,7 +172,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] {
override def load(path: String): MaxAbsScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val Row(maxAbs: Vector) = sqlContext.read.parquet(dataPath)
val Row(maxAbs: Vector) = sparkSession.read.parquet(dataPath)
.select("maxAbs")
.head()
val model = new MaxAbsScalerModel(metadata.uid, maxAbs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = new Data(instance.originalMin, instance.originalMax)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

Expand All @@ -232,7 +232,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] {
override def load(path: String): MinMaxScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath)
val Row(originalMin: Vector, originalMax: Vector) = sparkSession.read.parquet(dataPath)
.select("originalMin", "originalMax")
.head()
val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax)
Expand Down
6 changes: 3 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ object PCAModel extends MLReadable[PCAModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.pc, instance.explainedVariance)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

Expand Down Expand Up @@ -217,12 +217,12 @@ object PCAModel extends MLReadable[PCAModel] {
val dataPath = new Path(path, "data").toString
val model = if (hasExplainedVariance) {
val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
sqlContext.read.parquet(dataPath)
sparkSession.read.parquet(dataPath)
.select("pc", "explainedVariance")
.head()
new PCAModel(metadata.uid, pc, explainedVariance)
} else {
val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath).select("pc").head()
val Row(pc: DenseMatrix) = sparkSession.read.parquet(dataPath).select("pc").head()
new PCAModel(metadata.uid, pc, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector])
}
DefaultParamsReader.getAndSetParams(model, metadata)
Expand Down
12 changes: 6 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: resolvedFormula
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(instance.resolvedFormula))
sparkSession.createDataFrame(Seq(instance.resolvedFormula))
.repartition(1).write.parquet(dataPath)
// Save pipeline model
val pmPath = new Path(path, "pipelineModel").toString
Expand All @@ -314,7 +314,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("label", "terms", "hasIntercept").head()
val data = sparkSession.read.parquet(dataPath).select("label", "terms", "hasIntercept").head()
val label = data.getString(0)
val terms = data.getAs[Seq[Seq[String]]](1)
val hasIntercept = data.getBoolean(2)
Expand Down Expand Up @@ -372,7 +372,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] {
// Save model data: columnsToPrune
val data = Data(instance.columnsToPrune.toSeq)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

Expand All @@ -385,7 +385,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("columnsToPrune").head()
val data = sparkSession.read.parquet(dataPath).select("columnsToPrune").head()
val columnsToPrune = data.getAs[Seq[String]](0).toSet
val pruner = new ColumnPruner(metadata.uid, columnsToPrune)

Expand Down Expand Up @@ -463,7 +463,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite
// Save model data: vectorCol, prefixesToRewrite
val data = Data(instance.vectorCol, instance.prefixesToRewrite)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

Expand All @@ -476,7 +476,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head()
val data = sparkSession.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head()
val vectorCol = data.getString(0)
val prefixesToRewrite = data.getAs[Map[String, String]](1)
val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.std, instance.mean)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

Expand All @@ -211,7 +211,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
override def load(path: String): StandardScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath)
val Row(std: Vector, mean: Vector) = sparkSession.read.parquet(dataPath)
.select("std", "mean")
.head()
val model = new StandardScalerModel(metadata.uid, std, mean)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.labels)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

Expand All @@ -232,7 +232,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] {
override def load(path: String): StringIndexerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath)
val data = sparkSession.read.parquet(dataPath)
.select("labels")
.head()
val labels = data.getAs[Seq[String]](0).toArray
Expand Down
Loading