From 5ee7d5e7079e79222239484d05da903d1d2eade0 Mon Sep 17 00:00:00 2001 From: ShivSood Date: Tue, 7 Jan 2020 12:31:56 -0800 Subject: [PATCH] Fix to make ByteType functional for write of tables using JDBC. As part of fix ByteType is now mapped to ShortType --- .../jdbc/MsSqlServerIntegrationSuite.scala | 25 +++++++++++++++++-- .../datasources/jdbc/JdbcUtils.scala | 8 +++--- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 8 +++--- .../spark/sql/jdbc/JDBCWriteSuite.scala | 21 ++++++++++++++++ 4 files changed, 52 insertions(+), 10 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index efd7ca74c796b..7c98bd04167b9 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -119,7 +119,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { val types = row.toSeq.map(x => x.getClass.toString) assert(types.length == 12) assert(types(0).equals("class java.lang.Boolean")) - assert(types(1).equals("class java.lang.Integer")) + assert(types(1).equals("class java.lang.Short")) assert(types(2).equals("class java.lang.Short")) assert(types(3).equals("class java.lang.Integer")) assert(types(4).equals("class java.lang.Long")) @@ -131,7 +131,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { assert(types(10).equals("class java.math.BigDecimal")) assert(types(11).equals("class java.math.BigDecimal")) assert(row.getBoolean(0) == false) - assert(row.getInt(1) == 255) + assert(row.getShort(1) == 255) assert(row.getShort(2) == 32767) assert(row.getInt(3) == 2147483647) assert(row.getLong(4) == 9223372036854775807L) @@ -202,4 +202,25 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { df2.write.jdbc(jdbcUrl, "datescopy", new Properties) df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) } + + test("SPARK-29644: Write tables with ByteType") { + import testImplicits._ + val df = Seq(-127.toByte, 0.toByte, 1.toByte, 38.toByte, 127.toByte).toDF("a") + val tablename = "bytetable" + df.write + .format("jdbc") + .mode("overwrite") + .option("url", jdbcUrl) + .option("dbtable", tablename) + .save() + val df2 = spark.read + .format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", tablename) + .load() + assert(df.count == df2.count) + val rows = df2.collect() + val colType = rows(0).toSeq.map(x => x.getClass.toString) + assert(colType(0) == "class java.lang.Short") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index c1e1aed83bae5..52fd9554ae864 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -171,7 +171,7 @@ object JdbcUtils extends Logging { case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE)) case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT)) case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) - case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) + case ByteType => Option(JdbcType("TINYINT", java.sql.Types.TINYINT)) case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) @@ -244,7 +244,7 @@ object JdbcUtils extends Logging { case java.sql.Types.TIMESTAMP => TimestampType case java.sql.Types.TIMESTAMP_WITH_TIMEZONE => null - case java.sql.Types.TINYINT => IntegerType + case java.sql.Types.TINYINT => ShortType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType case _ => @@ -445,7 +445,7 @@ object JdbcUtils extends Logging { case ByteType => (rs: ResultSet, row: InternalRow, pos: Int) => - row.setByte(pos, rs.getByte(pos + 1)) + row.setShort(pos, rs.getByte(pos + 1)) case StringType => (rs: ResultSet, row: InternalRow, pos: Int) => @@ -550,7 +550,7 @@ object JdbcUtils extends Logging { case ByteType => (stmt: PreparedStatement, row: Row, pos: Int) => - stmt.setInt(pos + 1, row.getByte(pos)) + stmt.setByte(pos + 1, row.getByte(pos)) case BooleanType => (stmt: PreparedStatement, row: Row, pos: Int) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 66ddc6ee83d06..87fbe32ae604b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -65,7 +65,7 @@ class JDBCSuite extends QueryTest size: Int, md: MetadataBuilder): Option[DataType] = { sqlType match { - case java.sql.Types.TINYINT => Some(ByteType) + case java.sql.Types.TINYINT => Some(ShortType) case _ => None } } @@ -578,7 +578,7 @@ class JDBCSuite extends QueryTest assert(rows.length === 1) assert(rows(0).getInt(0) === 1) assert(rows(0).getBoolean(1) === false) - assert(rows(0).getInt(2) === 3) + assert(rows(0).getShort(2) === 3) assert(rows(0).getInt(3) === 4) assert(rows(0).getLong(4) === 1234567890123L) } @@ -704,8 +704,8 @@ class JDBCSuite extends QueryTest val df = spark.read.jdbc(urlWithUserAndPass, "test.inttypes", new Properties()) val rows = df.collect() assert(rows.length === 2) - assert(rows(0).get(2).isInstanceOf[Byte]) - assert(rows(0).getByte(2) === 3) + assert(rows(0).get(2).isInstanceOf[Short]) + assert(rows(0).getShort(2) === 3) assert(rows(1).isNullAt(2)) JdbcDialects.unregisterDialect(testH2DialectTinyInt) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 8021ef1a17a18..61a85b0031c99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -598,4 +598,25 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { sparkContext.removeSparkListener(listener) taskMetrics.sum } + + test("SPARK-29644: Write tables with ByteType") { + import testImplicits._ + val df = Seq(-127.toByte, 0.toByte, 1.toByte, 38.toByte, 128.toByte, 255.toByte).toDF("a") + val tablename = "bytetable" + df.write + .format("jdbc") + .mode("overwrite") + .option("url", url) + .option("dbtable", tablename) + .save() + val df2 = spark.read + .format("jdbc") + .option("url", url) + .option("dbtable", tablename) + .load() + assert(df.count == df2.count) + val rows = df2.collect() + val colType = rows(0).toSeq.map(x => x.getClass.toString) + assert(colType(0) == "class java.lang.Short") + } }