diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index f0991f1927985..6769773cfec45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -108,10 +108,11 @@ object TableOutputResolver { case StoreAssignmentPolicy.LEGACY => outputField - case StoreAssignmentPolicy.STRICT => + case StoreAssignmentPolicy.STRICT | StoreAssignmentPolicy.ANSI => // run the type check first to ensure type errors are present val canWrite = DataType.canWrite( - queryExpr.dataType, tableAttr.dataType, byName, conf.resolver, tableAttr.name, addError) + queryExpr.dataType, tableAttr.dataType, byName, conf.resolver, tableAttr.name, + storeAssignmentPolicy, addError) if (queryExpr.nullable && !tableAttr.nullable) { addError(s"Cannot write nullable values to non-null column '${tableAttr.name}'") None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 32e2707948919..549d47a9226c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -158,6 +158,36 @@ object Cast { case _ => false } + def canANSIStoreAssign(from: DataType, to: DataType): Boolean = (from, to) match { + case _ if from == to => true + case (_: NumericType, _: NumericType) => true + case (_: AtomicType, StringType) => true + case (_: CalendarIntervalType, StringType) => true + case (DateType, TimestampType) => true + case (TimestampType, DateType) => true + // Spark supports casting between long and timestamp, please see `longToTimestamp` and + // `timestampToLong` for details. + case (TimestampType, LongType) => true + case (LongType, TimestampType) => true + + case (ArrayType(fromType, fn), ArrayType(toType, tn)) => + resolvableNullability(fn, tn) && canANSIStoreAssign(fromType, toType) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + resolvableNullability(fn, tn) && canANSIStoreAssign(fromKey, toKey) && + canANSIStoreAssign(fromValue, toValue) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { + case (f1, f2) => + resolvableNullability(f1.nullable, f2.nullable) && + canANSIStoreAssign(f1.dataType, f2.dataType) + } + + case _ => false + } + private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = { val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from) val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80697cb76aec0..3d5e4597d8d9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1644,14 +1644,15 @@ object SQLConf { .createWithDefault(PartitionOverwriteMode.STATIC.toString) object StoreAssignmentPolicy extends Enumeration { - val LEGACY, STRICT = Value + val ANSI, LEGACY, STRICT = Value } val STORE_ASSIGNMENT_POLICY = buildConf("spark.sql.storeAssignmentPolicy") .doc("When inserting a value into a column with different data type, Spark will perform " + - "type coercion. Currently we support 2 policies for the type coercion rules: legacy and " + - "strict. With legacy policy, Spark allows casting any value to any data type. " + + "type coercion. Currently we support 3 policies for the type coercion rules: ansi, " + + "legacy and strict. With ansi policy, Spark performs the type coercion as per ANSI SQL. " + + "With legacy policy, Spark allows casting any value to any data type. " + "The legacy policy is the only behavior in Spark 2.x and it is compatible with Hive. " + "With strict policy, Spark doesn't allow any possible precision loss or data truncation " + "in type coercion, e.g. `int` to `long` and `float` to `double` are not allowed." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index a35e971d08823..3a10a56f6937f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy.{ANSI, STRICT} import org.apache.spark.util.Utils /** @@ -371,12 +373,14 @@ object DataType { byName: Boolean, resolver: Resolver, context: String, + storeAssignmentPolicy: StoreAssignmentPolicy.Value, addError: String => Unit): Boolean = { (write, read) match { case (wArr: ArrayType, rArr: ArrayType) => // run compatibility check first to produce all error messages val typesCompatible = canWrite( - wArr.elementType, rArr.elementType, byName, resolver, context + ".element", addError) + wArr.elementType, rArr.elementType, byName, resolver, context + ".element", + storeAssignmentPolicy, addError) if (wArr.containsNull && !rArr.containsNull) { addError(s"Cannot write nullable elements to array of non-nulls: '$context'") @@ -391,9 +395,11 @@ object DataType { // run compatibility check first to produce all error messages val keyCompatible = canWrite( - wMap.keyType, rMap.keyType, byName, resolver, context + ".key", addError) + wMap.keyType, rMap.keyType, byName, resolver, context + ".key", + storeAssignmentPolicy, addError) val valueCompatible = canWrite( - wMap.valueType, rMap.valueType, byName, resolver, context + ".value", addError) + wMap.valueType, rMap.valueType, byName, resolver, context + ".value", + storeAssignmentPolicy, addError) if (wMap.valueContainsNull && !rMap.valueContainsNull) { addError(s"Cannot write nullable values to map of non-nulls: '$context'") @@ -409,7 +415,8 @@ object DataType { val nameMatch = resolver(wField.name, rField.name) || isSparkGeneratedName(wField.name) val fieldContext = s"$context.${rField.name}" val typesCompatible = canWrite( - wField.dataType, rField.dataType, byName, resolver, fieldContext, addError) + wField.dataType, rField.dataType, byName, resolver, fieldContext, + storeAssignmentPolicy, addError) if (byName && !nameMatch) { addError(s"Struct '$context' $i-th field name does not match " + @@ -441,7 +448,7 @@ object DataType { fieldCompatible - case (w: AtomicType, r: AtomicType) => + case (w: AtomicType, r: AtomicType) if storeAssignmentPolicy == STRICT => if (!Cast.canUpCast(w, r)) { addError(s"Cannot safely cast '$context': $w to $r") false @@ -449,6 +456,14 @@ object DataType { true } + case (w: AtomicType, r: AtomicType) if storeAssignmentPolicy == ANSI => + if (!Cast.canANSIStoreAssign(w, r)) { + addError(s"Cannot safely cast '$context': $w to $r") + false + } else { + true + } + case (w, r) if w.sameType(r) && !w.isInstanceOf[NullType] => true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala index c757015c754b7..eade9b6112fe4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.types._ -class V2AppendDataAnalysisSuite extends DataSourceV2AnalysisSuite { +class V2AppendDataANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { AppendData.byName(table, query) } @@ -37,7 +37,17 @@ class V2AppendDataAnalysisSuite extends DataSourceV2AnalysisSuite { } } -class V2OverwritePartitionsDynamicAnalysisSuite extends DataSourceV2AnalysisSuite { +class V2AppendDataStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + AppendData.byName(table, query) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + AppendData.byPosition(table, query) + } +} + +class V2OverwritePartitionsDynamicANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { OverwritePartitionsDynamic.byName(table, query) } @@ -47,7 +57,17 @@ class V2OverwritePartitionsDynamicAnalysisSuite extends DataSourceV2AnalysisSuit } } -class V2OverwriteByExpressionAnalysisSuite extends DataSourceV2AnalysisSuite { +class V2OverwritePartitionsDynamicStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwritePartitionsDynamic.byName(table, query) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwritePartitionsDynamic.byPosition(table, query) + } +} + +class V2OverwriteByExpressionANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { OverwriteByExpression.byName(table, query, Literal(true)) } @@ -104,6 +124,12 @@ class V2OverwriteByExpressionAnalysisSuite extends DataSourceV2AnalysisSuite { } } +class V2OverwriteByExpressionStrictAnalysisSuite extends V2OverwriteByExpressionANSIAnalysisSuite { + override def getSQLConf(caseSensitive: Boolean): SQLConf = + super.getSQLConf(caseSensitive) + .copy(SQLConf.STORE_ASSIGNMENT_POLICY -> StoreAssignmentPolicy.STRICT) +} + case class TestRelation(output: Seq[AttributeReference]) extends LeafNode with NamedRelation { override def name: String = "table-name" } @@ -114,12 +140,85 @@ case class TestRelationAcceptAnySchema(output: Seq[AttributeReference]) override def skipSchemaResolution: Boolean = true } -abstract class DataSourceV2AnalysisSuite extends AnalysisTest { +abstract class DataSourceV2ANSIAnalysisSuite extends DataSourceV2AnalysisBaseSuite { + override def getSQLConf(caseSensitive: Boolean): SQLConf = + super.getSQLConf(caseSensitive) + .copy(SQLConf.STORE_ASSIGNMENT_POLICY -> StoreAssignmentPolicy.ANSI) +} - override def getAnalyzer(caseSensitive: Boolean): Analyzer = { - val conf = new SQLConf() - .copy(SQLConf.CASE_SENSITIVE -> caseSensitive) +abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseSuite { + override def getSQLConf(caseSensitive: Boolean): SQLConf = + super.getSQLConf(caseSensitive) .copy(SQLConf.STORE_ASSIGNMENT_POLICY -> StoreAssignmentPolicy.STRICT) + + test("byName: fail canWrite check") { + val parsedPlan = byName(table, widerTable) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", + "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) + } + + test("byName: multiple field errors are reported") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("x", DoubleType), + StructField("b", FloatType))).toAttributes) + + val parsedPlan = byName(xRequiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot safely cast", "'x'", "DoubleType to FloatType", + "Cannot write nullable values to non-null column", "'x'", + "Cannot find data for output column", "'y'")) + } + + + test("byPosition: fail canWrite check") { + val widerTable = TestRelation(StructType(Seq( + StructField("a", DoubleType), + StructField("b", DoubleType))).toAttributes) + + val parsedPlan = byPosition(table, widerTable) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", + "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) + } + + test("byPosition: multiple field errors are reported") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("x", DoubleType), + StructField("b", FloatType))).toAttributes) + + val parsedPlan = byPosition(xRequiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot write nullable values to non-null column", "'x'", + "Cannot safely cast", "'x'", "DoubleType to FloatType")) + } +} + +abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest { + + protected def getSQLConf(caseSensitive: Boolean): SQLConf = + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) + + override def getAnalyzer(caseSensitive: Boolean): Analyzer = { + val conf = getSQLConf(caseSensitive) val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin, conf) catalog.createDatabase( CatalogDatabase("default", "", new URI("loc"), Map.empty), @@ -254,15 +353,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'x'")) } - test("byName: fail canWrite check") { - val parsedPlan = byName(table, widerTable) - - assertNotResolved(parsedPlan) - assertAnalysisError(parsedPlan, Seq( - "Cannot write", "'table-name'", - "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) - } - test("byName: insert safe cast") { val x = table.output.head val y = table.output.last @@ -294,25 +384,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'x', 'y', 'z'")) } - test("byName: multiple field errors are reported") { - val xRequiredTable = TestRelation(StructType(Seq( - StructField("x", FloatType, nullable = false), - StructField("y", DoubleType))).toAttributes) - - val query = TestRelation(StructType(Seq( - StructField("x", DoubleType), - StructField("b", FloatType))).toAttributes) - - val parsedPlan = byName(xRequiredTable, query) - - assertNotResolved(parsedPlan) - assertAnalysisError(parsedPlan, Seq( - "Cannot write incompatible data to table", "'table-name'", - "Cannot safely cast", "'x'", "DoubleType to FloatType", - "Cannot write nullable values to non-null column", "'x'", - "Cannot find data for output column", "'y'")) - } - test("byPosition: basic behavior") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), @@ -396,19 +467,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'y'")) } - test("byPosition: fail canWrite check") { - val widerTable = TestRelation(StructType(Seq( - StructField("a", DoubleType), - StructField("b", DoubleType))).toAttributes) - - val parsedPlan = byPosition(table, widerTable) - - assertNotResolved(parsedPlan) - assertAnalysisError(parsedPlan, Seq( - "Cannot write", "'table-name'", - "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) - } - test("byPosition: insert safe cast") { val widerTable = TestRelation(StructType(Seq( StructField("a", DoubleType), @@ -444,24 +502,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'a', 'b', 'c'")) } - test("byPosition: multiple field errors are reported") { - val xRequiredTable = TestRelation(StructType(Seq( - StructField("x", FloatType, nullable = false), - StructField("y", DoubleType))).toAttributes) - - val query = TestRelation(StructType(Seq( - StructField("x", DoubleType), - StructField("b", FloatType))).toAttributes) - - val parsedPlan = byPosition(xRequiredTable, query) - - assertNotResolved(parsedPlan) - assertAnalysisError(parsedPlan, Seq( - "Cannot write incompatible data to table", "'table-name'", - "Cannot write nullable values to non-null column", "'x'", - "Cannot safely cast", "'x'", "DoubleType to FloatType")) - } - test("bypass output column resolution") { val table = TestRelationAcceptAnySchema(StructType(Seq( StructField("a", FloatType, nullable = false), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala index 6b5fc5f0d4434..784cc7a70489f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -22,20 +22,136 @@ import scala.collection.mutable import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy -class DataTypeWriteCompatibilitySuite extends SparkFunSuite { - private val atomicTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, - DoubleType, DateType, TimestampType, StringType, BinaryType) +class StrictDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBaseSuite { + override def storeAssignmentPolicy: SQLConf.StoreAssignmentPolicy.Value = + StoreAssignmentPolicy.STRICT - private val point2 = StructType(Seq( + override def canCast: (DataType, DataType) => Boolean = Cast.canUpCast + + test("Check struct types: unsafe casts are not allowed") { + assertNumErrors(widerPoint2, point2, "t", + "Should fail because types require unsafe casts", 2) { errs => + + assert(errs(0).contains("'t.x'"), "Should include the nested field name context") + assert(errs(0).contains("Cannot safely cast")) + + assert(errs(1).contains("'t.y'"), "Should include the nested field name context") + assert(errs(1).contains("Cannot safely cast")) + } + } + + test("Check array types: unsafe casts are not allowed") { + val arrayOfLong = ArrayType(LongType) + val arrayOfInt = ArrayType(IntegerType) + + assertSingleError(arrayOfLong, arrayOfInt, "arr", + "Should not allow array of longs to array of ints") { err => + assert(err.contains("'arr.element'"), + "Should identify problem with named array's element type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map value types: casting Long to Integer is not allowed") { + val mapOfLong = MapType(StringType, LongType) + val mapOfInt = MapType(StringType, IntegerType) + + assertSingleError(mapOfLong, mapOfInt, "m", + "Should not allow map of longs to map of ints") { err => + assert(err.contains("'m.value'"), "Should identify problem with named map's value type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map key types: unsafe casts are not allowed") { + val mapKeyLong = MapType(LongType, StringType) + val mapKeyInt = MapType(IntegerType, StringType) + + assertSingleError(mapKeyLong, mapKeyInt, "m", + "Should not allow map of long keys to map of int keys") { err => + assert(err.contains("'m.key'"), "Should identify problem with named map's key type") + assert(err.contains("Cannot safely cast")) + } + } +} + +class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBaseSuite { + override protected def storeAssignmentPolicy: SQLConf.StoreAssignmentPolicy.Value = + StoreAssignmentPolicy.ANSI + + override def canCast: (DataType, DataType) => Boolean = Cast.canANSIStoreAssign + + test("Check map value types: unsafe casts are not allowed") { + val mapOfString = MapType(StringType, StringType) + val mapOfInt = MapType(StringType, IntegerType) + + assertSingleError(mapOfString, mapOfInt, "m", + "Should not allow map of strings to map of ints") { err => + assert(err.contains("'m.value'"), "Should identify problem with named map's value type") + assert(err.contains("Cannot safely cast")) + } + } + + private val stringPoint2 = StructType(Seq( + StructField("x", StringType, nullable = false), + StructField("y", StringType, nullable = false))) + + test("Check struct types: unsafe casts are not allowed") { + assertNumErrors(stringPoint2, point2, "t", + "Should fail because types require unsafe casts", 2) { errs => + + assert(errs(0).contains("'t.x'"), "Should include the nested field name context") + assert(errs(0).contains("Cannot safely cast")) + + assert(errs(1).contains("'t.y'"), "Should include the nested field name context") + assert(errs(1).contains("Cannot safely cast")) + } + } + + test("Check array types: unsafe casts are not allowed") { + val arrayOfString = ArrayType(StringType) + val arrayOfInt = ArrayType(IntegerType) + + assertSingleError(arrayOfString, arrayOfInt, "arr", + "Should not allow array of strings to array of ints") { err => + assert(err.contains("'arr.element'"), + "Should identify problem with named array's element type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map key types: unsafe casts are not allowed") { + val mapKeyString = MapType(StringType, StringType) + val mapKeyInt = MapType(IntegerType, StringType) + + assertSingleError(mapKeyString, mapKeyInt, "m", + "Should not allow map of string keys to map of int keys") { err => + assert(err.contains("'m.key'"), "Should identify problem with named map's key type") + assert(err.contains("Cannot safely cast")) + } + } +} + +abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { + protected def storeAssignmentPolicy: StoreAssignmentPolicy.Value + + protected def canCast: (DataType, DataType) => Boolean + + protected val atomicTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DateType, TimestampType, StringType, BinaryType) + + protected val point2 = StructType(Seq( StructField("x", FloatType, nullable = false), StructField("y", FloatType, nullable = false))) - private val widerPoint2 = StructType(Seq( + protected val widerPoint2 = StructType(Seq( StructField("x", DoubleType, nullable = false), StructField("y", DoubleType, nullable = false))) - private val point3 = StructType(Seq( + protected val point3 = StructType(Seq( StructField("x", FloatType, nullable = false), StructField("y", FloatType, nullable = false), StructField("z", FloatType))) @@ -67,7 +183,7 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { test("Check atomic types: write allowed only when casting is safe") { atomicTypes.foreach { w => atomicTypes.foreach { r => - if (Cast.canUpCast(w, r)) { + if (canCast(w, r)) { assertAllowed(w, r, "t", s"Should allow writing $w to $r because cast is safe") } else { @@ -172,18 +288,6 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { } } - test("Check struct types: unsafe casts are not allowed") { - assertNumErrors(widerPoint2, point2, "t", - "Should fail because types require unsafe casts", 2) { errs => - - assert(errs(0).contains("'t.x'"), "Should include the nested field name context") - assert(errs(0).contains("Cannot safely cast")) - - assert(errs(1).contains("'t.y'"), "Should include the nested field name context") - assert(errs(1).contains("Cannot safely cast")) - } - } - test("Check struct types: type promotion is allowed") { assertAllowed(point2, widerPoint2, "t", "Should allow widening float fields x and y to double") @@ -203,18 +307,6 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { "Should allow writing point (x,y) to point(x,y,z=null)") } - test("Check array types: unsafe casts are not allowed") { - val arrayOfLong = ArrayType(LongType) - val arrayOfInt = ArrayType(IntegerType) - - assertSingleError(arrayOfLong, arrayOfInt, "arr", - "Should not allow array of longs to array of ints") { err => - assert(err.contains("'arr.element'"), - "Should identify problem with named array's element type") - assert(err.contains("Cannot safely cast")) - } - } - test("Check array types: type promotion is allowed") { val arrayOfLong = ArrayType(LongType) val arrayOfInt = ArrayType(IntegerType) @@ -241,17 +333,6 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { "Should allow array of required elements to array of optional elements") } - test("Check map value types: unsafe casts are not allowed") { - val mapOfLong = MapType(StringType, LongType) - val mapOfInt = MapType(StringType, IntegerType) - - assertSingleError(mapOfLong, mapOfInt, "m", - "Should not allow map of longs to map of ints") { err => - assert(err.contains("'m.value'"), "Should identify problem with named map's value type") - assert(err.contains("Cannot safely cast")) - } - } - test("Check map value types: type promotion is allowed") { val mapOfLong = MapType(StringType, LongType) val mapOfInt = MapType(StringType, IntegerType) @@ -278,17 +359,6 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { "Should allow map of required elements to map of optional elements") } - test("Check map key types: unsafe casts are not allowed") { - val mapKeyLong = MapType(LongType, StringType) - val mapKeyInt = MapType(IntegerType, StringType) - - assertSingleError(mapKeyLong, mapKeyInt, "m", - "Should not allow map of long keys to map of int keys") { err => - assert(err.contains("'m.key'"), "Should identify problem with named map's key type") - assert(err.contains("Cannot safely cast")) - } - } - test("Check map key types: type promotion is allowed") { val mapKeyLong = MapType(LongType, StringType) val mapKeyInt = MapType(IntegerType, StringType) @@ -317,9 +387,9 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { StructField("a", ArrayType(StringType)), StructField("arr_of_structs", ArrayType(point3)), StructField("bad_nested_type", point3), - StructField("m", MapType(DoubleType, DoubleType)), + StructField("m", MapType(StringType, BooleanType)), StructField("map_of_structs", MapType(StringType, missingMiddleField)), - StructField("y", LongType) + StructField("y", StringType) )) assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) { errs => @@ -342,11 +412,11 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { assert(errs(5).contains("'top.m.key'"), "Should identify bad type") assert(errs(5).contains("Cannot safely cast")) - assert(errs(5).contains("DoubleType to LongType")) + assert(errs(5).contains("StringType to LongType")) assert(errs(6).contains("'top.m.value'"), "Should identify bad type") assert(errs(6).contains("Cannot safely cast")) - assert(errs(6).contains("DoubleType to FloatType")) + assert(errs(6).contains("BooleanType to FloatType")) assert(errs(7).contains("'top.m'"), "Should identify bad type") assert(errs(7).contains("Cannot write nullable values to map of non-nulls")) @@ -364,7 +434,7 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { assert(errs(11).contains("'top.x'"), "Should identify bad type") assert(errs(11).contains("Cannot safely cast")) - assert(errs(11).contains("LongType to IntegerType")) + assert(errs(11).contains("StringType to IntegerType")) assert(errs(12).contains("'top'"), "Should identify bad type") assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name mismatch") @@ -386,6 +456,7 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { byName: Boolean = true): Unit = { assert( DataType.canWrite(writeType, readType, byName, analysis.caseSensitiveResolution, name, + storeAssignmentPolicy, errMsg => fail(s"Should not produce errors but was called with: $errMsg")), desc) } @@ -411,7 +482,7 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { val errs = new mutable.ArrayBuffer[String]() assert( DataType.canWrite(writeType, readType, byName, analysis.caseSensitiveResolution, name, - errMsg => errs += errMsg) === false, desc) + storeAssignmentPolicy, errMsg => errs += errMsg) === false, desc) assert(errs.size === numErrs, s"Should produce $numErrs error messages") checkErrors(errs) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index a55aa7b28ce23..09b99508c1182 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import java.io.File +import java.sql.Date import org.apache.spark.SparkException import org.apache.spark.sql._ @@ -582,6 +583,57 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } } + test("Throw exception on unsafe cast with ANSI casting policy") { + withSQLConf( + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "parquet", + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.ANSI.toString) { + withTable("t") { + sql("create table t(i int, d double) using parquet") + var msg = intercept[AnalysisException] { + sql("insert into t values('a', 'b')") + }.getMessage + assert(msg.contains("Cannot safely cast 'i': StringType to IntegerType") && + msg.contains("Cannot safely cast 'd': StringType to DoubleType")) + msg = intercept[AnalysisException] { + sql("insert into t values(now(), now())") + }.getMessage + assert(msg.contains("Cannot safely cast 'i': TimestampType to IntegerType") && + msg.contains("Cannot safely cast 'd': TimestampType to DoubleType")) + msg = intercept[AnalysisException] { + sql("insert into t values(true, false)") + }.getMessage + assert(msg.contains("Cannot safely cast 'i': BooleanType to IntegerType") && + msg.contains("Cannot safely cast 'd': BooleanType to DoubleType")) + } + } + } + + test("Allow on writing any numeric value to numeric type with ANSI policy") { + withSQLConf( + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "parquet", + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.ANSI.toString) { + withTable("t") { + sql("create table t(i int, d float) using parquet") + sql("insert into t values(1L, 2.0)") + sql("insert into t values(3.0, 4)") + sql("insert into t values(5.0, 6L)") + checkAnswer(sql("select * from t"), Seq(Row(1, 2.0F), Row(3, 4.0F), Row(5, 6.0F))) + } + } + } + + test("Allow on writing timestamp value to date type with ANSI policy") { + withSQLConf( + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "parquet", + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.ANSI.toString) { + withTable("t") { + sql("create table t(i date) using parquet") + sql("insert into t values(TIMESTAMP('2010-09-02 14:10:10'))") + checkAnswer(sql("select * from t"), Seq(Row(Date.valueOf("2010-09-02")))) + } + } + } + test("SPARK-24860: dynamic partition overwrite specified per source without catalog table") { withTempPath { path => Seq((1, 1), (2, 2)).toDF("i", "part") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index d37e53bc5ac08..922c0ee4525d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -327,6 +327,28 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with } } + test("Throw exception on unsafe cast with ANSI casting policy") { + withSQLConf( + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "parquet", + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.ANSI.toString) { + withTable("t") { + sql("create table t(i int, d double) using parquet") + // Calling `saveAsTable` to an existing table with append mode results in table insertion. + var msg = intercept[AnalysisException] { + Seq(("a", "b")).toDF("i", "d").write.mode("append").saveAsTable("t") + }.getMessage + assert(msg.contains("Cannot safely cast 'i': StringType to IntegerType") && + msg.contains("Cannot safely cast 'd': StringType to DoubleType")) + + msg = intercept[AnalysisException] { + Seq((true, false)).toDF("i", "d").write.mode("append").saveAsTable("t") + }.getMessage + assert(msg.contains("Cannot safely cast 'i': BooleanType to IntegerType") && + msg.contains("Cannot safely cast 'd': BooleanType to DoubleType")) + } + } + } + test("test path option in load") { spark.read .format("org.apache.spark.sql.test")