-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-17351] Refactor JDBCRDD to expose ResultSet -> Seq[Row] utility methods #14907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
17d770a
682b591
ec49acc
025c9d0
05dfe52
fca548a
43cbef6
1d725ad
f09174b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
|
|
||
| package org.apache.spark.sql.execution.datasources.jdbc | ||
|
|
||
| import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp} | ||
| import java.sql.{Connection, Date, PreparedStatement, ResultSet, SQLException, Timestamp} | ||
| import java.util.Properties | ||
|
|
||
| import scala.util.control.NonFatal | ||
|
|
@@ -28,12 +28,10 @@ import org.apache.spark.{Partition, SparkContext, TaskContext} | |
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} | ||
| import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} | ||
| import org.apache.spark.sql.jdbc.JdbcDialects | ||
| import org.apache.spark.sql.sources._ | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
| import org.apache.spark.util.CompletionIterator | ||
|
|
||
| /** | ||
| * Data corresponding to one partition of a JDBCRDD. | ||
|
|
@@ -44,68 +42,6 @@ case class JDBCPartition(whereClause: String, idx: Int) extends Partition { | |
|
|
||
| object JDBCRDD extends Logging { | ||
|
|
||
| /** | ||
| * Maps a JDBC type to a Catalyst type. This function is called only when | ||
| * the JdbcDialect class corresponding to your database driver returns null. | ||
| * | ||
| * @param sqlType - A field of java.sql.Types | ||
| * @return The Catalyst type corresponding to sqlType. | ||
| */ | ||
| private def getCatalystType( | ||
| sqlType: Int, | ||
| precision: Int, | ||
| scale: Int, | ||
| signed: Boolean): DataType = { | ||
| val answer = sqlType match { | ||
| // scalastyle:off | ||
| case java.sql.Types.ARRAY => null | ||
| case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) } | ||
| case java.sql.Types.BINARY => BinaryType | ||
| case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks | ||
| case java.sql.Types.BLOB => BinaryType | ||
| case java.sql.Types.BOOLEAN => BooleanType | ||
| case java.sql.Types.CHAR => StringType | ||
| case java.sql.Types.CLOB => StringType | ||
| case java.sql.Types.DATALINK => null | ||
| case java.sql.Types.DATE => DateType | ||
| case java.sql.Types.DECIMAL | ||
| if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) | ||
| case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT | ||
| case java.sql.Types.DISTINCT => null | ||
| case java.sql.Types.DOUBLE => DoubleType | ||
| case java.sql.Types.FLOAT => FloatType | ||
| case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType } | ||
| case java.sql.Types.JAVA_OBJECT => null | ||
| case java.sql.Types.LONGNVARCHAR => StringType | ||
| case java.sql.Types.LONGVARBINARY => BinaryType | ||
| case java.sql.Types.LONGVARCHAR => StringType | ||
| case java.sql.Types.NCHAR => StringType | ||
| case java.sql.Types.NCLOB => StringType | ||
| case java.sql.Types.NULL => null | ||
| case java.sql.Types.NUMERIC | ||
| if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) | ||
| case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT | ||
| case java.sql.Types.NVARCHAR => StringType | ||
| case java.sql.Types.OTHER => null | ||
| case java.sql.Types.REAL => DoubleType | ||
| case java.sql.Types.REF => StringType | ||
| case java.sql.Types.ROWID => LongType | ||
| case java.sql.Types.SMALLINT => IntegerType | ||
| case java.sql.Types.SQLXML => StringType | ||
| case java.sql.Types.STRUCT => StringType | ||
| case java.sql.Types.TIME => TimestampType | ||
| case java.sql.Types.TIMESTAMP => TimestampType | ||
| case java.sql.Types.TINYINT => IntegerType | ||
| case java.sql.Types.VARBINARY => BinaryType | ||
| case java.sql.Types.VARCHAR => StringType | ||
| case _ => null | ||
| // scalastyle:on | ||
| } | ||
|
|
||
| if (answer == null) throw new SQLException("Unsupported type " + sqlType) | ||
| answer | ||
| } | ||
|
|
||
| /** | ||
| * Takes a (schema, table) specification and returns the table's Catalyst | ||
| * schema. | ||
|
|
@@ -126,37 +62,7 @@ object JDBCRDD extends Logging { | |
| try { | ||
| val rs = statement.executeQuery() | ||
| try { | ||
| val rsmd = rs.getMetaData | ||
| val ncols = rsmd.getColumnCount | ||
| val fields = new Array[StructField](ncols) | ||
| var i = 0 | ||
| while (i < ncols) { | ||
| val columnName = rsmd.getColumnLabel(i + 1) | ||
| val dataType = rsmd.getColumnType(i + 1) | ||
| val typeName = rsmd.getColumnTypeName(i + 1) | ||
| val fieldSize = rsmd.getPrecision(i + 1) | ||
| val fieldScale = rsmd.getScale(i + 1) | ||
| val isSigned = { | ||
| try { | ||
| rsmd.isSigned(i + 1) | ||
| } catch { | ||
| // Workaround for HIVE-14684: | ||
| case e: SQLException if | ||
| e.getMessage == "Method not supported" && | ||
| rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true | ||
| } | ||
| } | ||
| val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls | ||
| val metadata = new MetadataBuilder() | ||
| .putString("name", columnName) | ||
| .putLong("scale", fieldScale) | ||
| val columnType = | ||
| dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( | ||
| getCatalystType(dataType, fieldSize, fieldScale, isSigned)) | ||
| fields(i) = StructField(columnName, columnType, nullable, metadata.build()) | ||
| i = i + 1 | ||
| } | ||
| return new StructType(fields) | ||
| return JdbcUtils.getSchema(rs, dialect) | ||
| } finally { | ||
| rs.close() | ||
| } | ||
|
|
@@ -331,195 +237,15 @@ private[jdbc] class JDBCRDD( | |
| } | ||
| } | ||
|
|
||
| // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field | ||
| // for `MutableRow`. The last argument `Int` means the index for the value to be set in | ||
| // the row and also used for the value in `ResultSet`. | ||
| private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit | ||
|
|
||
| /** | ||
| * Creates `JDBCValueGetter`s according to [[StructType]], which can set | ||
| * each value from `ResultSet` to each field of [[MutableRow]] correctly. | ||
| */ | ||
| def makeGetters(schema: StructType): Array[JDBCValueGetter] = | ||
| schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata)) | ||
|
|
||
| private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match { | ||
| case BooleanType => | ||
| (rs: ResultSet, row: MutableRow, pos: Int) => | ||
| row.setBoolean(pos, rs.getBoolean(pos + 1)) | ||
|
|
||
| case DateType => | ||
| (rs: ResultSet, row: MutableRow, pos: Int) => | ||
| // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. | ||
| val dateVal = rs.getDate(pos + 1) | ||
| if (dateVal != null) { | ||
| row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal)) | ||
| } else { | ||
| row.update(pos, null) | ||
| } | ||
|
|
||
| // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal | ||
| // object returned by ResultSet.getBigDecimal is not correctly matched to the table | ||
| // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale. | ||
| // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through | ||
| // a BigDecimal object with scale as 0. But the dataframe schema has correct type as | ||
| // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then | ||
| // retrieve it, you will get wrong result 199.99. | ||
| // So it is needed to set precision and scale for Decimal based on JDBC metadata. | ||
| case DecimalType.Fixed(p, s) => | ||
| (rs: ResultSet, row: MutableRow, pos: Int) => | ||
| val decimal = | ||
| nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s)) | ||
| row.update(pos, decimal) | ||
|
|
||
| case DoubleType => | ||
| (rs: ResultSet, row: MutableRow, pos: Int) => | ||
| row.setDouble(pos, rs.getDouble(pos + 1)) | ||
|
|
||
| case FloatType => | ||
| (rs: ResultSet, row: MutableRow, pos: Int) => | ||
| row.setFloat(pos, rs.getFloat(pos + 1)) | ||
|
|
||
| case IntegerType => | ||
| (rs: ResultSet, row: MutableRow, pos: Int) => | ||
| row.setInt(pos, rs.getInt(pos + 1)) | ||
|
|
||
| case LongType if metadata.contains("binarylong") => | ||
| (rs: ResultSet, row: MutableRow, pos: Int) => | ||
| val bytes = rs.getBytes(pos + 1) | ||
| var ans = 0L | ||
| var j = 0 | ||
| while (j < bytes.size) { | ||
| ans = 256 * ans + (255 & bytes(j)) | ||
| j = j + 1 | ||
| } | ||
| row.setLong(pos, ans) | ||
|
|
||
| case LongType => | ||
| (rs: ResultSet, row: MutableRow, pos: Int) => | ||
| row.setLong(pos, rs.getLong(pos + 1)) | ||
|
|
||
| case ShortType => | ||
| (rs: ResultSet, row: MutableRow, pos: Int) => | ||
| row.setShort(pos, rs.getShort(pos + 1)) | ||
|
|
||
| case StringType => | ||
| (rs: ResultSet, row: MutableRow, pos: Int) => | ||
| // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 | ||
| row.update(pos, UTF8String.fromString(rs.getString(pos + 1))) | ||
|
|
||
| case TimestampType => | ||
| (rs: ResultSet, row: MutableRow, pos: Int) => | ||
| val t = rs.getTimestamp(pos + 1) | ||
| if (t != null) { | ||
| row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t)) | ||
| } else { | ||
| row.update(pos, null) | ||
| } | ||
|
|
||
| case BinaryType => | ||
| (rs: ResultSet, row: MutableRow, pos: Int) => | ||
| row.update(pos, rs.getBytes(pos + 1)) | ||
|
|
||
| case ArrayType(et, _) => | ||
| val elementConversion = et match { | ||
| case TimestampType => | ||
| (array: Object) => | ||
| array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => | ||
| nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) | ||
| } | ||
|
|
||
| case StringType => | ||
| (array: Object) => | ||
| array.asInstanceOf[Array[java.lang.String]] | ||
| .map(UTF8String.fromString) | ||
|
|
||
| case DateType => | ||
| (array: Object) => | ||
| array.asInstanceOf[Array[java.sql.Date]].map { date => | ||
| nullSafeConvert(date, DateTimeUtils.fromJavaDate) | ||
| } | ||
|
|
||
| case dt: DecimalType => | ||
| (array: Object) => | ||
| array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => | ||
| nullSafeConvert[java.math.BigDecimal]( | ||
| decimal, d => Decimal(d, dt.precision, dt.scale)) | ||
| } | ||
|
|
||
| case LongType if metadata.contains("binarylong") => | ||
| throw new IllegalArgumentException(s"Unsupported array element " + | ||
| s"type ${dt.simpleString} based on binary") | ||
|
|
||
| case ArrayType(_, _) => | ||
| throw new IllegalArgumentException("Nested arrays unsupported") | ||
|
|
||
| case _ => (array: Object) => array.asInstanceOf[Array[Any]] | ||
| } | ||
|
|
||
| (rs: ResultSet, row: MutableRow, pos: Int) => | ||
| val array = nullSafeConvert[Object]( | ||
| rs.getArray(pos + 1).getArray, | ||
| array => new GenericArrayData(elementConversion.apply(array))) | ||
| row.update(pos, array) | ||
|
|
||
| case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") | ||
| } | ||
|
|
||
| /** | ||
| * Runs the SQL query against the JDBC driver. | ||
| * | ||
| */ | ||
| override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = | ||
| new Iterator[InternalRow] { | ||
| override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = { | ||
| var closed = false | ||
| var finished = false | ||
| var gotNext = false | ||
| var nextValue: InternalRow = null | ||
|
|
||
| context.addTaskCompletionListener{ context => close() } | ||
| val inputMetrics = context.taskMetrics().inputMetrics | ||
| val part = thePart.asInstanceOf[JDBCPartition] | ||
| val conn = getConnection() | ||
| val dialect = JdbcDialects.get(url) | ||
| import scala.collection.JavaConverters._ | ||
| dialect.beforeFetch(conn, properties.asScala.toMap) | ||
|
|
||
| // H2's JDBC driver does not support the setSchema() method. We pass a | ||
| // fully-qualified table name in the SELECT statement. I don't know how to | ||
| // talk about a table in a completely portable way. | ||
|
|
||
| val myWhereClause = getWhereClause(part) | ||
|
|
||
| val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause" | ||
| val stmt = conn.prepareStatement(sqlText, | ||
| ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) | ||
| val fetchSize = properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt | ||
| require(fetchSize >= 0, | ||
| s"Invalid value `${fetchSize.toString}` for parameter " + | ||
| s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When the value is 0, " + | ||
| "the JDBC driver ignores the value and does the estimates.") | ||
| stmt.setFetchSize(fetchSize) | ||
| val rs = stmt.executeQuery() | ||
|
|
||
| val getters: Array[JDBCValueGetter] = makeGetters(schema) | ||
| val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) | ||
|
|
||
| def getNext(): InternalRow = { | ||
| if (rs.next()) { | ||
| inputMetrics.incRecordsRead(1) | ||
| var i = 0 | ||
| while (i < getters.length) { | ||
| getters(i).apply(rs, mutableRow, i) | ||
| if (rs.wasNull) mutableRow.setNullAt(i) | ||
| i = i + 1 | ||
| } | ||
| mutableRow | ||
| } else { | ||
| finished = true | ||
| null.asInstanceOf[InternalRow] | ||
| } | ||
| } | ||
| var rs: ResultSet = null | ||
| var stmt: PreparedStatement = null | ||
| var conn: Connection = null | ||
|
|
||
| def close() { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, |
||
| if (closed) return | ||
|
|
@@ -555,33 +281,33 @@ private[jdbc] class JDBCRDD( | |
| closed = true | ||
| } | ||
|
|
||
| override def hasNext: Boolean = { | ||
| if (!finished) { | ||
| if (!gotNext) { | ||
| nextValue = getNext() | ||
| if (finished) { | ||
| close() | ||
| } | ||
| gotNext = true | ||
| } | ||
| } | ||
| !finished | ||
| } | ||
| context.addTaskCompletionListener{ context => close() } | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: space before the bracket.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, this is an unfortunate style carryover from the old code :( |
||
|
|
||
| override def next(): InternalRow = { | ||
| if (!hasNext) { | ||
| throw new NoSuchElementException("End of stream") | ||
| } | ||
| gotNext = false | ||
| nextValue | ||
| } | ||
| } | ||
| val inputMetrics = context.taskMetrics().inputMetrics | ||
| val part = thePart.asInstanceOf[JDBCPartition] | ||
| conn = getConnection() | ||
| val dialect = JdbcDialects.get(url) | ||
| import scala.collection.JavaConverters._ | ||
| dialect.beforeFetch(conn, properties.asScala.toMap) | ||
|
|
||
| private def nullSafeConvert[T](input: T, f: T => Any): Any = { | ||
| if (input == null) { | ||
| null | ||
| } else { | ||
| f(input) | ||
| } | ||
| // H2's JDBC driver does not support the setSchema() method. We pass a | ||
| // fully-qualified table name in the SELECT statement. I don't know how to | ||
| // talk about a table in a completely portable way. | ||
|
|
||
| val myWhereClause = getWhereClause(part) | ||
|
|
||
| val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause" | ||
| stmt = conn.prepareStatement(sqlText, | ||
| ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) | ||
| val fetchSize = properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt | ||
| require(fetchSize >= 0, | ||
| s"Invalid value `${fetchSize.toString}` for parameter " + | ||
| s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When the value is 0, " + | ||
| "the JDBC driver ignores the value and does the estimates.") | ||
| stmt.setFetchSize(fetchSize) | ||
| rs = stmt.executeQuery() | ||
| val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) | ||
|
|
||
| CompletionIterator[InternalRow, Iterator[InternalRow]](rowsIterator, close()) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While most of the changes in this block stem from moving the inner loop into a
JdbcUtilsmethod, there are one or two non-trivial changes that may impact cleanup in error situations. I'll comment on these changes below in order to help walk through them.