Skip to content

Commit a133057

Browse files
JoshRosenhvanhovell
authored andcommitted
[SPARK-17229][SQL] PostgresDialect shouldn't widen float and short types during reads
## What changes were proposed in this pull request? When reading float4 and smallint columns from PostgreSQL, Spark's `PostgresDialect` widens these types to Decimal and Integer rather than using the narrower Float and Short types. According to https://www.postgresql.org/docs/7.1/static/datatype.html#DATATYPE-TABLE, Postgres maps the `smallint` type to a signed two-byte integer and the `real` / `float4` types to single precision floating point numbers. This patch fixes this by adding more special-cases to `getCatalystType`, similar to what was done for the Derby JDBC dialect. I also fixed a similar problem in the write path which causes Spark to create integer columns in Postgres for what should have been ShortType columns. ## How was this patch tested? New test cases in `PostgresIntegrationSuite` (which I ran manually because Jenkins can't run it right now). Author: Josh Rosen <[email protected]> Closes #14796 from JoshRosen/postgres-jdbc-type-fixes.
1 parent 9958ac0 commit a133057

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.util.Properties
2222

2323
import org.apache.spark.sql.Column
2424
import org.apache.spark.sql.catalyst.expressions.Literal
25-
import org.apache.spark.sql.types.{ArrayType, DecimalType}
25+
import org.apache.spark.sql.types.{ArrayType, DecimalType, FloatType, ShortType}
2626
import org.apache.spark.tags.DockerTest
2727

2828
@DockerTest
@@ -45,18 +45,20 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
4545
conn.prepareStatement("CREATE TYPE enum_type AS ENUM ('d1', 'd2')").executeUpdate()
4646
conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, "
4747
+ "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, "
48-
+ "c10 integer[], c11 text[], c12 real[], c13 numeric(2,2)[], c14 enum_type)").executeUpdate()
48+
+ "c10 integer[], c11 text[], c12 real[], c13 numeric(2,2)[], c14 enum_type, "
49+
+ "c15 float4, c16 smallint)").executeUpdate()
4950
conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
5051
+ "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', "
51-
+ """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', '{0.11, 0.22}', 'd1')""").executeUpdate()
52+
+ """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', '{0.11, 0.22}', 'd1', 1.01, 1)"""
53+
).executeUpdate()
5254
}
5355

5456
test("Type mapping for various types") {
5557
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
5658
val rows = df.collect()
5759
assert(rows.length == 1)
5860
val types = rows(0).toSeq.map(x => x.getClass)
59-
assert(types.length == 15)
61+
assert(types.length == 17)
6062
assert(classOf[String].isAssignableFrom(types(0)))
6163
assert(classOf[java.lang.Integer].isAssignableFrom(types(1)))
6264
assert(classOf[java.lang.Double].isAssignableFrom(types(2)))
@@ -72,6 +74,8 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
7274
assert(classOf[Seq[Double]].isAssignableFrom(types(12)))
7375
assert(classOf[Seq[BigDecimal]].isAssignableFrom(types(13)))
7476
assert(classOf[String].isAssignableFrom(types(14)))
77+
assert(classOf[java.lang.Float].isAssignableFrom(types(15)))
78+
assert(classOf[java.lang.Short].isAssignableFrom(types(16)))
7579
assert(rows(0).getString(0).equals("hello"))
7680
assert(rows(0).getInt(1) == 42)
7781
assert(rows(0).getDouble(2) == 1.25)
@@ -90,6 +94,8 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
9094
assert(rows(0).getSeq(12).toSeq == Seq(0.11f, 0.22f))
9195
assert(rows(0).getSeq(13) == Seq("0.11", "0.22").map(BigDecimal(_).bigDecimal))
9296
assert(rows(0).getString(14) == "d1")
97+
assert(rows(0).getFloat(15) == 1.01f)
98+
assert(rows(0).getShort(16) == 1)
9399
}
94100

95101
test("Basic write test") {
@@ -104,4 +110,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
104110
Column(Literal.create(null, a.dataType)).as(a.name)
105111
}: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties)
106112
}
113+
114+
test("Creating a table with shorts and floats") {
115+
sqlContext.createDataFrame(Seq((1.0f, 1.toShort)))
116+
.write.jdbc(jdbcUrl, "shortfloat", new Properties)
117+
val schema = sqlContext.read.jdbc(jdbcUrl, "shortfloat", new Properties).schema
118+
assert(schema(0).dataType == FloatType)
119+
assert(schema(1).dataType == ShortType)
120+
}
107121
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,10 @@ private[jdbc] class JDBCRDD(
390390
(rs: ResultSet, row: MutableRow, pos: Int) =>
391391
row.setLong(pos, rs.getLong(pos + 1))
392392

393+
case ShortType =>
394+
(rs: ResultSet, row: MutableRow, pos: Int) =>
395+
row.setShort(pos, rs.getShort(pos + 1))
396+
393397
case StringType =>
394398
(rs: ResultSet, row: MutableRow, pos: Int) =>
395399
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8

sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ private object PostgresDialect extends JdbcDialect {
2929

3030
override def getCatalystType(
3131
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
32-
if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
32+
if (sqlType == Types.REAL) {
33+
Some(FloatType)
34+
} else if (sqlType == Types.SMALLINT) {
35+
Some(ShortType)
36+
} else if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
3337
Some(BinaryType)
3438
} else if (sqlType == Types.OTHER) {
3539
Some(StringType)
@@ -66,6 +70,7 @@ private object PostgresDialect extends JdbcDialect {
6670
case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN))
6771
case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT))
6872
case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE))
73+
case ShortType => Some(JdbcType("SMALLINT", Types.SMALLINT))
6974
case t: DecimalType => Some(
7075
JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC))
7176
case ArrayType(et, _) if et.isInstanceOf[AtomicType] =>

0 commit comments

Comments
 (0)