Skip to content

Commit b3eec71

Browse files
Joseph BatchikCodingCat
authored andcommitted
[SPARK-9486][SQL] Add data source aliasing for external packages
Users currently have to provide the full class name for external data sources, like: `sqlContext.read.format("com.databricks.spark.avro").load(path)` This allows external data source packages to register themselves using a Service Loader so that they can add custom alias like: `sqlContext.read.format("avro").load(path)` This makes it so that using external data source packages uses the same format as the internal data sources like parquet, json, etc. Author: Joseph Batchik <[email protected]> Author: Joseph Batchik <[email protected]> Closes apache#7802 from JDrit/service_loader and squashes the following commits: 49a01ec [Joseph Batchik] fixed a couple of format / error bugs e5e93b2 [Joseph Batchik] modified rat file to only excluded added services 72b349a [Joseph Batchik] fixed error with orc data source actually 9f93ea7 [Joseph Batchik] fixed error with orc data source 87b7f1c [Joseph Batchik] fixed typo 101cd22 [Joseph Batchik] removing unneeded changes 8f3cf43 [Joseph Batchik] merged in changes b63d337 [Joseph Batchik] merged in master 95ae030 [Joseph Batchik] changed the new trait to be used as a mixin for data source to register themselves 74db85e [Joseph Batchik] reformatted class loader ac2270d [Joseph Batchik] removing some added test a6926db [Joseph Batchik] added test cases for data source loader 208a2a8 [Joseph Batchik] changes to do error catching if there are multiple data sources 946186e [Joseph Batchik] started working on service loader
1 parent ee2d3ac commit b3eec71

File tree

11 files changed

+156
-30
lines changed

11 files changed

+156
-30
lines changed

.rat-excludes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,4 @@ INDEX
9393
.lintr
9494
gen-java.*
9595
.*avpr
96+
org.apache.spark.sql.sources.DataSourceRegister
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
org.apache.spark.sql.jdbc.DefaultSource
2+
org.apache.spark.sql.json.DefaultSource
3+
org.apache.spark.sql.parquet.DefaultSource

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717

1818
package org.apache.spark.sql.execution.datasources
1919

20+
import java.util.ServiceLoader
21+
22+
import scala.collection.Iterator
23+
import scala.collection.JavaConversions._
2024
import scala.language.{existentials, implicitConversions}
25+
import scala.util.{Failure, Success, Try}
2126
import scala.util.matching.Regex
2227

2328
import org.apache.hadoop.fs.Path
@@ -190,37 +195,32 @@ private[sql] class DDLParser(
190195
}
191196
}
192197

193-
private[sql] object ResolvedDataSource {
194-
195-
private val builtinSources = Map(
196-
"jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource",
197-
"json" -> "org.apache.spark.sql.json.DefaultSource",
198-
"parquet" -> "org.apache.spark.sql.parquet.DefaultSource",
199-
"orc" -> "org.apache.spark.sql.hive.orc.DefaultSource"
200-
)
198+
private[sql] object ResolvedDataSource extends Logging {
201199

202200
/** Given a provider name, look up the data source class definition. */
203201
def lookupDataSource(provider: String): Class[_] = {
202+
val provider2 = s"$provider.DefaultSource"
204203
val loader = Utils.getContextOrSparkClassLoader
205-
206-
if (builtinSources.contains(provider)) {
207-
return loader.loadClass(builtinSources(provider))
208-
}
209-
210-
try {
211-
loader.loadClass(provider)
212-
} catch {
213-
case cnf: java.lang.ClassNotFoundException =>
214-
try {
215-
loader.loadClass(provider + ".DefaultSource")
216-
} catch {
217-
case cnf: java.lang.ClassNotFoundException =>
218-
if (provider.startsWith("org.apache.spark.sql.hive.orc")) {
219-
sys.error("The ORC data source must be used with Hive support enabled.")
220-
} else {
221-
sys.error(s"Failed to load class for data source: $provider")
222-
}
204+
val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader)
205+
206+
serviceLoader.iterator().filter(_.format().equalsIgnoreCase(provider)).toList match {
207+
/** the provider format did not match any given registered aliases */
208+
case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match {
209+
case Success(dataSource) => dataSource
210+
case Failure(error) => if (provider.startsWith("org.apache.spark.sql.hive.orc")) {
211+
throw new ClassNotFoundException(
212+
"The ORC data source must be used with Hive support enabled.", error)
213+
} else {
214+
throw new ClassNotFoundException(
215+
s"Failed to load class for data source: $provider", error)
223216
}
217+
}
218+
/** there is exactly one registered alias */
219+
case head :: Nil => head.getClass
220+
/** There are multiple registered aliases for the input */
221+
case sources => sys.error(s"Multiple sources found for $provider, " +
222+
s"(${sources.map(_.getClass.getName).mkString(", ")}), " +
223+
"please specify the fully qualified class name")
224224
}
225225
}
226226

sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ private[sql] object JDBCRelation {
7777
}
7878
}
7979

80-
private[sql] class DefaultSource extends RelationProvider {
80+
private[sql] class DefaultSource extends RelationProvider with DataSourceRegister {
81+
82+
def format(): String = "jdbc"
83+
8184
/** Returns a new base relation with the given parameters. */
8285
override def createRelation(
8386
sqlContext: SQLContext,

sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ import org.apache.spark.sql.sources._
3737
import org.apache.spark.sql.types.StructType
3838
import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
3939

40-
private[sql] class DefaultSource extends HadoopFsRelationProvider {
40+
private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
41+
42+
def format(): String = "json"
43+
4144
override def createRelation(
4245
sqlContext: SQLContext,
4346
paths: Array[String],

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ import org.apache.spark.sql.types.{DataType, StructType}
4949
import org.apache.spark.util.{SerializableConfiguration, Utils}
5050

5151

52-
private[sql] class DefaultSource extends HadoopFsRelationProvider {
52+
private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
53+
54+
def format(): String = "parquet"
55+
5356
override def createRelation(
5457
sqlContext: SQLContext,
5558
paths: Array[String],

sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,27 @@ import org.apache.spark.sql.types.StructType
3737
import org.apache.spark.sql._
3838
import org.apache.spark.util.SerializableConfiguration
3939

40+
/**
41+
* ::DeveloperApi::
42+
* Data sources should implement this trait so that they can register an alias to their data source.
43+
* This allows users to give the data source alias as the format type over the fully qualified
44+
* class name.
45+
*
46+
* ex: parquet.DefaultSource.format = "parquet".
47+
*
48+
* A new instance of this class with be instantiated each time a DDL call is made.
49+
*/
50+
@DeveloperApi
51+
trait DataSourceRegister {
52+
53+
/**
54+
* The string that represents the format that this data source provider uses. This is
55+
* overridden by children to provide a nice alias for the data source,
56+
* ex: override def format(): String = "parquet"
57+
*/
58+
def format(): String
59+
}
60+
4061
/**
4162
* ::DeveloperApi::
4263
* Implemented by objects that produce relations for a specific kind of data source. When
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
org.apache.spark.sql.sources.FakeSourceOne
2+
org.apache.spark.sql.sources.FakeSourceTwo
3+
org.apache.spark.sql.sources.FakeSourceThree
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.sources
19+
20+
import org.apache.spark.sql.SQLContext
21+
import org.apache.spark.sql.types.{StringType, StructField, StructType}
22+
23+
class FakeSourceOne extends RelationProvider with DataSourceRegister {
24+
25+
def format(): String = "Fluet da Bomb"
26+
27+
override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
28+
new BaseRelation {
29+
override def sqlContext: SQLContext = cont
30+
31+
override def schema: StructType =
32+
StructType(Seq(StructField("stringType", StringType, nullable = false)))
33+
}
34+
}
35+
36+
class FakeSourceTwo extends RelationProvider with DataSourceRegister {
37+
38+
def format(): String = "Fluet da Bomb"
39+
40+
override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
41+
new BaseRelation {
42+
override def sqlContext: SQLContext = cont
43+
44+
override def schema: StructType =
45+
StructType(Seq(StructField("stringType", StringType, nullable = false)))
46+
}
47+
}
48+
49+
class FakeSourceThree extends RelationProvider with DataSourceRegister {
50+
51+
def format(): String = "gathering quorum"
52+
53+
override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
54+
new BaseRelation {
55+
override def sqlContext: SQLContext = cont
56+
57+
override def schema: StructType =
58+
StructType(Seq(StructField("stringType", StringType, nullable = false)))
59+
}
60+
}
61+
// please note that the META-INF/services had to be modified for the test directory for this to work
62+
class DDLSourceLoadSuite extends DataSourceTest {
63+
64+
test("data sources with the same name") {
65+
intercept[RuntimeException] {
66+
caseInsensitiveContext.read.format("Fluet da Bomb").load()
67+
}
68+
}
69+
70+
test("load data source from format alias") {
71+
caseInsensitiveContext.read.format("gathering quorum").load().schema ==
72+
StructType(Seq(StructField("stringType", StringType, nullable = false)))
73+
}
74+
75+
test("specify full classname with duplicate formats") {
76+
caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne")
77+
.load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
78+
}
79+
80+
test("Loading Orc") {
81+
intercept[ClassNotFoundException] {
82+
caseInsensitiveContext.read.format("orc").load()
83+
}
84+
}
85+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
org.apache.spark.sql.hive.orc.DefaultSource

0 commit comments

Comments
 (0)