Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.math.BigDecimal
import java.sql.{Connection, Date, Timestamp}
import java.util.Properties

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.tags.DockerTest

@DockerTest
Expand Down Expand Up @@ -112,36 +113,58 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
}

test("Numeric types") {
val df = spark.read.jdbc(jdbcUrl, "numbers", new Properties)
val rows = df.collect()
assert(rows.length == 1)
val row = rows(0)
val types = row.toSeq.map(x => x.getClass.toString)
assert(types.length == 12)
assert(types(0).equals("class java.lang.Boolean"))
assert(types(1).equals("class java.lang.Integer"))
assert(types(2).equals("class java.lang.Short"))
assert(types(3).equals("class java.lang.Integer"))
assert(types(4).equals("class java.lang.Long"))
assert(types(5).equals("class java.lang.Double"))
assert(types(6).equals("class java.lang.Float"))
assert(types(7).equals("class java.lang.Float"))
assert(types(8).equals("class java.math.BigDecimal"))
assert(types(9).equals("class java.math.BigDecimal"))
assert(types(10).equals("class java.math.BigDecimal"))
assert(types(11).equals("class java.math.BigDecimal"))
assert(row.getBoolean(0) == false)
assert(row.getInt(1) == 255)
assert(row.getShort(2) == 32767)
assert(row.getInt(3) == 2147483647)
assert(row.getLong(4) == 9223372036854775807L)
assert(row.getDouble(5) == 1.2345678901234512E14) // float = float(53) has 15-digits precision
assert(row.getFloat(6) == 1.23456788103168E14) // float(24) has 7-digits precision
assert(row.getFloat(7) == 1.23456788103168E14) // real = float(24)
assert(row.getAs[BigDecimal](8).equals(new BigDecimal("123.00")))
assert(row.getAs[BigDecimal](9).equals(new BigDecimal("12345.12000")))
assert(row.getAs[BigDecimal](10).equals(new BigDecimal("922337203685477.5800")))
assert(row.getAs[BigDecimal](11).equals(new BigDecimal("214748.3647")))
Seq(true, false).foreach { flag =>
withSQLConf(SQLConf.LEGACY_MSSQLSERVER_NUMERIC_MAPPING_ENABLED.key -> s"$flag") {
val df = spark.read.jdbc(jdbcUrl, "numbers", new Properties)
val rows = df.collect()
assert(rows.length == 1)
val row = rows(0)
val types = row.toSeq.map(x => x.getClass.toString)
assert(types.length == 12)
assert(types(0).equals("class java.lang.Boolean"))
assert(types(1).equals("class java.lang.Integer"))
if (flag) {
assert(types(2).equals("class java.lang.Integer"))
} else {
assert(types(2).equals("class java.lang.Short"))
}
assert(types(3).equals("class java.lang.Integer"))
assert(types(4).equals("class java.lang.Long"))
assert(types(5).equals("class java.lang.Double"))
if (flag) {
assert(types(6).equals("class java.lang.Double"))
assert(types(7).equals("class java.lang.Double"))
} else {
assert(types(6).equals("class java.lang.Float"))
assert(types(7).equals("class java.lang.Float"))
}
assert(types(8).equals("class java.math.BigDecimal"))
assert(types(9).equals("class java.math.BigDecimal"))
assert(types(10).equals("class java.math.BigDecimal"))
assert(types(11).equals("class java.math.BigDecimal"))
assert(row.getBoolean(0) == false)
assert(row.getInt(1) == 255)
if (flag) {
assert(row.getInt(2) == 32767)
} else {
assert(row.getShort(2) == 32767)
}
assert(row.getInt(3) == 2147483647)
assert(row.getLong(4) == 9223372036854775807L)
assert(row.getDouble(5) == 1.2345678901234512E14) // float(53) has 15-digits precision
if (flag) {
assert(row.getDouble(6) == 1.23456788103168E14) // float(24) has 7-digits precision
assert(row.getDouble(7) == 1.23456788103168E14) // real = float(24)
} else {
assert(row.getFloat(6) == 1.23456788103168E14) // float(24) has 7-digits precision
assert(row.getFloat(7) == 1.23456788103168E14) // real = float(24)
}
assert(row.getAs[BigDecimal](8).equals(new BigDecimal("123.00")))
assert(row.getAs[BigDecimal](9).equals(new BigDecimal("12345.12000")))
assert(row.getAs[BigDecimal](10).equals(new BigDecimal("922337203685477.5800")))
assert(row.getAs[BigDecimal](11).equals(new BigDecimal("214748.3647")))
}
}
}

test("Date types") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2161,6 +2161,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val LEGACY_MSSQLSERVER_NUMERIC_MAPPING_ENABLED =
buildConf("spark.sql.legacy.mssqlserver.numericMapping.enabled")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For such a change, could we have an item in the migration guide?

Copy link
Member Author

@dongjoon-hyun dongjoon-hyun Jan 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Can we handle that after releasing 2.4.5 since It's a documentation change?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

.internal()
.doc("When true, use legacy MySqlServer SMALLINT and REAL type mapping.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not very clear what is the legacy MySqlServer SMALLINT and REAL type mapping if they do not read the PR. Could we make it clear? The behavior when the value is set to true and the behavior when the value is false

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course, I'd like to mention it at the migration guide fully.

.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -2417,6 +2424,9 @@ class SQLConf extends Serializable with Logging {

def addDirectoryRecursiveEnabled: Boolean = getConf(LEGACY_ADD_DIRECTORY_USING_RECURSIVE)

def legacyMsSqlServerNumericMappingEnabled: Boolean =
getConf(LEGACY_MSSQLSERVER_NUMERIC_MAPPING_ENABLED)

/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
* identifiers are equal.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc

import java.util.Locale

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._


Expand All @@ -33,10 +34,14 @@ private object MsSqlServerDialect extends JdbcDialect {
// String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients
Option(StringType)
} else {
sqlType match {
case java.sql.Types.SMALLINT => Some(ShortType)
case java.sql.Types.REAL => Some(FloatType)
case _ => None
if (SQLConf.get.legacyMsSqlServerNumericMappingEnabled) {
None
} else {
sqlType match {
case java.sql.Types.SMALLINT => Some(ShortType)
case java.sql.Types.REAL => Some(FloatType)
case _ => None
}
}
}
}
Expand All @@ -46,7 +51,8 @@ private object MsSqlServerDialect extends JdbcDialect {
case StringType => Some(JdbcType("NVARCHAR(MAX)", java.sql.Types.NVARCHAR))
case BooleanType => Some(JdbcType("BIT", java.sql.Types.BIT))
case BinaryType => Some(JdbcType("VARBINARY(MAX)", java.sql.Types.VARBINARY))
case ShortType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT))
case ShortType if !SQLConf.get.legacyMsSqlServerNumericMappingEnabled =>
Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT))
case _ => None
}

Expand Down
32 changes: 26 additions & 6 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -884,17 +884,37 @@ class JDBCSuite extends QueryTest
"BIT")
assert(msSqlServerDialect.getJDBCType(BinaryType).map(_.databaseTypeDefinition).get ==
"VARBINARY(MAX)")
assert(msSqlServerDialect.getJDBCType(ShortType).map(_.databaseTypeDefinition).get ==
"SMALLINT")
Seq(true, false).foreach { flag =>
withSQLConf(SQLConf.LEGACY_MSSQLSERVER_NUMERIC_MAPPING_ENABLED.key -> s"$flag") {
if (SQLConf.get.legacyMsSqlServerNumericMappingEnabled) {
assert(msSqlServerDialect.getJDBCType(ShortType).map(_.databaseTypeDefinition).isEmpty)
} else {
assert(msSqlServerDialect.getJDBCType(ShortType).map(_.databaseTypeDefinition).get ==
"SMALLINT")
}
}
}
}

test("SPARK-28152 MsSqlServerDialect catalyst type mapping") {
val msSqlServerDialect = JdbcDialects.get("jdbc:sqlserver")
val metadata = new MetadataBuilder().putLong("scale", 1)
assert(msSqlServerDialect.getCatalystType(java.sql.Types.SMALLINT, "SMALLINT", 1,
metadata).get == ShortType)
assert(msSqlServerDialect.getCatalystType(java.sql.Types.REAL, "REAL", 1,
metadata).get == FloatType)

Seq(true, false).foreach { flag =>
withSQLConf(SQLConf.LEGACY_MSSQLSERVER_NUMERIC_MAPPING_ENABLED.key -> s"$flag") {
if (SQLConf.get.legacyMsSqlServerNumericMappingEnabled) {
assert(msSqlServerDialect.getCatalystType(java.sql.Types.SMALLINT, "SMALLINT", 1,
metadata).isEmpty)
assert(msSqlServerDialect.getCatalystType(java.sql.Types.REAL, "REAL", 1,
metadata).isEmpty)
} else {
assert(msSqlServerDialect.getCatalystType(java.sql.Types.SMALLINT, "SMALLINT", 1,
metadata).get == ShortType)
assert(msSqlServerDialect.getCatalystType(java.sql.Types.REAL, "REAL", 1,
metadata).get == FloatType)
}
}
}
}

test("table exists query by jdbc dialect") {
Expand Down