From b38a21ef6146784e4b93ef4ce8c899f1eee14572 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 16 Nov 2015 18:30:26 -0800 Subject: [PATCH 1/9] SPARK-11633 --- .../spark/sql/catalyst/analysis/Analyzer.scala | 3 ++- .../spark/sql/hive/execution/SQLQuerySuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2f4670b55bdb..5a5b71e52dd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -425,7 +425,8 @@ class Analyzer( */ j case Some((oldRelation, newRelation)) => - val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + val attributeRewrites = + AttributeMap(oldRelation.output.zip(newRelation.output).filter(x => x._1 != x._2)) val newRight = right transformUp { case r if r == oldRelation => newRelation } transformUp { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 3427152b2da0..5e00546a74c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -51,6 +51,8 @@ case class Order( state: String, month: Int) +case class Individual(F1: Integer, F2: Integer) + case class WindowData( month: Int, area: String, @@ -1479,4 +1481,18 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", 3.14, "hello")) } + + test ("SPARK-11633: HiveContext throws TreeNode Exception : Failed to Copy Node") { + val rdd1 = sparkContext.parallelize(Seq( Individual(1,3), Individual(2,1))) + val df = hiveContext.createDataFrame(rdd1) + df.registerTempTable("foo") + val df2 = sql("select f1, F2 as F2 from foo") + df2.registerTempTable("foo2") + df2.registerTempTable("foo3") + + checkAnswer(sql( + """ + SELECT a.F1 FROM foo2 a INNER JOIN foo3 b ON a.F2=b.F2 + """.stripMargin), Row(2) :: Row(1) :: Nil) + } } From 0546772f151f83d6d3cf4d000cbe341f52545007 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 20 Nov 2015 10:56:45 -0800 Subject: [PATCH 2/9] converge --- .../spark/sql/catalyst/analysis/Analyzer.scala | 3 +-- .../spark/sql/hive/execution/SQLQuerySuite.scala | 15 --------------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7c9512fbd00a..47962ebe6ef8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -417,8 +417,7 @@ class Analyzer( */ j case Some((oldRelation, newRelation)) => - val attributeRewrites = - AttributeMap(oldRelation.output.zip(newRelation.output).filter(x => x._1 != x._2)) + val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) val newRight = right transformUp { case r if r == oldRelation => newRelation } transformUp { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 5e00546a74c0..61d9dcd37572 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -51,8 +51,6 @@ case class Order( state: String, month: Int) -case class Individual(F1: Integer, F2: Integer) - case class WindowData( month: Int, area: String, @@ -1481,18 +1479,5 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", 3.14, "hello")) } - - test ("SPARK-11633: HiveContext throws TreeNode Exception : Failed to Copy Node") { - val rdd1 = sparkContext.parallelize(Seq( Individual(1,3), Individual(2,1))) - val df = hiveContext.createDataFrame(rdd1) - df.registerTempTable("foo") - val df2 = sql("select f1, F2 as F2 from foo") - df2.registerTempTable("foo2") - df2.registerTempTable("foo3") - - checkAnswer(sql( - """ - SELECT a.F1 FROM foo2 a INNER JOIN foo3 b ON a.F2=b.F2 - """.stripMargin), Row(2) :: Row(1) :: Nil) } } From b37a64f13956b6ddd0e38ddfd9fe1caee611f1a8 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 20 Nov 2015 10:58:37 -0800 Subject: [PATCH 3/9] converge --- .../org/apache/spark/sql/hive/execution/SQLQuerySuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 61d9dcd37572..3427152b2da0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1479,5 +1479,4 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", 3.14, "hello")) } - } } From acc4a1c85635d5329167f5a488dd1411451af071 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 28 May 2016 07:52:24 -0700 Subject: [PATCH 4/9] fix --- .../DecisionTreeClassifier.scala | 4 +-- .../ml/classification/GBTClassifier.scala | 4 +-- .../classification/LogisticRegression.scala | 4 +-- .../MultilayerPerceptronClassifier.scala | 4 +-- .../spark/ml/classification/NaiveBayes.scala | 4 +-- .../RandomForestClassifier.scala | 4 +-- .../spark/ml/clustering/GaussianMixture.scala | 4 +-- .../apache/spark/ml/clustering/KMeans.scala | 10 +++---- .../spark/ml/feature/ChiSqSelector.scala | 4 +-- .../spark/ml/feature/CountVectorizer.scala | 4 +-- .../org/apache/spark/ml/feature/IDF.scala | 4 +-- .../spark/ml/feature/MaxAbsScaler.scala | 4 +-- .../spark/ml/feature/MinMaxScaler.scala | 4 +-- .../org/apache/spark/ml/feature/PCA.scala | 6 ++--- .../apache/spark/ml/feature/RFormula.scala | 12 ++++----- .../spark/ml/feature/StandardScaler.scala | 4 +-- .../spark/ml/feature/StringIndexer.scala | 4 +-- .../spark/ml/feature/VectorIndexer.scala | 4 +-- .../apache/spark/ml/feature/Word2Vec.scala | 4 +-- .../apache/spark/ml/recommendation/ALS.scala | 4 +-- .../ml/regression/AFTSurvivalRegression.scala | 4 +-- .../ml/regression/DecisionTreeRegressor.scala | 4 +-- .../spark/ml/regression/GBTRegressor.scala | 4 +-- .../GeneralizedLinearRegression.scala | 4 +-- .../ml/regression/IsotonicRegression.scala | 4 +-- .../ml/regression/LinearRegression.scala | 4 +-- .../ml/regression/RandomForestRegressor.scala | 4 +-- .../org/apache/spark/ml/tree/treeModels.scala | 12 ++++----- .../org/apache/spark/ml/util/ReadWrite.scala | 26 +++++++++---------- .../ml/util/JavaDefaultReadWriteSuite.java | 2 +- project/MimaExcludes.scala | 6 +++++ .../execution/joins/BroadcastJoinSuite.scala | 2 +- 32 files changed, 88 insertions(+), 84 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 881dcefb79be..c65d3d5b5442 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -243,7 +243,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) val (nodeData, _) = NodeData.build(instance.rootNode, 0) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(nodeData).write.parquet(dataPath) + sparkSession.createDataFrame(nodeData).write.parquet(dataPath) } } @@ -258,7 +258,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] - val root = loadTreeNodes(path, metadata, sqlContext) + val root = loadTreeNodes(path, metadata, sparkSession) val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses) DefaultParamsReader.getAndSetParams(model, metadata) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index f843df449c61..4e534baddc63 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -270,7 +270,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { val extraMetadata: JObject = Map( "numFeatures" -> instance.numFeatures, "numTrees" -> instance.getNumTrees) - EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) } } @@ -283,7 +283,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { override def load(path: String): GBTClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 61e355ab9fba..2dff241e23d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -653,7 +653,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -667,7 +667,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.format("parquet").load(dataPath) + val data = sparkSession.read.format("parquet").load(dataPath) .select("numClasses", "numFeatures", "intercept", "coefficients").head() // We will need numClasses, numFeatures in the future for multinomial logreg support. // val numClasses = data.getInt(0) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index c4e882240ffd..700542117ee7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -356,7 +356,7 @@ object MultilayerPerceptronClassificationModel // Save model data: layers, weights val data = Data(instance.layers, instance.weights) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -370,7 +370,7 @@ object MultilayerPerceptronClassificationModel val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("layers", "weights").head() + val data = sparkSession.read.parquet(dataPath).select("layers", "weights").head() val layers = data.getAs[Seq[Int]](0).toArray val weights = data.getAs[Vector](1) val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index a98bdeca6b72..a9d493032b28 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -262,7 +262,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { // Save model data: pi, theta val data = Data(instance.pi, instance.theta) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -275,7 +275,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head() + val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head() val pi = data.getAs[Vector](0) val theta = data.getAs[Matrix](1) val model = new NaiveBayesModel(metadata.uid, pi, theta) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index b3c074f83925..9a26a5c5b143 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -282,7 +282,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica "numFeatures" -> instance.numFeatures, "numClasses" -> instance.numClasses, "numTrees" -> instance.getNumTrees) - EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) } } @@ -296,7 +296,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica override def load(path: String): RandomForestClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = - EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 773e50e24549..1cc3ff4c96e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -195,7 +195,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { val sigmas = gaussians.map(c => OldMatrices.fromML(c.cov)) val data = Data(weights, mus, sigmas) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -208,7 +208,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val row = sqlContext.read.parquet(dataPath).select("weights", "mus", "sigmas").head() + val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head() val weights = row.getSeq[Double](0).toArray val mus = row.getSeq[OldVector](1).toArray val sigmas = row.getSeq[OldMatrix](2).toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 790ef1fe8dc9..6f63d0481896 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -211,7 +211,7 @@ object KMeansModel extends MLReadable[KMeansModel] { Data(idx, center) } val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(data).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) } } @@ -222,8 +222,8 @@ object KMeansModel extends MLReadable[KMeansModel] { override def load(path: String): KMeansModel = { // Import implicits for Dataset Encoder - val sqlContext = super.sqlContext - import sqlContext.implicits._ + val sparkSession = super.sparkSession + import sparkSession.implicits._ val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString @@ -232,11 +232,11 @@ object KMeansModel extends MLReadable[KMeansModel] { val versionRegex(major, _) = metadata.sparkVersion val clusterCenters = if (major.toInt >= 2) { - val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data] + val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data] data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) } else { // Loads KMeansModel stored with the old format used by Spark 1.6 and earlier. - sqlContext.read.parquet(dataPath).as[OldData].head().clusterCenters + sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters } val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) DefaultParamsReader.getAndSetParams(model, metadata) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index e73a8f5d6608..7a4ba18c806c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -186,7 +186,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.selectedFeatures.toSeq) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -197,7 +197,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { override def load(path: String): ChiSqSelectorModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("selectedFeatures").head() + val data = sparkSession.read.parquet(dataPath).select("selectedFeatures").head() val selectedFeatures = data.getAs[Seq[Int]](0).toArray val oldModel = new feature.ChiSqSelectorModel(selectedFeatures) val model = new ChiSqSelectorModel(metadata.uid, oldModel) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index fc4885bf4ba8..d08dad15ec61 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -281,7 +281,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.vocabulary) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -292,7 +292,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { override def load(path: String): CountVectorizerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) .select("vocabulary") .head() val vocabulary = data.getAs[Seq[String]](0).toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 08beda6d7515..d259659c8cf0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -156,7 +156,7 @@ object IDFModel extends MLReadable[IDFModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.idf) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -167,7 +167,7 @@ object IDFModel extends MLReadable[IDFModel] { override def load(path: String): IDFModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) .select("idf") .head() val idf = data.getAs[Vector](0) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 0dffba93ac57..aa24a2d3ba20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -153,7 +153,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = new Data(instance.maxAbs) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -164,7 +164,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { override def load(path: String): MaxAbsScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(maxAbs: Vector) = sqlContext.read.parquet(dataPath) + val Row(maxAbs: Vector) = sparkSession.read.parquet(dataPath) .select("maxAbs") .head() val model = new MaxAbsScalerModel(metadata.uid, maxAbs) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index c6ff639f2962..a66aa9c744fb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -208,7 +208,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = new Data(instance.originalMin, instance.originalMax) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -219,7 +219,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { override def load(path: String): MinMaxScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath) + val Row(originalMin: Vector, originalMax: Vector) = sparkSession.read.parquet(dataPath) .select("originalMin", "originalMax") .head() val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index dbbaa5aa46f4..ec61a4929007 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -177,7 +177,7 @@ object PCAModel extends MLReadable[PCAModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.pc, instance.explainedVariance) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -208,12 +208,12 @@ object PCAModel extends MLReadable[PCAModel] { val dataPath = new Path(path, "data").toString val model = if (hasExplainedVariance) { val Row(pc: DenseMatrix, explainedVariance: DenseVector) = - sqlContext.read.parquet(dataPath) + sparkSession.read.parquet(dataPath) .select("pc", "explainedVariance") .head() new PCAModel(metadata.uid, pc, explainedVariance) } else { - val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath).select("pc").head() + val Row(pc: DenseMatrix) = sparkSession.read.parquet(dataPath).select("pc").head() new PCAModel(metadata.uid, pc, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) } DefaultParamsReader.getAndSetParams(model, metadata) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 2916b6d9df3b..ee6079bec0f9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -283,7 +283,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) // Save model data: resolvedFormula val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(instance.resolvedFormula)) + sparkSession.createDataFrame(Seq(instance.resolvedFormula)) .repartition(1).write.parquet(dataPath) // Save pipeline model val pmPath = new Path(path, "pipelineModel").toString @@ -300,7 +300,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("label", "terms", "hasIntercept").head() + val data = sparkSession.read.parquet(dataPath).select("label", "terms", "hasIntercept").head() val label = data.getString(0) val terms = data.getAs[Seq[Seq[String]]](1) val hasIntercept = data.getBoolean(2) @@ -358,7 +358,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { // Save model data: columnsToPrune val data = Data(instance.columnsToPrune.toSeq) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -371,7 +371,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("columnsToPrune").head() + val data = sparkSession.read.parquet(dataPath).select("columnsToPrune").head() val columnsToPrune = data.getAs[Seq[String]](0).toSet val pruner = new ColumnPruner(metadata.uid, columnsToPrune) @@ -449,7 +449,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite // Save model data: vectorCol, prefixesToRewrite val data = Data(instance.vectorCol, instance.prefixesToRewrite) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -462,7 +462,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head() + val data = sparkSession.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head() val vectorCol = data.getString(0) val prefixesToRewrite = data.getAs[Map[String, String]](1) val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 9d084b520c48..36e9ec472994 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -189,7 +189,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.std, instance.mean) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -200,7 +200,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { override def load(path: String): StandardScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath) + val Row(std: Vector, mean: Vector) = sparkSession.read.parquet(dataPath) .select("std", "mean") .head() val model = new StandardScalerModel(metadata.uid, std, mean) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index cc0571fd7e39..9fc78ef0869f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -207,7 +207,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.labels) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -218,7 +218,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { override def load(path: String): StringIndexerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) .select("labels") .head() val labels = data.getAs[Seq[String]](0).toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index d814528ec48d..615854cb78df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -436,7 +436,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.numFeatures, instance.categoryMaps) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -447,7 +447,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] { override def load(path: String): VectorIndexerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) .select("numFeatures", "categoryMaps") .head() val numFeatures = data.getAs[Int](0) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 1b929cdfffe3..2dce21d279a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -267,7 +267,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.wordVectors.wordIndex, instance.wordVectors.wordVectors.toSeq) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -278,7 +278,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] { override def load(path: String): Word2VecModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) .select("wordIndex", "wordVectors") .head() val wordIndex = data.getAs[Map[String, Int]](0) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 8dc7437d4747..53f18904fdf7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -320,9 +320,9 @@ object ALSModel extends MLReadable[ALSModel] { implicit val format = DefaultFormats val rank = (metadata.metadata \ "rank").extract[Int] val userPath = new Path(path, "userFactors").toString - val userFactors = sqlContext.read.format("parquet").load(userPath) + val userFactors = sparkSession.read.format("parquet").load(userPath) val itemPath = new Path(path, "itemFactors").toString - val itemFactors = sqlContext.read.format("parquet").load(itemPath) + val itemFactors = sparkSession.read.format("parquet").load(itemPath) val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 00ef6ccc74d2..ce51ccaca779 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -368,7 +368,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] // Save model data: coefficients, intercept, scale val data = Data(instance.coefficients, instance.intercept, instance.scale) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -381,7 +381,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) .select("coefficients", "intercept", "scale").head() val coefficients = data.getAs[Vector](0) val intercept = data.getDouble(1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index c4df9d11127f..7ff6d0afd55c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -249,7 +249,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) val (nodeData, _) = NodeData.build(instance.rootNode, 0) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(nodeData).write.parquet(dataPath) + sparkSession.createDataFrame(nodeData).write.parquet(dataPath) } } @@ -263,7 +263,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] - val root = loadTreeNodes(path, metadata, sqlContext) + val root = loadTreeNodes(path, metadata, sparkSession) val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures) DefaultParamsReader.getAndSetParams(model, metadata) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 81f2139f0b42..6223555504d7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -252,7 +252,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { val extraMetadata: JObject = Map( "numFeatures" -> instance.numFeatures, "numTrees" -> instance.getNumTrees) - EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) } } @@ -265,7 +265,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { override def load(path: String): GBTRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index adbdd345e92e..a23e90d9e125 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -813,7 +813,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr // Save model data: intercept, coefficients val data = Data(instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -827,7 +827,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) .select("intercept", "coefficients").head() val intercept = data.getDouble(0) val coefficients = data.getAs[Vector](1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index d16e8e3f6b25..f05b47eda7b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -284,7 +284,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { val data = Data( instance.oldModel.boundaries, instance.oldModel.predictions, instance.oldModel.isotonic) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -297,7 +297,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) .select("boundaries", "predictions", "isotonic").head() val boundaries = data.getAs[Seq[Double]](0).toArray val predictions = data.getAs[Seq[Double]](1).toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 4d66b0eb37ab..c302972e8bb4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -479,7 +479,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { // Save model data: intercept, coefficients val data = Data(instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -492,7 +492,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.format("parquet").load(dataPath) + val data = sparkSession.read.format("parquet").load(dataPath) .select("intercept", "coefficients").head() val intercept = data.getDouble(0) val coefficients = data.getAs[Vector](1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index a6dbf21d55e2..4f4d3d27841d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -244,7 +244,7 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode val extraMetadata: JObject = Map( "numFeatures" -> instance.numFeatures, "numTrees" -> instance.getNumTrees) - EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) } } @@ -257,7 +257,7 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode override def load(path: String): RandomForestRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 56c85c9b53e1..5b6fcc53c2dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -31,7 +31,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Dataset, SQLContext} +import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.util.collection.OpenHashMap /** @@ -332,8 +332,8 @@ private[ml] object DecisionTreeModelReadWrite { def loadTreeNodes( path: String, metadata: DefaultParamsReader.Metadata, - sqlContext: SQLContext): Node = { - import sqlContext.implicits._ + sparkSession: SparkSession): Node = { + import sparkSession.implicits._ implicit val format = DefaultFormats // Get impurity to construct ImpurityCalculator for each node @@ -343,7 +343,7 @@ private[ml] object DecisionTreeModelReadWrite { } val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).as[NodeData] + val data = sparkSession.read.parquet(dataPath).as[NodeData] buildTreeFromNodes(data.collect(), impurityType) } @@ -393,7 +393,7 @@ private[ml] object EnsembleModelReadWrite { def saveImpl[M <: Params with TreeEnsembleModel[_ <: DecisionTreeModel]]( instance: M, path: String, - sql: SQLContext, + sql: SparkSession, extraMetadata: JObject): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata)) val treesMetadataWeights: Array[(Int, String, Double)] = instance.trees.zipWithIndex.map { @@ -424,7 +424,7 @@ private[ml] object EnsembleModelReadWrite { */ def loadImpl( path: String, - sql: SQLContext, + sql: SparkSession, className: String, treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { import sql.implicits._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 94d1b83ec253..cfde2dd1837b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -33,36 +33,34 @@ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams -import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.SparkSession import org.apache.spark.util.Utils /** * Trait for [[MLWriter]] and [[MLReader]]. */ private[util] sealed trait BaseReadWrite { - private var optionSQLContext: Option[SQLContext] = None + private var optionSparkSession: Option[SparkSession] = None /** - * Sets the SQL context to use for saving/loading. + * Sets the Spark Session to use for saving/loading. */ @Since("1.6.0") - def context(sqlContext: SQLContext): this.type = { - optionSQLContext = Option(sqlContext) + def context(sparkSession: SparkSession): this.type = { + optionSparkSession = Option(sparkSession) this } /** - * Returns the user-specified SQL context or the default. + * Returns the user-specified Spark Session or the default. */ - protected final def sqlContext: SQLContext = { - if (optionSQLContext.isEmpty) { - optionSQLContext = Some(SQLContext.getOrCreate(SparkContext.getOrCreate())) + protected final def sparkSession: SparkSession = { + if (optionSparkSession.isEmpty) { + optionSparkSession = Some(SparkSession.builder().getOrCreate()) } - optionSQLContext.get + optionSparkSession.get } - protected final def sparkSession: SparkSession = sqlContext.sparkSession - /** Returns the underlying [[SparkContext]]. */ protected final def sc: SparkContext = sparkSession.sparkContext } @@ -116,7 +114,7 @@ abstract class MLWriter extends BaseReadWrite with Logging { } // override for Java compatibility - override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) + override def context(sparkSession: SparkSession): this.type = super.context(sparkSession) } /** @@ -160,7 +158,7 @@ abstract class MLReader[T] extends BaseReadWrite { def load(path: String): T // override for Java compatibility - override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) + override def context(sparkSession: SparkSession): this.type = super.context(sparkSession) } /** diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java index 7bda219243bf..5598ff3f10ec 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -56,7 +56,7 @@ public void testDefaultReadWrite() throws IOException { } catch (IOException e) { // expected } - instance.write().context(spark.sqlContext()).overwrite().save(outputPath); + instance.write().context(spark).overwrite().save(outputPath); MyParams newInstance = MyParams.load(outputPath); Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); Assert.assertEquals("Params should be preserved.", diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 73debe9da427..a235fb3055f1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -775,6 +775,12 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.AlphaComponent"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.Experimental"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.DeveloperApi") + ) ++ Seq( + // [] Replace SQLContext by SparkSession in MLLIB + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLWriter.sqlContext"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.util.MLWriter.context"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLReader.sqlContext"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.util.MLReader.context") ) case v if v.startsWith("1.6") => Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index db32b6b6befb..bc89ff86ab82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -48,7 +48,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { .setMaster("local-cluster[2,1,1024]") .setAppName("testing") val sc = new SparkContext(conf) - spark = SparkSession.builder.getOrCreate() + spark = SparkSession.builder().config(sc.getConf).getOrCreate() } override def afterAll(): Unit = { From 5bdc4470f4b2531aba94b1c67ac9a479e34c945c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 28 May 2016 20:47:39 -0700 Subject: [PATCH 5/9] revert it back --- .../apache/spark/sql/execution/joins/BroadcastJoinSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index bc89ff86ab82..db32b6b6befb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -48,7 +48,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { .setMaster("local-cluster[2,1,1024]") .setAppName("testing") val sc = new SparkContext(conf) - spark = SparkSession.builder().config(sc.getConf).getOrCreate() + spark = SparkSession.builder.getOrCreate() } override def afterAll(): Unit = { From 7dcaaa4539324d596c290ca189adc2e65d77bb0c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 1 Jun 2016 18:52:05 -0700 Subject: [PATCH 6/9] address comments --- .../org/apache/spark/ml/util/ReadWrite.scala | 24 +++++++++++++++++-- project/MimaExcludes.scala | 6 ----- .../org/apache/spark/sql/SparkSession.scala | 2 +- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index cfde2dd1837b..66e251272a0f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -33,7 +33,7 @@ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.util.Utils /** @@ -43,9 +43,18 @@ private[util] sealed trait BaseReadWrite { private var optionSparkSession: Option[SparkSession] = None /** - * Sets the Spark Session to use for saving/loading. + * Sets the Spark SQLContext to use for saving/loading. */ @Since("1.6.0") + def context(sqlContext: SQLContext): this.type = { + optionSparkSession = Option(sqlContext.sparkSession) + this + } + + /** + * Sets the Spark Session to use for saving/loading. + */ + @Since("2.0.0") def context(sparkSession: SparkSession): this.type = { optionSparkSession = Option(sparkSession) this @@ -61,6 +70,11 @@ private[util] sealed trait BaseReadWrite { optionSparkSession.get } + /** + * Returns the user-specified SQL context or the default. + */ + protected final def sqlContext: SQLContext = sparkSession.sqlContext + /** Returns the underlying [[SparkContext]]. */ protected final def sc: SparkContext = sparkSession.sparkContext } @@ -115,6 +129,9 @@ abstract class MLWriter extends BaseReadWrite with Logging { // override for Java compatibility override def context(sparkSession: SparkSession): this.type = super.context(sparkSession) + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) } /** @@ -159,6 +176,9 @@ abstract class MLReader[T] extends BaseReadWrite { // override for Java compatibility override def context(sparkSession: SparkSession): this.type = super.context(sparkSession) + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) } /** diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 574c23225bf8..9d0d9b1be077 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -778,12 +778,6 @@ object MimaExcludes { ) ++ Seq( ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.asBreeze"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.asBreeze") - ) ++ Seq( - // [SPARK-15644] Replace SQLContext by SparkSession in MLLIB - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLWriter.sqlContext"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.util.MLWriter.context"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLReader.sqlContext"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.util.MLReader.context") ) case v if v.startsWith("1.6") => Seq( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 20e22baa351a..d117dc3c26e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -109,7 +109,7 @@ class SparkSession private( * A wrapped version of this session in the form of a [[SQLContext]], for backward compatibility. */ @transient - private[sql] val sqlContext: SQLContext = new SQLContext(this) + private[spark] val sqlContext: SQLContext = new SQLContext(this) /** * Runtime configuration interface for Spark. From 42a1ed933b88645bee3523b198ab66afb5907a04 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 17 Jun 2016 11:31:42 -0700 Subject: [PATCH 7/9] name change --- .../src/main/scala/org/apache/spark/ml/util/ReadWrite.scala | 6 +++--- .../org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 66e251272a0f..8bb99579eb2c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -55,7 +55,7 @@ private[util] sealed trait BaseReadWrite { * Sets the Spark Session to use for saving/loading. */ @Since("2.0.0") - def context(sparkSession: SparkSession): this.type = { + def session(sparkSession: SparkSession): this.type = { optionSparkSession = Option(sparkSession) this } @@ -128,7 +128,7 @@ abstract class MLWriter extends BaseReadWrite with Logging { } // override for Java compatibility - override def context(sparkSession: SparkSession): this.type = super.context(sparkSession) + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) // override for Java compatibility override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) @@ -175,7 +175,7 @@ abstract class MLReader[T] extends BaseReadWrite { def load(path: String): T // override for Java compatibility - override def context(sparkSession: SparkSession): this.type = super.context(sparkSession) + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) // override for Java compatibility override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java index 5598ff3f10ec..e4f678fef1d1 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -56,7 +56,7 @@ public void testDefaultReadWrite() throws IOException { } catch (IOException e) { // expected } - instance.write().context(spark).overwrite().save(outputPath); + instance.write().session(spark).overwrite().save(outputPath); MyParams newInstance = MyParams.load(outputPath); Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); Assert.assertEquals("Params should be preserved.", From 0f316f9c3e8459a4356e8371afb5fb0d15216788 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 17 Jun 2016 11:38:33 -0700 Subject: [PATCH 8/9] add deprecate --- mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 8bb99579eb2c..ce9127a8d341 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -46,6 +46,7 @@ private[util] sealed trait BaseReadWrite { * Sets the Spark SQLContext to use for saving/loading. */ @Since("1.6.0") + @deprecated("Use session instead", "2.0.0") def context(sqlContext: SQLContext): this.type = { optionSparkSession = Option(sqlContext.sparkSession) this From 65534a04bd4fc68347ee9aff71d4de186e9656a0 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 21 Jun 2016 19:23:16 -0700 Subject: [PATCH 9/9] address comments --- mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 1541d2e99cc3..1582a73ea047 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -134,7 +134,7 @@ abstract class MLWriter extends BaseReadWrite with Logging { override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) // override for Java compatibility - override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } /** @@ -199,7 +199,7 @@ abstract class MLReader[T] extends BaseReadWrite { override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) // override for Java compatibility - override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } /**