Skip to content

Commit febdf69

Browse files
committed
Support OracleDialect whether convert number(1) to BooleanType
1 parent 24db358 commit febdf69

File tree

16 files changed

+184
-135
lines changed

16 files changed

+184
-135
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.jdbc
2020
import java.sql.{Connection, DriverManager}
2121
import java.util.{Locale, Properties}
2222

23+
import org.apache.spark.annotation.DeveloperApi
2324
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
2425

2526
/**
@@ -33,6 +34,14 @@ class JDBCOptions(
3334

3435
def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))
3536

37+
@DeveloperApi
38+
def this(url: String) = {
39+
this(CaseInsensitiveMap(Map(
40+
JDBCOptions.JDBC_URL -> url,
41+
JDBCOptions.JDBC_DRIVER_CLASS -> "org.h2.Driver",
42+
JDBCOptions.JDBC_TABLE_NAME -> "")))
43+
}
44+
3645
def this(url: String, table: String, parameters: Map[String, String]) = {
3746
this(CaseInsensitiveMap(parameters ++ Map(
3847
JDBCOptions.JDBC_URL -> url,

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,8 @@ object JDBCRDD extends Logging {
5252
* @throws SQLException if the table contains an unsupported type.
5353
*/
5454
def resolveTable(options: JDBCOptions): StructType = {
55-
val url = options.url
5655
val table = options.table
57-
val dialect = JdbcDialects.get(url)
56+
val dialect = JdbcDialects.get(options)
5857
val conn: Connection = JdbcUtils.createConnectionFactory(options)()
5958
try {
6059
val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
@@ -167,8 +166,7 @@ object JDBCRDD extends Logging {
167166
filters: Array[Filter],
168167
parts: Array[Partition],
169168
options: JDBCOptions): RDD[InternalRow] = {
170-
val url = options.url
171-
val dialect = JdbcDialects.get(url)
169+
val dialect = JdbcDialects.get(options)
172170
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
173171
new JDBCRDD(
174172
sc,
@@ -177,7 +175,7 @@ object JDBCRDD extends Logging {
177175
quotedColumns,
178176
filters,
179177
parts,
180-
url,
178+
options.url,
181179
options)
182180
}
183181
}
@@ -217,7 +215,7 @@ private[jdbc] class JDBCRDD(
217215
*/
218216
private val filterWhereClause: String =
219217
filters
220-
.flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url)))
218+
.flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(options)))
221219
.map(p => s"($p)").mkString(" AND ")
222220

223221
/**
@@ -284,7 +282,7 @@ private[jdbc] class JDBCRDD(
284282
val inputMetrics = context.taskMetrics().inputMetrics
285283
val part = thePart.asInstanceOf[JDBCPartition]
286284
conn = getConnection()
287-
val dialect = JdbcDialects.get(url)
285+
val dialect = JdbcDialects.get(options)
288286
import scala.collection.JavaConverters._
289287
dialect.beforeFetch(conn, options.asConnectionProperties.asScala.toMap)
290288

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ private[sql] case class JDBCRelation(
114114

115115
// Check if JDBCRDD.compileFilter can accept input filters
116116
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
117-
filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty)
117+
filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions)).isEmpty)
118118
}
119119

120120
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
6161
if (tableExists) {
6262
mode match {
6363
case SaveMode.Overwrite =>
64-
if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
64+
if (options.isTruncate && isCascadingTruncateTable(options) == Some(false)) {
6565
// In this case, we should truncate table and then load.
6666
truncateTable(conn, options.table)
6767
val tableSchema = JdbcUtils.getSchemaOption(conn, options)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ object JdbcUtils extends Logging {
6666
* Returns true if the table already exists in the JDBC database.
6767
*/
6868
def tableExists(conn: Connection, options: JDBCOptions): Boolean = {
69-
val dialect = JdbcDialects.get(options.url)
69+
val dialect = JdbcDialects.get(options)
7070

7171
// Somewhat hacky, but there isn't a good way to identify whether a table exists for all
7272
// SQL database systems using JDBC meta data calls, considering "table" could also include
@@ -105,8 +105,8 @@ object JdbcUtils extends Logging {
105105
}
106106
}
107107

108-
def isCascadingTruncateTable(url: String): Option[Boolean] = {
109-
JdbcDialects.get(url).isCascadingTruncateTable()
108+
def isCascadingTruncateTable(options: JDBCOptions): Option[Boolean] = {
109+
JdbcDialects.get(options).isCascadingTruncateTable()
110110
}
111111

112112
/**
@@ -247,7 +247,7 @@ object JdbcUtils extends Logging {
247247
* Returns the schema if the table already exists in the JDBC database.
248248
*/
249249
def getSchemaOption(conn: Connection, options: JDBCOptions): Option[StructType] = {
250-
val dialect = JdbcDialects.get(options.url)
250+
val dialect = JdbcDialects.get(options)
251251

252252
try {
253253
val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table))
@@ -702,10 +702,10 @@ object JdbcUtils extends Logging {
702702
*/
703703
def schemaString(
704704
df: DataFrame,
705-
url: String,
705+
options: JDBCOptions,
706706
createTableColumnTypes: Option[String] = None): String = {
707707
val sb = new StringBuilder()
708-
val dialect = JdbcDialects.get(url)
708+
val dialect = JdbcDialects.get(options)
709709
val userSpecifiedColTypesMap = createTableColumnTypes
710710
.map(parseUserSpecifiedCreateTableColumnTypes(df, _))
711711
.getOrElse(Map.empty[String, String])
@@ -772,9 +772,8 @@ object JdbcUtils extends Logging {
772772
tableSchema: Option[StructType],
773773
isCaseSensitive: Boolean,
774774
options: JDBCOptions): Unit = {
775-
val url = options.url
776775
val table = options.table
777-
val dialect = JdbcDialects.get(url)
776+
val dialect = JdbcDialects.get(options)
778777
val rddSchema = df.schema
779778
val getConnection: () => Connection = createConnectionFactory(options)
780779
val batchSize = options.batchSize
@@ -801,7 +800,7 @@ object JdbcUtils extends Logging {
801800
df: DataFrame,
802801
options: JDBCOptions): Unit = {
803802
val strSchema = schemaString(
804-
df, options.url, options.createTableColumnTypes)
803+
df, options, options.createTableColumnTypes)
805804
val table = options.table
806805
val createTableOptions = options.createTableOptions
807806
// Create the table if the table does not exist.

sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.jdbc
1919

20+
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
2021
import org.apache.spark.sql.types.{DataType, MetadataBuilder}
2122

2223
/**
@@ -30,8 +31,8 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect
3031

3132
require(dialects.nonEmpty)
3233

33-
override def canHandle(url : String): Boolean =
34-
dialects.map(_.canHandle(url)).reduce(_ && _)
34+
override def canHandle(options: JDBCOptions): Boolean =
35+
dialects.map(_.canHandle(options)).reduce(_ && _)
3536

3637
override def getCatalystType(
3738
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {

sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
package org.apache.spark.sql.jdbc
1919

20+
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
2021
import org.apache.spark.sql.types.{BooleanType, DataType, StringType}
2122

2223
private object DB2Dialect extends JdbcDialect {
2324

24-
override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2")
25+
override def canHandle(options: JDBCOptions): Boolean = options.url.startsWith("jdbc:db2")
2526

2627
override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
2728
case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB))

sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ package org.apache.spark.sql.jdbc
1919

2020
import java.sql.Types
2121

22+
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
2223
import org.apache.spark.sql.types._
2324

2425

2526
private object DerbyDialect extends JdbcDialect {
2627

27-
override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby")
28+
override def canHandle(options: JDBCOptions): Boolean = options.url.startsWith("jdbc:derby")
2829

2930
override def getCatalystType(
3031
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {

sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc
2020
import java.sql.Connection
2121

2222
import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since}
23+
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
2324
import org.apache.spark.sql.types._
2425

2526
/**
@@ -58,11 +59,11 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int)
5859
abstract class JdbcDialect extends Serializable {
5960
/**
6061
* Check if this dialect instance can handle a certain jdbc url.
61-
* @param url the jdbc url.
62+
* @param options the jdbc options.
6263
* @return True if the dialect can be applied on the given jdbc url.
6364
* @throws NullPointerException if the url is null.
6465
*/
65-
def canHandle(url : String): Boolean
66+
def canHandle(options: JDBCOptions): Boolean
6667

6768
/**
6869
* Get the custom datatype mapping for the given jdbc meta information.
@@ -179,8 +180,8 @@ object JdbcDialects {
179180
/**
180181
* Fetch the JdbcDialect class corresponding to a given database url.
181182
*/
182-
def get(url: String): JdbcDialect = {
183-
val matchingDialects = dialects.filter(_.canHandle(url))
183+
def get(options: JDBCOptions): JdbcDialect = {
184+
val matchingDialects = dialects.filter(_.canHandle(options))
184185
matchingDialects.length match {
185186
case 0 => NoopDialect
186187
case 1 => matchingDialects.head
@@ -193,5 +194,5 @@ object JdbcDialects {
193194
* NOOP dialect object, always returning the neutral element.
194195
*/
195196
private object NoopDialect extends JdbcDialect {
196-
override def canHandle(url : String): Boolean = true
197+
override def canHandle(options: JDBCOptions): Boolean = true
197198
}

sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
package org.apache.spark.sql.jdbc
1919

20+
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
2021
import org.apache.spark.sql.types._
2122

2223

2324
private object MsSqlServerDialect extends JdbcDialect {
2425

25-
override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver")
26+
override def canHandle(options: JDBCOptions): Boolean = options.url.startsWith("jdbc:sqlserver")
2627

2728
override def getCatalystType(
2829
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {

0 commit comments

Comments
 (0)