Skip to content

Commit 3e7e05f

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-12160][MLLIB] Use SQLContext.getOrCreate in MLlib
Switched from using SQLContext constructor to using getOrCreate, mainly in model save/load methods. This covers all instances in spark.mllib. There were no uses of the constructor in spark.ml. CC: mengxr yhuai Author: Joseph K. Bradley <[email protected]> Closes #10161 from jkbradley/mllib-sqlcontext-fix.
1 parent 36282f7 commit 3e7e05f

File tree

13 files changed

+29
-29
lines changed

13 files changed

+29
-29
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,7 +1191,7 @@ private[python] class PythonMLLibAPI extends Serializable {
11911191
def getIndexedRows(indexedRowMatrix: IndexedRowMatrix): DataFrame = {
11921192
// We use DataFrames for serialization of IndexedRows to Python,
11931193
// so return a DataFrame.
1194-
val sqlContext = new SQLContext(indexedRowMatrix.rows.sparkContext)
1194+
val sqlContext = SQLContext.getOrCreate(indexedRowMatrix.rows.sparkContext)
11951195
sqlContext.createDataFrame(indexedRowMatrix.rows)
11961196
}
11971197

@@ -1201,7 +1201,7 @@ private[python] class PythonMLLibAPI extends Serializable {
12011201
def getMatrixEntries(coordinateMatrix: CoordinateMatrix): DataFrame = {
12021202
// We use DataFrames for serialization of MatrixEntry entries to
12031203
// Python, so return a DataFrame.
1204-
val sqlContext = new SQLContext(coordinateMatrix.entries.sparkContext)
1204+
val sqlContext = SQLContext.getOrCreate(coordinateMatrix.entries.sparkContext)
12051205
sqlContext.createDataFrame(coordinateMatrix.entries)
12061206
}
12071207

@@ -1211,7 +1211,7 @@ private[python] class PythonMLLibAPI extends Serializable {
12111211
def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = {
12121212
// We use DataFrames for serialization of sub-matrix blocks to
12131213
// Python, so return a DataFrame.
1214-
val sqlContext = new SQLContext(blockMatrix.blocks.sparkContext)
1214+
val sqlContext = SQLContext.getOrCreate(blockMatrix.blocks.sparkContext)
12151215
sqlContext.createDataFrame(blockMatrix.blocks)
12161216
}
12171217
}

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
192192
modelType: String)
193193

194194
def save(sc: SparkContext, path: String, data: Data): Unit = {
195-
val sqlContext = new SQLContext(sc)
195+
val sqlContext = SQLContext.getOrCreate(sc)
196196
import sqlContext.implicits._
197197

198198
// Create JSON metadata.
@@ -208,7 +208,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
208208

209209
@Since("1.3.0")
210210
def load(sc: SparkContext, path: String): NaiveBayesModel = {
211-
val sqlContext = new SQLContext(sc)
211+
val sqlContext = SQLContext.getOrCreate(sc)
212212
// Load Parquet data.
213213
val dataRDD = sqlContext.read.parquet(dataPath(path))
214214
// Check schema explicitly since erasure makes it hard to use match-case for checking.
@@ -239,7 +239,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
239239
theta: Array[Array[Double]])
240240

241241
def save(sc: SparkContext, path: String, data: Data): Unit = {
242-
val sqlContext = new SQLContext(sc)
242+
val sqlContext = SQLContext.getOrCreate(sc)
243243
import sqlContext.implicits._
244244

245245
// Create JSON metadata.
@@ -254,7 +254,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
254254
}
255255

256256
def load(sc: SparkContext, path: String): NaiveBayesModel = {
257-
val sqlContext = new SQLContext(sc)
257+
val sqlContext = SQLContext.getOrCreate(sc)
258258
// Load Parquet data.
259259
val dataRDD = sqlContext.read.parquet(dataPath(path))
260260
// Check schema explicitly since erasure makes it hard to use match-case for checking.

mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel {
5151
weights: Vector,
5252
intercept: Double,
5353
threshold: Option[Double]): Unit = {
54-
val sqlContext = new SQLContext(sc)
54+
val sqlContext = SQLContext.getOrCreate(sc)
5555
import sqlContext.implicits._
5656

5757
// Create JSON metadata.
@@ -74,7 +74,7 @@ private[classification] object GLMClassificationModel {
7474
*/
7575
def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
7676
val datapath = Loader.dataPath(path)
77-
val sqlContext = new SQLContext(sc)
77+
val sqlContext = SQLContext.getOrCreate(sc)
7878
val dataRDD = sqlContext.read.parquet(datapath)
7979
val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
8080
assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
145145
weights: Array[Double],
146146
gaussians: Array[MultivariateGaussian]): Unit = {
147147

148-
val sqlContext = new SQLContext(sc)
148+
val sqlContext = SQLContext.getOrCreate(sc)
149149
import sqlContext.implicits._
150150

151151
// Create JSON metadata.
@@ -162,7 +162,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
162162

163163
def load(sc: SparkContext, path: String): GaussianMixtureModel = {
164164
val dataPath = Loader.dataPath(path)
165-
val sqlContext = new SQLContext(sc)
165+
val sqlContext = SQLContext.getOrCreate(sc)
166166
val dataFrame = sqlContext.read.parquet(dataPath)
167167
// Check schema explicitly since erasure makes it hard to use match-case for checking.
168168
Loader.checkSchema[Data](dataFrame.schema)

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ object KMeansModel extends Loader[KMeansModel] {
124124
val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
125125

126126
def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
127-
val sqlContext = new SQLContext(sc)
127+
val sqlContext = SQLContext.getOrCreate(sc)
128128
import sqlContext.implicits._
129129
val metadata = compact(render(
130130
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
@@ -137,7 +137,7 @@ object KMeansModel extends Loader[KMeansModel] {
137137

138138
def load(sc: SparkContext, path: String): KMeansModel = {
139139
implicit val formats = DefaultFormats
140-
val sqlContext = new SQLContext(sc)
140+
val sqlContext = SQLContext.getOrCreate(sc)
141141
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
142142
assert(className == thisClassName)
143143
assert(formatVersion == thisFormatVersion)

mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
7070

7171
@Since("1.4.0")
7272
def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = {
73-
val sqlContext = new SQLContext(sc)
73+
val sqlContext = SQLContext.getOrCreate(sc)
7474
import sqlContext.implicits._
7575

7676
val metadata = compact(render(
@@ -84,7 +84,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
8484
@Since("1.4.0")
8585
def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
8686
implicit val formats = DefaultFormats
87-
val sqlContext = new SQLContext(sc)
87+
val sqlContext = SQLContext.getOrCreate(sc)
8888

8989
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
9090
assert(className == thisClassName)

mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
134134
val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel"
135135

136136
def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = {
137-
val sqlContext = new SQLContext(sc)
137+
val sqlContext = SQLContext.getOrCreate(sc)
138138
import sqlContext.implicits._
139139
val metadata = compact(render(
140140
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
@@ -150,7 +150,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
150150

151151
def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
152152
implicit val formats = DefaultFormats
153-
val sqlContext = new SQLContext(sc)
153+
val sqlContext = SQLContext.getOrCreate(sc)
154154
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
155155
assert(className == thisClassName)
156156
assert(formatVersion == thisFormatVersion)

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
587587

588588
def load(sc: SparkContext, path: String): Word2VecModel = {
589589
val dataPath = Loader.dataPath(path)
590-
val sqlContext = new SQLContext(sc)
590+
val sqlContext = SQLContext.getOrCreate(sc)
591591
val dataFrame = sqlContext.read.parquet(dataPath)
592592
// Check schema explicitly since erasure makes it hard to use match-case for checking.
593593
Loader.checkSchema[Data](dataFrame.schema)
@@ -599,7 +599,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
599599

600600
def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = {
601601

602-
val sqlContext = new SQLContext(sc)
602+
val sqlContext = SQLContext.getOrCreate(sc)
603603
import sqlContext.implicits._
604604

605605
val vectorSize = model.values.head.size

mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
353353
*/
354354
def save(model: MatrixFactorizationModel, path: String): Unit = {
355355
val sc = model.userFeatures.sparkContext
356-
val sqlContext = new SQLContext(sc)
356+
val sqlContext = SQLContext.getOrCreate(sc)
357357
import sqlContext.implicits._
358358
val metadata = compact(render(
359359
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank)))
@@ -364,7 +364,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
364364

365365
def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
366366
implicit val formats = DefaultFormats
367-
val sqlContext = new SQLContext(sc)
367+
val sqlContext = SQLContext.getOrCreate(sc)
368368
val (className, formatVersion, metadata) = loadMetadata(sc, path)
369369
assert(className == thisClassName)
370370
assert(formatVersion == thisFormatVersion)

mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
185185
boundaries: Array[Double],
186186
predictions: Array[Double],
187187
isotonic: Boolean): Unit = {
188-
val sqlContext = new SQLContext(sc)
188+
val sqlContext = SQLContext.getOrCreate(sc)
189189

190190
val metadata = compact(render(
191191
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
@@ -198,7 +198,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
198198
}
199199

200200
def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = {
201-
val sqlContext = new SQLContext(sc)
201+
val sqlContext = SQLContext.getOrCreate(sc)
202202
val dataRDD = sqlContext.read.parquet(dataPath(path))
203203

204204
checkSchema[Data](dataRDD.schema)

0 commit comments

Comments
 (0)