Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution.datasources.jdbc

import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp}
import java.sql.{Connection, Date, PreparedStatement, ResultSet, SQLException, Timestamp}
import java.util.Properties

import scala.util.control.NonFatal
Expand All @@ -28,12 +28,10 @@ import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.CompletionIterator

/**
* Data corresponding to one partition of a JDBCRDD.
Expand All @@ -44,68 +42,6 @@ case class JDBCPartition(whereClause: String, idx: Int) extends Partition {

object JDBCRDD extends Logging {

/**
* Maps a JDBC type to a Catalyst type. This function is called only when
* the JdbcDialect class corresponding to your database driver returns null.
*
* @param sqlType - A field of java.sql.Types
* @return The Catalyst type corresponding to sqlType.
*/
private def getCatalystType(
sqlType: Int,
precision: Int,
scale: Int,
signed: Boolean): DataType = {
val answer = sqlType match {
// scalastyle:off
case java.sql.Types.ARRAY => null
case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) }
case java.sql.Types.BINARY => BinaryType
case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks
case java.sql.Types.BLOB => BinaryType
case java.sql.Types.BOOLEAN => BooleanType
case java.sql.Types.CHAR => StringType
case java.sql.Types.CLOB => StringType
case java.sql.Types.DATALINK => null
case java.sql.Types.DATE => DateType
case java.sql.Types.DECIMAL
if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT
case java.sql.Types.DISTINCT => null
case java.sql.Types.DOUBLE => DoubleType
case java.sql.Types.FLOAT => FloatType
case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType }
case java.sql.Types.JAVA_OBJECT => null
case java.sql.Types.LONGNVARCHAR => StringType
case java.sql.Types.LONGVARBINARY => BinaryType
case java.sql.Types.LONGVARCHAR => StringType
case java.sql.Types.NCHAR => StringType
case java.sql.Types.NCLOB => StringType
case java.sql.Types.NULL => null
case java.sql.Types.NUMERIC
if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT
case java.sql.Types.NVARCHAR => StringType
case java.sql.Types.OTHER => null
case java.sql.Types.REAL => DoubleType
case java.sql.Types.REF => StringType
case java.sql.Types.ROWID => LongType
case java.sql.Types.SMALLINT => IntegerType
case java.sql.Types.SQLXML => StringType
case java.sql.Types.STRUCT => StringType
case java.sql.Types.TIME => TimestampType
case java.sql.Types.TIMESTAMP => TimestampType
case java.sql.Types.TINYINT => IntegerType
case java.sql.Types.VARBINARY => BinaryType
case java.sql.Types.VARCHAR => StringType
case _ => null
// scalastyle:on
}

if (answer == null) throw new SQLException("Unsupported type " + sqlType)
answer
}

/**
* Takes a (schema, table) specification and returns the table's Catalyst
* schema.
Expand All @@ -126,37 +62,7 @@ object JDBCRDD extends Logging {
try {
val rs = statement.executeQuery()
try {
val rsmd = rs.getMetaData
val ncols = rsmd.getColumnCount
val fields = new Array[StructField](ncols)
var i = 0
while (i < ncols) {
val columnName = rsmd.getColumnLabel(i + 1)
val dataType = rsmd.getColumnType(i + 1)
val typeName = rsmd.getColumnTypeName(i + 1)
val fieldSize = rsmd.getPrecision(i + 1)
val fieldScale = rsmd.getScale(i + 1)
val isSigned = {
try {
rsmd.isSigned(i + 1)
} catch {
// Workaround for HIVE-14684:
case e: SQLException if
e.getMessage == "Method not supported" &&
rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true
}
}
val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
val metadata = new MetadataBuilder()
.putString("name", columnName)
.putLong("scale", fieldScale)
val columnType =
dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
getCatalystType(dataType, fieldSize, fieldScale, isSigned))
fields(i) = StructField(columnName, columnType, nullable, metadata.build())
i = i + 1
}
return new StructType(fields)
return JdbcUtils.getSchema(rs, dialect)
} finally {
rs.close()
}
Expand Down Expand Up @@ -331,195 +237,15 @@ private[jdbc] class JDBCRDD(
}
}

// A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
// for `MutableRow`. The last argument `Int` means the index for the value to be set in
// the row and also used for the value in `ResultSet`.
private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit

/**
* Creates `JDBCValueGetter`s according to [[StructType]], which can set
* each value from `ResultSet` to each field of [[MutableRow]] correctly.
*/
def makeGetters(schema: StructType): Array[JDBCValueGetter] =
schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))

private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
case BooleanType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setBoolean(pos, rs.getBoolean(pos + 1))

case DateType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
// DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
val dateVal = rs.getDate(pos + 1)
if (dateVal != null) {
row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
} else {
row.update(pos, null)
}

// When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
// object returned by ResultSet.getBigDecimal is not correctly matched to the table
// schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
// If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
// a BigDecimal object with scale as 0. But the dataframe schema has correct type as
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
// retrieve it, you will get wrong result 199.99.
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
case DecimalType.Fixed(p, s) =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
val decimal =
nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s))
row.update(pos, decimal)

case DoubleType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setDouble(pos, rs.getDouble(pos + 1))

case FloatType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setFloat(pos, rs.getFloat(pos + 1))

case IntegerType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setInt(pos, rs.getInt(pos + 1))

case LongType if metadata.contains("binarylong") =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
val bytes = rs.getBytes(pos + 1)
var ans = 0L
var j = 0
while (j < bytes.size) {
ans = 256 * ans + (255 & bytes(j))
j = j + 1
}
row.setLong(pos, ans)

case LongType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setLong(pos, rs.getLong(pos + 1))

case ShortType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setShort(pos, rs.getShort(pos + 1))

case StringType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))

case TimestampType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
val t = rs.getTimestamp(pos + 1)
if (t != null) {
row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
} else {
row.update(pos, null)
}

case BinaryType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.update(pos, rs.getBytes(pos + 1))

case ArrayType(et, _) =>
val elementConversion = et match {
case TimestampType =>
(array: Object) =>
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
}

case StringType =>
(array: Object) =>
array.asInstanceOf[Array[java.lang.String]]
.map(UTF8String.fromString)

case DateType =>
(array: Object) =>
array.asInstanceOf[Array[java.sql.Date]].map { date =>
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
}

case dt: DecimalType =>
(array: Object) =>
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
nullSafeConvert[java.math.BigDecimal](
decimal, d => Decimal(d, dt.precision, dt.scale))
}

case LongType if metadata.contains("binarylong") =>
throw new IllegalArgumentException(s"Unsupported array element " +
s"type ${dt.simpleString} based on binary")

case ArrayType(_, _) =>
throw new IllegalArgumentException("Nested arrays unsupported")

case _ => (array: Object) => array.asInstanceOf[Array[Any]]
}

(rs: ResultSet, row: MutableRow, pos: Int) =>
val array = nullSafeConvert[Object](
rs.getArray(pos + 1).getArray,
array => new GenericArrayData(elementConversion.apply(array)))
row.update(pos, array)

case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
}

/**
* Runs the SQL query against the JDBC driver.
*
*/
override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] =
new Iterator[InternalRow] {
override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While most of the changes in this block stem from moving the inner loop into a JdbcUtils method, there are one or two non-trivial changes that may impact cleanup in error situations. I'll comment on these changes below in order to help walk through them.

var closed = false
var finished = false
var gotNext = false
var nextValue: InternalRow = null

context.addTaskCompletionListener{ context => close() }
val inputMetrics = context.taskMetrics().inputMetrics
val part = thePart.asInstanceOf[JDBCPartition]
val conn = getConnection()
val dialect = JdbcDialects.get(url)
import scala.collection.JavaConverters._
dialect.beforeFetch(conn, properties.asScala.toMap)

// H2's JDBC driver does not support the setSchema() method. We pass a
// fully-qualified table name in the SELECT statement. I don't know how to
// talk about a table in a completely portable way.

val myWhereClause = getWhereClause(part)

val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause"
val stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
val fetchSize = properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt
require(fetchSize >= 0,
s"Invalid value `${fetchSize.toString}` for parameter " +
s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When the value is 0, " +
"the JDBC driver ignores the value and does the estimates.")
stmt.setFetchSize(fetchSize)
val rs = stmt.executeQuery()

val getters: Array[JDBCValueGetter] = makeGetters(schema)
val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))

def getNext(): InternalRow = {
if (rs.next()) {
inputMetrics.incRecordsRead(1)
var i = 0
while (i < getters.length) {
getters(i).apply(rs, mutableRow, i)
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
}
mutableRow
} else {
finished = true
null.asInstanceOf[InternalRow]
}
}
var rs: ResultSet = null
var stmt: PreparedStatement = null
var conn: Connection = null

def close() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, close() tears down the result set, statement, and connection. While in principle closing the connection should be sufficient to close the statement (which should be sufficient to close the result set), not all JDBC drivers adhere to this contract so we can't simplify this code at all.

if (closed) return
Expand Down Expand Up @@ -555,33 +281,33 @@ private[jdbc] class JDBCRDD(
closed = true
}

override def hasNext: Boolean = {
if (!finished) {
if (!gotNext) {
nextValue = getNext()
if (finished) {
close()
}
gotNext = true
}
}
!finished
}
context.addTaskCompletionListener{ context => close() }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that close() is idempotent because it may be called twice, at the end of the task and at the end of iteration. In the old code the end-of-iteration case was handled in hasNext; here, it's handled in a CompletionIterator defined below.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: space before the bracket.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is an unfortunate style carryover from the old code :(


override def next(): InternalRow = {
if (!hasNext) {
throw new NoSuchElementException("End of stream")
}
gotNext = false
nextValue
}
}
val inputMetrics = context.taskMetrics().inputMetrics
val part = thePart.asInstanceOf[JDBCPartition]
conn = getConnection()
val dialect = JdbcDialects.get(url)
import scala.collection.JavaConverters._
dialect.beforeFetch(conn, properties.asScala.toMap)

private def nullSafeConvert[T](input: T, f: T => Any): Any = {
if (input == null) {
null
} else {
f(input)
}
// H2's JDBC driver does not support the setSchema() method. We pass a
// fully-qualified table name in the SELECT statement. I don't know how to
// talk about a table in a completely portable way.

val myWhereClause = getWhereClause(part)

val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
val fetchSize = properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt
require(fetchSize >= 0,
s"Invalid value `${fetchSize.toString}` for parameter " +
s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When the value is 0, " +
"the JDBC driver ignores the value and does the estimates.")
stmt.setFetchSize(fetchSize)
rs = stmt.executeQuery()
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)

CompletionIterator[InternalRow, Iterator[InternalRow]](rowsIterator, close())
}
}
Loading