Skip to content

Commit 459fd34

Browse files
dongjoon-hyunAndrew Or
authored andcommitted
[SPARK-15618][SQL][MLLIB] Use SparkSession.builder.sparkContext if applicable.
This PR changes function `SparkSession.builder.sparkContext(..)` from **private[sql]** into **private[spark]**, and uses it if applicable like the followings. ``` - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() ``` Pass the existing Jenkins tests. Author: Dongjoon Hyun <[email protected]> Closes #13365 from dongjoon-hyun/SPARK-15618. (cherry picked from commit 85d6b0d) Signed-off-by: Andrew Or <[email protected]>
1 parent ac4cb17 commit 459fd34

File tree

28 files changed

+78
-89
lines changed

28 files changed

+78
-89
lines changed

examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,10 @@ object BroadcastTest {
2929

3030
val blockSize = if (args.length > 2) args(2) else "4096"
3131

32-
val sparkConf = new SparkConf()
33-
.set("spark.broadcast.blockSize", blockSize)
34-
3532
val spark = SparkSession
36-
.builder
37-
.config(sparkConf)
33+
.builder()
3834
.appName("Broadcast Test")
35+
.config("spark.broadcast.blockSize", blockSize)
3936
.getOrCreate()
4037

4138
val sc = spark.sparkContext

examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ object LDAExample {
191191

192192
val spark = SparkSession
193193
.builder
194+
.sparkContext(sc)
194195
.getOrCreate()
195196
import spark.implicits._
196197

examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import java.io.File
2222

2323
import com.google.common.io.{ByteStreams, Files}
2424

25-
import org.apache.spark.SparkConf
2625
import org.apache.spark.sql._
2726

2827
object HiveFromSpark {
@@ -35,8 +34,6 @@ object HiveFromSpark {
3534
ByteStreams.copy(kv1Stream, Files.newOutputStreamSupplier(kv1File))
3635

3736
def main(args: Array[String]) {
38-
val sparkConf = new SparkConf().setAppName("HiveFromSpark")
39-
4037
// When working with Hive, one must instantiate `SparkSession` with Hive support, including
4138
// connectivity to a persistent Hive metastore, support for Hive serdes, and Hive user-defined
4239
// functions. Users who do not have an existing Hive deployment can still enable Hive support.
@@ -45,7 +42,7 @@ object HiveFromSpark {
4542
// which defaults to the directory `spark-warehouse` in the current directory that the spark
4643
// application is started.
4744
val spark = SparkSession.builder
48-
.config(sparkConf)
45+
.appName("HiveFromSpark")
4946
.enableHiveSupport()
5047
.getOrCreate()
5148
val sc = spark.sparkContext

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,7 +1177,7 @@ private[python] class PythonMLLibAPI extends Serializable {
11771177
// We use DataFrames for serialization of IndexedRows to Python,
11781178
// so return a DataFrame.
11791179
val sc = indexedRowMatrix.rows.sparkContext
1180-
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
1180+
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
11811181
spark.createDataFrame(indexedRowMatrix.rows)
11821182
}
11831183

@@ -1188,7 +1188,7 @@ private[python] class PythonMLLibAPI extends Serializable {
11881188
// We use DataFrames for serialization of MatrixEntry entries to
11891189
// Python, so return a DataFrame.
11901190
val sc = coordinateMatrix.entries.sparkContext
1191-
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
1191+
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
11921192
spark.createDataFrame(coordinateMatrix.entries)
11931193
}
11941194

@@ -1199,7 +1199,7 @@ private[python] class PythonMLLibAPI extends Serializable {
11991199
// We use DataFrames for serialization of sub-matrix blocks to
12001200
// Python, so return a DataFrame.
12011201
val sc = blockMatrix.blocks.sparkContext
1202-
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
1202+
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
12031203
spark.createDataFrame(blockMatrix.blocks)
12041204
}
12051205
}

mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ class LogisticRegressionWithLBFGS
437437
lr.setMaxIter(optimizer.getNumIterations())
438438
lr.setTol(optimizer.getConvergenceTol())
439439
// Convert our input into a DataFrame
440-
val spark = SparkSession.builder().config(input.context.getConf).getOrCreate()
440+
val spark = SparkSession.builder().sparkContext(input.context).getOrCreate()
441441
val df = spark.createDataFrame(input.map(_.asML))
442442
// Determine if we should cache the DF
443443
val handlePersistence = input.getStorageLevel == StorageLevel.NONE

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
193193
modelType: String)
194194

195195
def save(sc: SparkContext, path: String, data: Data): Unit = {
196-
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
196+
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
197197

198198
// Create JSON metadata.
199199
val metadata = compact(render(
@@ -207,7 +207,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
207207

208208
@Since("1.3.0")
209209
def load(sc: SparkContext, path: String): NaiveBayesModel = {
210-
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
210+
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
211211
// Load Parquet data.
212212
val dataRDD = spark.read.parquet(dataPath(path))
213213
// Check schema explicitly since erasure makes it hard to use match-case for checking.
@@ -238,7 +238,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
238238
theta: Array[Array[Double]])
239239

240240
def save(sc: SparkContext, path: String, data: Data): Unit = {
241-
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
241+
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
242242

243243
// Create JSON metadata.
244244
val metadata = compact(render(
@@ -251,7 +251,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
251251
}
252252

253253
def load(sc: SparkContext, path: String): NaiveBayesModel = {
254-
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
254+
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
255255
// Load Parquet data.
256256
val dataRDD = spark.read.parquet(dataPath(path))
257257
// Check schema explicitly since erasure makes it hard to use match-case for checking.

mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel {
5151
weights: Vector,
5252
intercept: Double,
5353
threshold: Option[Double]): Unit = {
54-
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
54+
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
5555

5656
// Create JSON metadata.
5757
val metadata = compact(render(
@@ -73,7 +73,7 @@ private[classification] object GLMClassificationModel {
7373
*/
7474
def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
7575
val dataPath = Loader.dataPath(path)
76-
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
76+
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
7777
val dataRDD = spark.read.parquet(dataPath)
7878
val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
7979
assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath")

mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
144144
val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel"
145145

146146
def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = {
147-
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
147+
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
148148
val metadata = compact(render(
149149
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)
150150
~ ("rootId" -> model.root.index)))
@@ -165,7 +165,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
165165
}
166166

167167
def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = {
168-
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
168+
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
169169
val rows = spark.read.parquet(Loader.dataPath(path))
170170
Loader.checkSchema[Data](rows.schema)
171171
val data = rows.select("index", "size", "center", "norm", "cost", "height", "children")

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
143143
path: String,
144144
weights: Array[Double],
145145
gaussians: Array[MultivariateGaussian]): Unit = {
146-
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
146+
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
147147

148148
// Create JSON metadata.
149149
val metadata = compact(render
@@ -159,7 +159,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
159159

160160
def load(sc: SparkContext, path: String): GaussianMixtureModel = {
161161
val dataPath = Loader.dataPath(path)
162-
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
162+
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
163163
val dataFrame = spark.read.parquet(dataPath)
164164
// Check schema explicitly since erasure makes it hard to use match-case for checking.
165165
Loader.checkSchema[Data](dataFrame.schema)

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ object KMeansModel extends Loader[KMeansModel] {
123123
val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
124124

125125
def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
126-
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
126+
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
127127
val metadata = compact(render(
128128
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
129129
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
@@ -135,7 +135,7 @@ object KMeansModel extends Loader[KMeansModel] {
135135

136136
def load(sc: SparkContext, path: String): KMeansModel = {
137137
implicit val formats = DefaultFormats
138-
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
138+
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
139139
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
140140
assert(className == thisClassName)
141141
assert(formatVersion == thisFormatVersion)

0 commit comments

Comments
 (0)