Skip to content

Commit 49a01ec

Browse files
committed
fixed a couple of format / error bugs
1 parent e5e93b2 commit 49a01ec

File tree

3 files changed

+23
-15
lines changed

3 files changed

+23
-15
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.ServiceLoader
2222
import scala.collection.Iterator
2323
import scala.collection.JavaConversions._
2424
import scala.language.{existentials, implicitConversions}
25+
import scala.util.{Failure, Success, Try}
2526
import scala.util.matching.Regex
2627

2728
import org.apache.hadoop.fs.Path
@@ -196,26 +197,27 @@ private[sql] class DDLParser(
196197

197198
private[sql] object ResolvedDataSource extends Logging {
198199

199-
/** Tries to load the particular class */
200-
private def tryLoad(loader: ClassLoader, provider: String): Option[Class[_]] = try {
201-
Some(loader.loadClass(provider))
202-
} catch {
203-
case cnf: ClassNotFoundException => None
204-
}
205-
206200
/** Given a provider name, look up the data source class definition. */
207201
def lookupDataSource(provider: String): Class[_] = {
202+
val provider2 = s"$provider.DefaultSource"
208203
val loader = Utils.getContextOrSparkClassLoader
209204
val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader)
210205

211-
serviceLoader.iterator().filter(_.format() == provider).toList match {
212-
case Nil => tryLoad(loader, provider).orElse(tryLoad(loader, s"$provider.DefaultSource"))
213-
.getOrElse(if (provider.startsWith("org.apache.spark.sql.hive.orc")) {
214-
sys.error("The ORC data source must be used with Hive support enabled.")
215-
} else {
216-
sys.error(s"Failed to load class for data source: $provider")
217-
})
206+
serviceLoader.iterator().filter(_.format().equalsIgnoreCase(provider)).toList match {
207+
/** the provider format did not match any given registered aliases */
208+
case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match {
209+
case Success(dataSource) => dataSource
210+
case Failure(error) => if (provider.startsWith("org.apache.spark.sql.hive.orc")) {
211+
throw new ClassNotFoundException(
212+
"The ORC data source must be used with Hive support enabled.", error)
213+
} else {
214+
throw new ClassNotFoundException(
215+
s"Failed to load class for data source: $provider", error)
216+
}
217+
}
218+
/** there is exactly one registered alias */
218219
case head :: Nil => head.getClass
220+
/** There are multiple registered aliases for the input */
219221
case sources => sys.error(s"Multiple sources found for $provider, " +
220222
s"(${sources.map(_.getClass.getName).mkString(", ")}), " +
221223
"please specify the fully qualified class name")

sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,10 @@ class DDLSourceLoadSuite extends DataSourceTest {
7676
caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne")
7777
.load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
7878
}
79+
80+
test("Loading Orc") {
81+
intercept[ClassNotFoundException] {
82+
caseInsensitiveContext.read.format("orc").load()
83+
}
84+
}
7985
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
org.apache.spark.sql.hive.orc.DefaultSource
1+
org.apache.spark.sql.hive.orc.DefaultSource

0 commit comments

Comments
 (0)