Skip to content

Commit 946186e

Browse files
author
Joseph Batchik
committed
started working on service loader
1 parent 103d8cc commit 946186e

File tree

13 files changed

+81
-14
lines changed

13 files changed

+81
-14
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
org.apache.spark.sql.json.DefaultSource
2+
org.apache.spark.sql.parquet.DefaultSource
3+
org.apache.spark.sql.jdbc.DefaultSource

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.datasources
1919

20+
import java.util.ServiceLoader
21+
2022
import scala.language.{existentials, implicitConversions}
2123
import scala.util.matching.Regex
2224

@@ -190,23 +192,24 @@ private[sql] class DDLParser(
190192
}
191193
}
192194

193-
private[sql] object ResolvedDataSource {
194-
195-
private val builtinSources = Map(
196-
"jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource",
197-
"json" -> "org.apache.spark.sql.json.DefaultSource",
198-
"parquet" -> "org.apache.spark.sql.parquet.DefaultSource",
199-
"orc" -> "org.apache.spark.sql.hive.orc.DefaultSource"
200-
)
195+
private[sql] object ResolvedDataSource extends Logging {
201196

202197
/** Given a provider name, look up the data source class definition. */
203198
def lookupDataSource(provider: String): Class[_] = {
204199
val loader = Utils.getContextOrSparkClassLoader
205200

206-
if (builtinSources.contains(provider)) {
207-
return loader.loadClass(builtinSources(provider))
201+
val sl = ServiceLoader.load(classOf[DataSourceProvider], loader)
202+
203+
val itr = sl.iterator()
204+
while (itr.hasNext) {
205+
val service = itr.next()
206+
if (service.format() == provider) {
207+
return service.getClass
208+
}
208209
}
209210

211+
logInfo(s"could not find registered data source for $provider")
212+
210213
try {
211214
loader.loadClass(provider)
212215
} catch {

sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ private[sql] object JDBCRelation {
7878
}
7979

8080
private[sql] class DefaultSource extends RelationProvider {
81+
82+
override def format(): String = "jdbc"
83+
8184
/** Returns a new base relation with the given parameters. */
8285
override def createRelation(
8386
sqlContext: SQLContext,

sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ private[sql] class DefaultSource
3434
with SchemaRelationProvider
3535
with CreatableRelationProvider {
3636

37+
override def format(): String = "json"
38+
3739
private def checkPath(parameters: Map[String, String]): String = {
3840
parameters.getOrElse("path", sys.error("'path' must be specified for json data."))
3941
}

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ import org.apache.spark.util.{SerializableConfiguration, Utils}
5151

5252

5353
private[sql] class DefaultSource extends HadoopFsRelationProvider {
54+
55+
override def format(): String = "parquet"
56+
5457
override def createRelation(
5558
sqlContext: SQLContext,
5659
paths: Array[String],

sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,27 @@ import org.apache.spark.sql.types.StructType
3737
import org.apache.spark.sql._
3838
import org.apache.spark.util.SerializableConfiguration
3939

40+
/**
41+
* ::DeveloperApi::
42+
* Base trait for all relation providers. All relation providers need to provide a string
43+
* representing the format they load. ex: parquet.DefaultSource.format = "parquet".
44+
* This allows users to specify that string as the format to read / write instead of providing
45+
* the entire class.
46+
*
47+
* A new instance of this class with be instantiated each time a DDL call is made.
48+
*/
49+
@DeveloperApi
50+
trait DataSourceProvider {
51+
52+
/**
53+
* The string that represents the format that this data source provider uses. By default, it is
54+
* the name of the class, ex: "org.apache.spark.sql.parquet.DefaultSource". This should be
55+
* overridden by children to provide a nice alias for the data source,
56+
* ex: override def format(): String = "parquet"
57+
*/
58+
def format(): String = getClass.getName
59+
}
60+
4061
/**
4162
* ::DeveloperApi::
4263
* Implemented by objects that produce relations for a specific kind of data source. When
@@ -53,7 +74,7 @@ import org.apache.spark.util.SerializableConfiguration
5374
* @since 1.3.0
5475
*/
5576
@DeveloperApi
56-
trait RelationProvider {
77+
trait RelationProvider extends DataSourceProvider {
5778
/**
5879
* Returns a new base relation with the given parameters.
5980
* Note: the parameters' keywords are case insensitive and this insensitivity is enforced
@@ -84,7 +105,7 @@ trait RelationProvider {
84105
* @since 1.3.0
85106
*/
86107
@DeveloperApi
87-
trait SchemaRelationProvider {
108+
trait SchemaRelationProvider extends DataSourceProvider {
88109
/**
89110
* Returns a new base relation with the given parameters and user defined schema.
90111
* Note: the parameters' keywords are case insensitive and this insensitivity is enforced
@@ -120,7 +141,7 @@ trait SchemaRelationProvider {
120141
* @since 1.4.0
121142
*/
122143
@Experimental
123-
trait HadoopFsRelationProvider {
144+
trait HadoopFsRelationProvider extends DataSourceProvider {
124145
/**
125146
* Returns a new base relation with the given parameters, a user defined schema, and a list of
126147
* partition columns. Note: the parameters' keywords are case insensitive and this insensitivity
@@ -140,7 +161,7 @@ trait HadoopFsRelationProvider {
140161
* @since 1.3.0
141162
*/
142163
@DeveloperApi
143-
trait CreatableRelationProvider {
164+
trait CreatableRelationProvider extends DataSourceProvider {
144165
/**
145166
* Creates a relation with the given parameters based on the contents of the given
146167
* DataFrame. The mode specifies the expected behavior of createRelation when

sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ import org.apache.spark.sql.types._
2525

2626

2727
class FilteredScanSource extends RelationProvider {
28+
29+
override def format(): String = "test format"
30+
2831
override def createRelation(
2932
sqlContext: SQLContext,
3033
parameters: Map[String, String]): BaseRelation = {

sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ import org.apache.spark.sql._
2424
import org.apache.spark.sql.types._
2525

2626
class PrunedScanSource extends RelationProvider {
27+
28+
override def format(): String = "test format"
29+
2730
override def createRelation(
2831
sqlContext: SQLContext,
2932
parameters: Map[String, String]): BaseRelation = {

sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ import org.apache.spark.unsafe.types.UTF8String
2929
class DefaultSource extends SimpleScanSource
3030

3131
class SimpleScanSource extends RelationProvider {
32+
33+
override def format(): String = "test format"
34+
3235
override def createRelation(
3336
sqlContext: SQLContext,
3437
parameters: Map[String, String]): BaseRelation = {
@@ -46,6 +49,9 @@ case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext)
4649
}
4750

4851
class AllDataTypesScanSource extends SchemaRelationProvider {
52+
53+
override def format(): String = "test format"
54+
4955
override def createRelation(
5056
sqlContext: SQLContext,
5157
parameters: Map[String, String],
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
org.apache.spark.sql.hive.orc.DefaultSource

0 commit comments

Comments
 (0)