Skip to content

Commit 74db85e

Browse files
author
Joseph Batchik
committed
reformatted class loader
1 parent ac2270d commit 74db85e

File tree

2 files changed

+42
-41
lines changed

2 files changed

+42
-41
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -196,34 +196,31 @@ private[sql] class DDLParser(
196196

197197
private[sql] object ResolvedDataSource extends Logging {
198198

199+
private lazy val loader = Utils.getContextOrSparkClassLoader
200+
private lazy val serviceLoader = ServiceLoader.load(classOf[DataSourceProvider], loader)
201+
202+
/** Tries to load the particular class */
203+
private def tryLoad(provider: String): Option[Class[_]] = try {
204+
Some(loader.loadClass(provider))
205+
} catch {
206+
case cnf: ClassNotFoundException => if (provider.startsWith("org.apache.spark.sql.hive.orc")) {
207+
sys.error("The ORC data source must be used with Hive support enabled.")
208+
} else {
209+
None
210+
}
211+
}
212+
199213
/** Given a provider name, look up the data source class definition. */
200214
def lookupDataSource(provider: String): Class[_] = {
201-
val loader = Utils.getContextOrSparkClassLoader
202-
val sl = ServiceLoader.load(classOf[DataSourceProvider], loader)
203-
204-
sl.iterator().filter(_.format() == provider).toList match {
205-
case Nil => logDebug(s"provider: $provider is not registered in the service loader")
206-
case head :: Nil => return head.getClass
215+
serviceLoader.iterator().filter(_.format() == provider).toList match {
216+
case Nil => tryLoad(provider).orElse(tryLoad(s"$provider.DefaultSource")).getOrElse {
217+
sys.error(s"Failed to load class for data source: $provider")
218+
}
219+
case head :: Nil => head.getClass
207220
case sources => sys.error(s"Multiple sources found for $provider, " +
208-
s"(${sources.map(_.getClass.getName).mkString(", ")}, " +
221+
s"(${sources.map(_.getClass.getName).mkString(", ")}), " +
209222
"please specify the fully qualified class name")
210223
}
211-
212-
try {
213-
loader.loadClass(provider)
214-
} catch {
215-
case cnf: java.lang.ClassNotFoundException =>
216-
try {
217-
loader.loadClass(provider + ".DefaultSource")
218-
} catch {
219-
case cnf: java.lang.ClassNotFoundException =>
220-
if (provider.startsWith("org.apache.spark.sql.hive.orc")) {
221-
sys.error("The ORC data source must be used with Hive support enabled.")
222-
} else {
223-
sys.error(s"Failed to load class for data source: $provider")
224-
}
225-
}
226-
}
227224
}
228225

229226
/** Create a [[ResolvedDataSource]] for reading data in. */

sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,37 +22,41 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType}
2222

2323
class FakeSourceOne extends RelationProvider {
2424

25-
override def format() = "Fluet da Bomb"
25+
override def format(): String = "Fluet da Bomb"
2626

27-
override def createRelation(cont: SQLContext, param: Map[String, String]) = new BaseRelation {
28-
override def sqlContext: SQLContext = cont
27+
override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
28+
new BaseRelation {
29+
override def sqlContext: SQLContext = cont
2930

30-
override def schema: StructType =
31-
StructType(Seq(StructField("stringType", StringType, nullable = false)))
32-
}
31+
override def schema: StructType =
32+
StructType(Seq(StructField("stringType", StringType, nullable = false)))
33+
}
3334
}
3435

3536
class FakeSourceTwo extends RelationProvider {
3637

37-
override def format() = "Fluet da Bomb"
38+
override def format(): String = "Fluet da Bomb"
3839

39-
override def createRelation(cont: SQLContext, param: Map[String, String]) = new BaseRelation {
40-
override def sqlContext: SQLContext = cont
40+
override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
41+
new BaseRelation {
42+
override def sqlContext: SQLContext = cont
4143

42-
override def schema: StructType =
43-
StructType(Seq(StructField("stringType", StringType, nullable = false)))
44-
}
44+
override def schema: StructType =
45+
StructType(Seq(StructField("stringType", StringType, nullable = false)))
46+
}
4547
}
4648

4749
class FakeSourceThree extends RelationProvider {
48-
override def format() = "gathering quorum"
4950

50-
override def createRelation(cont: SQLContext, param: Map[String, String]) = new BaseRelation {
51-
override def sqlContext: SQLContext = cont
51+
override def format(): String = "gathering quorum"
5252

53-
override def schema: StructType =
54-
StructType(Seq(StructField("stringType", StringType, nullable = false)))
55-
}
53+
override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
54+
new BaseRelation {
55+
override def sqlContext: SQLContext = cont
56+
57+
override def schema: StructType =
58+
StructType(Seq(StructField("stringType", StringType, nullable = false)))
59+
}
5660
}
5761
// please note that the META-INF/services had to be modified for the test directory for this to work
5862
class DDLSourceLoadSuite extends DataSourceTest {

0 commit comments

Comments
 (0)