diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index efd7ca74c796b..5738307095933 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -21,6 +21,7 @@ import java.math.BigDecimal import java.sql.{Connection, Date, Timestamp} import java.util.Properties +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.tags.DockerTest @DockerTest @@ -112,36 +113,58 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { } test("Numeric types") { - val df = spark.read.jdbc(jdbcUrl, "numbers", new Properties) - val rows = df.collect() - assert(rows.length == 1) - val row = rows(0) - val types = row.toSeq.map(x => x.getClass.toString) - assert(types.length == 12) - assert(types(0).equals("class java.lang.Boolean")) - assert(types(1).equals("class java.lang.Integer")) - assert(types(2).equals("class java.lang.Short")) - assert(types(3).equals("class java.lang.Integer")) - assert(types(4).equals("class java.lang.Long")) - assert(types(5).equals("class java.lang.Double")) - assert(types(6).equals("class java.lang.Float")) - assert(types(7).equals("class java.lang.Float")) - assert(types(8).equals("class java.math.BigDecimal")) - assert(types(9).equals("class java.math.BigDecimal")) - assert(types(10).equals("class java.math.BigDecimal")) - assert(types(11).equals("class java.math.BigDecimal")) - assert(row.getBoolean(0) == false) - assert(row.getInt(1) == 255) - assert(row.getShort(2) == 32767) - assert(row.getInt(3) == 2147483647) - assert(row.getLong(4) == 9223372036854775807L) - assert(row.getDouble(5) == 1.2345678901234512E14) // float = float(53) has 15-digits precision - assert(row.getFloat(6) == 1.23456788103168E14) // float(24) has 7-digits precision - assert(row.getFloat(7) == 1.23456788103168E14) // real = float(24) - assert(row.getAs[BigDecimal](8).equals(new BigDecimal("123.00"))) - assert(row.getAs[BigDecimal](9).equals(new BigDecimal("12345.12000"))) - assert(row.getAs[BigDecimal](10).equals(new BigDecimal("922337203685477.5800"))) - assert(row.getAs[BigDecimal](11).equals(new BigDecimal("214748.3647"))) + Seq(true, false).foreach { flag => + withSQLConf(SQLConf.LEGACY_MSSQLSERVER_NUMERIC_MAPPING_ENABLED.key -> s"$flag") { + val df = spark.read.jdbc(jdbcUrl, "numbers", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val row = rows(0) + val types = row.toSeq.map(x => x.getClass.toString) + assert(types.length == 12) + assert(types(0).equals("class java.lang.Boolean")) + assert(types(1).equals("class java.lang.Integer")) + if (flag) { + assert(types(2).equals("class java.lang.Integer")) + } else { + assert(types(2).equals("class java.lang.Short")) + } + assert(types(3).equals("class java.lang.Integer")) + assert(types(4).equals("class java.lang.Long")) + assert(types(5).equals("class java.lang.Double")) + if (flag) { + assert(types(6).equals("class java.lang.Double")) + assert(types(7).equals("class java.lang.Double")) + } else { + assert(types(6).equals("class java.lang.Float")) + assert(types(7).equals("class java.lang.Float")) + } + assert(types(8).equals("class java.math.BigDecimal")) + assert(types(9).equals("class java.math.BigDecimal")) + assert(types(10).equals("class java.math.BigDecimal")) + assert(types(11).equals("class java.math.BigDecimal")) + assert(row.getBoolean(0) == false) + assert(row.getInt(1) == 255) + if (flag) { + assert(row.getInt(2) == 32767) + } else { + assert(row.getShort(2) == 32767) + } + assert(row.getInt(3) == 2147483647) + assert(row.getLong(4) == 9223372036854775807L) + assert(row.getDouble(5) == 1.2345678901234512E14) // float(53) has 15-digits precision + if (flag) { + assert(row.getDouble(6) == 1.23456788103168E14) // float(24) has 7-digits precision + assert(row.getDouble(7) == 1.23456788103168E14) // real = float(24) + } else { + assert(row.getFloat(6) == 1.23456788103168E14) // float(24) has 7-digits precision + assert(row.getFloat(7) == 1.23456788103168E14) // real = float(24) + } + assert(row.getAs[BigDecimal](8).equals(new BigDecimal("123.00"))) + assert(row.getAs[BigDecimal](9).equals(new BigDecimal("12345.12000"))) + assert(row.getAs[BigDecimal](10).equals(new BigDecimal("922337203685477.5800"))) + assert(row.getAs[BigDecimal](11).equals(new BigDecimal("214748.3647"))) + } + } } test("Date types") { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 41a429303efb4..06fd7a79c4c93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2161,6 +2161,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val LEGACY_MSSQLSERVER_NUMERIC_MAPPING_ENABLED = + buildConf("spark.sql.legacy.mssqlserver.numericMapping.enabled") + .internal() + .doc("When true, use legacy MySqlServer SMALLINT and REAL type mapping.") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * @@ -2417,6 +2424,9 @@ class SQLConf extends Serializable with Logging { def addDirectoryRecursiveEnabled: Boolean = getConf(LEGACY_ADD_DIRECTORY_USING_RECURSIVE) + def legacyMsSqlServerNumericMappingEnabled: Boolean = + getConf(LEGACY_MSSQLSERVER_NUMERIC_MAPPING_ENABLED) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 2511067abc3fd..72284b5996201 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.util.Locale +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -33,10 +34,14 @@ private object MsSqlServerDialect extends JdbcDialect { // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients Option(StringType) } else { - sqlType match { - case java.sql.Types.SMALLINT => Some(ShortType) - case java.sql.Types.REAL => Some(FloatType) - case _ => None + if (SQLConf.get.legacyMsSqlServerNumericMappingEnabled) { + None + } else { + sqlType match { + case java.sql.Types.SMALLINT => Some(ShortType) + case java.sql.Types.REAL => Some(FloatType) + case _ => None + } } } } @@ -46,7 +51,8 @@ private object MsSqlServerDialect extends JdbcDialect { case StringType => Some(JdbcType("NVARCHAR(MAX)", java.sql.Types.NVARCHAR)) case BooleanType => Some(JdbcType("BIT", java.sql.Types.BIT)) case BinaryType => Some(JdbcType("VARBINARY(MAX)", java.sql.Types.VARBINARY)) - case ShortType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case ShortType if !SQLConf.get.legacyMsSqlServerNumericMappingEnabled => + Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 66ddc6ee83d06..9cba95f7d7df2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -884,17 +884,37 @@ class JDBCSuite extends QueryTest "BIT") assert(msSqlServerDialect.getJDBCType(BinaryType).map(_.databaseTypeDefinition).get == "VARBINARY(MAX)") - assert(msSqlServerDialect.getJDBCType(ShortType).map(_.databaseTypeDefinition).get == - "SMALLINT") + Seq(true, false).foreach { flag => + withSQLConf(SQLConf.LEGACY_MSSQLSERVER_NUMERIC_MAPPING_ENABLED.key -> s"$flag") { + if (SQLConf.get.legacyMsSqlServerNumericMappingEnabled) { + assert(msSqlServerDialect.getJDBCType(ShortType).map(_.databaseTypeDefinition).isEmpty) + } else { + assert(msSqlServerDialect.getJDBCType(ShortType).map(_.databaseTypeDefinition).get == + "SMALLINT") + } + } + } } test("SPARK-28152 MsSqlServerDialect catalyst type mapping") { val msSqlServerDialect = JdbcDialects.get("jdbc:sqlserver") val metadata = new MetadataBuilder().putLong("scale", 1) - assert(msSqlServerDialect.getCatalystType(java.sql.Types.SMALLINT, "SMALLINT", 1, - metadata).get == ShortType) - assert(msSqlServerDialect.getCatalystType(java.sql.Types.REAL, "REAL", 1, - metadata).get == FloatType) + + Seq(true, false).foreach { flag => + withSQLConf(SQLConf.LEGACY_MSSQLSERVER_NUMERIC_MAPPING_ENABLED.key -> s"$flag") { + if (SQLConf.get.legacyMsSqlServerNumericMappingEnabled) { + assert(msSqlServerDialect.getCatalystType(java.sql.Types.SMALLINT, "SMALLINT", 1, + metadata).isEmpty) + assert(msSqlServerDialect.getCatalystType(java.sql.Types.REAL, "REAL", 1, + metadata).isEmpty) + } else { + assert(msSqlServerDialect.getCatalystType(java.sql.Types.SMALLINT, "SMALLINT", 1, + metadata).get == ShortType) + assert(msSqlServerDialect.getCatalystType(java.sql.Types.REAL, "REAL", 1, + metadata).get == FloatType) + } + } + } } test("table exists query by jdbc dialect") {