From 39c331e5951ddef1829f7bf061e2e141ff4c19b1 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 26 Aug 2019 19:27:13 +0800 Subject: [PATCH 1/6] ASNI mode --- .../analysis/TableOutputResolver.scala | 5 +- .../spark/sql/catalyst/expressions/Cast.scala | 29 +++ .../apache/spark/sql/internal/SQLConf.scala | 7 +- .../org/apache/spark/sql/types/DataType.scala | 25 ++- .../analysis/DataSourceV2AnalysisSuite.scala | 172 +++++++++++------- .../DataTypeWriteCompatibilitySuite.scala | 4 +- .../spark/sql/sources/InsertSuite.scala | 32 ++++ .../sql/test/DataFrameReaderWriterSuite.scala | 15 ++ 8 files changed, 212 insertions(+), 77 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index f0991f1927985..6769773cfec45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -108,10 +108,11 @@ object TableOutputResolver { case StoreAssignmentPolicy.LEGACY => outputField - case StoreAssignmentPolicy.STRICT => + case StoreAssignmentPolicy.STRICT | StoreAssignmentPolicy.ANSI => // run the type check first to ensure type errors are present val canWrite = DataType.canWrite( - queryExpr.dataType, tableAttr.dataType, byName, conf.resolver, tableAttr.name, addError) + queryExpr.dataType, tableAttr.dataType, byName, conf.resolver, tableAttr.name, + storeAssignmentPolicy, addError) if (queryExpr.nullable && !tableAttr.nullable) { addError(s"Cannot write nullable values to non-null column '${tableAttr.name}'") None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 32e2707948919..cb5cb72dc5a6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -158,6 +158,35 @@ object Cast { case _ => false } + def canANSIStoreAssign(from: DataType, to: DataType): Boolean = (from, to) match { + case _ if from == to => true + case (_: NumericType, _: NumericType) => true + case (_, StringType) => true + case (DateType, TimestampType) => true + case (TimestampType, DateType) => true + // Spark supports casting between long and timestamp, please see `longToTimestamp` and + // `timestampToLong` for details. + case (TimestampType, LongType) => true + case (LongType, TimestampType) => true + + case (ArrayType(fromType, fn), ArrayType(toType, tn)) => + resolvableNullability(fn, tn) && canANSIStoreAssign(fromType, toType) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + resolvableNullability(fn, tn) && canANSIStoreAssign(fromKey, toKey) && + canANSIStoreAssign(fromValue, toValue) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { + case (f1, f2) => + resolvableNullability(f1.nullable, f2.nullable) && + canANSIStoreAssign(f1.dataType, f2.dataType) + } + + case _ => false + } + private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = { val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from) val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80697cb76aec0..3d5e4597d8d9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1644,14 +1644,15 @@ object SQLConf { .createWithDefault(PartitionOverwriteMode.STATIC.toString) object StoreAssignmentPolicy extends Enumeration { - val LEGACY, STRICT = Value + val ANSI, LEGACY, STRICT = Value } val STORE_ASSIGNMENT_POLICY = buildConf("spark.sql.storeAssignmentPolicy") .doc("When inserting a value into a column with different data type, Spark will perform " + - "type coercion. Currently we support 2 policies for the type coercion rules: legacy and " + - "strict. With legacy policy, Spark allows casting any value to any data type. " + + "type coercion. Currently we support 3 policies for the type coercion rules: ansi, " + + "legacy and strict. With ansi policy, Spark performs the type coercion as per ANSI SQL. " + + "With legacy policy, Spark allows casting any value to any data type. " + "The legacy policy is the only behavior in Spark 2.x and it is compatible with Hive. " + "With strict policy, Spark doesn't allow any possible precision loss or data truncation " + "in type coercion, e.g. `int` to `long` and `float` to `double` are not allowed." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index a35e971d08823..259b047a3f178 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy.{ANSI, STRICT} import org.apache.spark.util.Utils /** @@ -371,12 +373,14 @@ object DataType { byName: Boolean, resolver: Resolver, context: String, + storeAssignmentPolicy: StoreAssignmentPolicy.Value, addError: String => Unit): Boolean = { (write, read) match { case (wArr: ArrayType, rArr: ArrayType) => // run compatibility check first to produce all error messages val typesCompatible = canWrite( - wArr.elementType, rArr.elementType, byName, resolver, context + ".element", addError) + wArr.elementType, rArr.elementType, byName, resolver, context + ".element", + storeAssignmentPolicy, addError) if (wArr.containsNull && !rArr.containsNull) { addError(s"Cannot write nullable elements to array of non-nulls: '$context'") @@ -391,9 +395,11 @@ object DataType { // run compatibility check first to produce all error messages val keyCompatible = canWrite( - wMap.keyType, rMap.keyType, byName, resolver, context + ".key", addError) + wMap.keyType, rMap.keyType, byName, resolver, context + ".key", + storeAssignmentPolicy, addError) val valueCompatible = canWrite( - wMap.valueType, rMap.valueType, byName, resolver, context + ".value", addError) + wMap.valueType, rMap.valueType, byName, resolver, context + ".value", + storeAssignmentPolicy, addError) if (wMap.valueContainsNull && !rMap.valueContainsNull) { addError(s"Cannot write nullable values to map of non-nulls: '$context'") @@ -409,7 +415,8 @@ object DataType { val nameMatch = resolver(wField.name, rField.name) || isSparkGeneratedName(wField.name) val fieldContext = s"$context.${rField.name}" val typesCompatible = canWrite( - wField.dataType, rField.dataType, byName, resolver, fieldContext, addError) + wField.dataType, rField.dataType, byName, resolver, fieldContext, + storeAssignmentPolicy, addError) if (byName && !nameMatch) { addError(s"Struct '$context' $i-th field name does not match " + @@ -441,7 +448,7 @@ object DataType { fieldCompatible - case (w: AtomicType, r: AtomicType) => + case (w: AtomicType, r: AtomicType) if storeAssignmentPolicy == STRICT => if (!Cast.canUpCast(w, r)) { addError(s"Cannot safely cast '$context': $w to $r") false @@ -449,6 +456,14 @@ object DataType { true } + case (w: AtomicType, r: AtomicType) if storeAssignmentPolicy == ANSI => + if (!Cast.canANSIStoreAssign(w, r)) { + addError(s"Cannot cast '$context': $w to $r") + false + } else { + true + } + case (w, r) if w.sameType(r) && !w.isInstanceOf[NullType] => true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala index c757015c754b7..eade9b6112fe4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.types._ -class V2AppendDataAnalysisSuite extends DataSourceV2AnalysisSuite { +class V2AppendDataANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { AppendData.byName(table, query) } @@ -37,7 +37,17 @@ class V2AppendDataAnalysisSuite extends DataSourceV2AnalysisSuite { } } -class V2OverwritePartitionsDynamicAnalysisSuite extends DataSourceV2AnalysisSuite { +class V2AppendDataStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + AppendData.byName(table, query) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + AppendData.byPosition(table, query) + } +} + +class V2OverwritePartitionsDynamicANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { OverwritePartitionsDynamic.byName(table, query) } @@ -47,7 +57,17 @@ class V2OverwritePartitionsDynamicAnalysisSuite extends DataSourceV2AnalysisSuit } } -class V2OverwriteByExpressionAnalysisSuite extends DataSourceV2AnalysisSuite { +class V2OverwritePartitionsDynamicStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwritePartitionsDynamic.byName(table, query) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwritePartitionsDynamic.byPosition(table, query) + } +} + +class V2OverwriteByExpressionANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { OverwriteByExpression.byName(table, query, Literal(true)) } @@ -104,6 +124,12 @@ class V2OverwriteByExpressionAnalysisSuite extends DataSourceV2AnalysisSuite { } } +class V2OverwriteByExpressionStrictAnalysisSuite extends V2OverwriteByExpressionANSIAnalysisSuite { + override def getSQLConf(caseSensitive: Boolean): SQLConf = + super.getSQLConf(caseSensitive) + .copy(SQLConf.STORE_ASSIGNMENT_POLICY -> StoreAssignmentPolicy.STRICT) +} + case class TestRelation(output: Seq[AttributeReference]) extends LeafNode with NamedRelation { override def name: String = "table-name" } @@ -114,12 +140,85 @@ case class TestRelationAcceptAnySchema(output: Seq[AttributeReference]) override def skipSchemaResolution: Boolean = true } -abstract class DataSourceV2AnalysisSuite extends AnalysisTest { +abstract class DataSourceV2ANSIAnalysisSuite extends DataSourceV2AnalysisBaseSuite { + override def getSQLConf(caseSensitive: Boolean): SQLConf = + super.getSQLConf(caseSensitive) + .copy(SQLConf.STORE_ASSIGNMENT_POLICY -> StoreAssignmentPolicy.ANSI) +} - override def getAnalyzer(caseSensitive: Boolean): Analyzer = { - val conf = new SQLConf() - .copy(SQLConf.CASE_SENSITIVE -> caseSensitive) +abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseSuite { + override def getSQLConf(caseSensitive: Boolean): SQLConf = + super.getSQLConf(caseSensitive) .copy(SQLConf.STORE_ASSIGNMENT_POLICY -> StoreAssignmentPolicy.STRICT) + + test("byName: fail canWrite check") { + val parsedPlan = byName(table, widerTable) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", + "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) + } + + test("byName: multiple field errors are reported") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("x", DoubleType), + StructField("b", FloatType))).toAttributes) + + val parsedPlan = byName(xRequiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot safely cast", "'x'", "DoubleType to FloatType", + "Cannot write nullable values to non-null column", "'x'", + "Cannot find data for output column", "'y'")) + } + + + test("byPosition: fail canWrite check") { + val widerTable = TestRelation(StructType(Seq( + StructField("a", DoubleType), + StructField("b", DoubleType))).toAttributes) + + val parsedPlan = byPosition(table, widerTable) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", + "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) + } + + test("byPosition: multiple field errors are reported") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("x", DoubleType), + StructField("b", FloatType))).toAttributes) + + val parsedPlan = byPosition(xRequiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot write nullable values to non-null column", "'x'", + "Cannot safely cast", "'x'", "DoubleType to FloatType")) + } +} + +abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest { + + protected def getSQLConf(caseSensitive: Boolean): SQLConf = + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) + + override def getAnalyzer(caseSensitive: Boolean): Analyzer = { + val conf = getSQLConf(caseSensitive) val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin, conf) catalog.createDatabase( CatalogDatabase("default", "", new URI("loc"), Map.empty), @@ -254,15 +353,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'x'")) } - test("byName: fail canWrite check") { - val parsedPlan = byName(table, widerTable) - - assertNotResolved(parsedPlan) - assertAnalysisError(parsedPlan, Seq( - "Cannot write", "'table-name'", - "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) - } - test("byName: insert safe cast") { val x = table.output.head val y = table.output.last @@ -294,25 +384,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'x', 'y', 'z'")) } - test("byName: multiple field errors are reported") { - val xRequiredTable = TestRelation(StructType(Seq( - StructField("x", FloatType, nullable = false), - StructField("y", DoubleType))).toAttributes) - - val query = TestRelation(StructType(Seq( - StructField("x", DoubleType), - StructField("b", FloatType))).toAttributes) - - val parsedPlan = byName(xRequiredTable, query) - - assertNotResolved(parsedPlan) - assertAnalysisError(parsedPlan, Seq( - "Cannot write incompatible data to table", "'table-name'", - "Cannot safely cast", "'x'", "DoubleType to FloatType", - "Cannot write nullable values to non-null column", "'x'", - "Cannot find data for output column", "'y'")) - } - test("byPosition: basic behavior") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), @@ -396,19 +467,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'y'")) } - test("byPosition: fail canWrite check") { - val widerTable = TestRelation(StructType(Seq( - StructField("a", DoubleType), - StructField("b", DoubleType))).toAttributes) - - val parsedPlan = byPosition(table, widerTable) - - assertNotResolved(parsedPlan) - assertAnalysisError(parsedPlan, Seq( - "Cannot write", "'table-name'", - "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) - } - test("byPosition: insert safe cast") { val widerTable = TestRelation(StructType(Seq( StructField("a", DoubleType), @@ -444,24 +502,6 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'a', 'b', 'c'")) } - test("byPosition: multiple field errors are reported") { - val xRequiredTable = TestRelation(StructType(Seq( - StructField("x", FloatType, nullable = false), - StructField("y", DoubleType))).toAttributes) - - val query = TestRelation(StructType(Seq( - StructField("x", DoubleType), - StructField("b", FloatType))).toAttributes) - - val parsedPlan = byPosition(xRequiredTable, query) - - assertNotResolved(parsedPlan) - assertAnalysisError(parsedPlan, Seq( - "Cannot write incompatible data to table", "'table-name'", - "Cannot write nullable values to non-null column", "'x'", - "Cannot safely cast", "'x'", "DoubleType to FloatType")) - } - test("bypass output column resolution") { val table = TestRelationAcceptAnySchema(StructType(Seq( StructField("a", FloatType, nullable = false), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala index 6b5fc5f0d4434..7a8cd4d18cf59 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy class DataTypeWriteCompatibilitySuite extends SparkFunSuite { private val atomicTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, @@ -386,6 +387,7 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { byName: Boolean = true): Unit = { assert( DataType.canWrite(writeType, readType, byName, analysis.caseSensitiveResolution, name, + StoreAssignmentPolicy.STRICT, errMsg => fail(s"Should not produce errors but was called with: $errMsg")), desc) } @@ -411,7 +413,7 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { val errs = new mutable.ArrayBuffer[String]() assert( DataType.canWrite(writeType, readType, byName, analysis.caseSensitiveResolution, name, - errMsg => errs += errMsg) === false, desc) + StoreAssignmentPolicy.STRICT, errMsg => errs += errMsg) === false, desc) assert(errs.size === numErrs, s"Should produce $numErrs error messages") checkErrors(errs) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index a55aa7b28ce23..ac666b61cca81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -582,6 +582,38 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } } + test("Throw exception on unsafe cast with ANSI casting policy") { + withSQLConf( + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "parquet", + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.ANSI.toString) { + withTable("t") { + sql("create table t(i int, d double) using parquet") + var msg = intercept[AnalysisException] { + sql("insert into t values('a', 'b')") + }.getMessage + assert(msg.contains("Cannot cast 'i': StringType to IntegerType") && + msg.contains("Cannot cast 'd': StringType to DoubleType")) + msg = intercept[AnalysisException] { + sql("insert into t values(now(), now())") + }.getMessage + assert(msg.contains("Cannot cast 'i': TimestampType to IntegerType") && + msg.contains("Cannot cast 'd': TimestampType to DoubleType")) + } + } + } + + test("Allow on writing any numeric value to numeric type with ANSI policy") { + withSQLConf( + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "parquet", + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.ANSI.toString) { + withTable("t") { + sql("create table t(i int, d float) using parquet") + sql("insert into t values(1L, 2.0)") + checkAnswer(sql("select * from t"), Row(1, 2.0F)) + } + } + } + test("SPARK-24860: dynamic partition overwrite specified per source without catalog table") { withTempPath { path => Seq((1, 1), (2, 2)).toDF("i", "part") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index d37e53bc5ac08..d0fb93896dffc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -327,6 +327,21 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with } } + test("Throw exception on unsafe cast with ANSI casting policy") { + withSQLConf( + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "parquet", + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.STRICT.toString) { + withTable("t") { + sql("create table t(i int, d double) using parquet") + // Calling `saveAsTable` to an existing table with append mode results in table insertion. + var msg = intercept[AnalysisException] { + Seq(("a", "b")).toDF("i", "d").write.mode("append").saveAsTable("t") + }.getMessage + assert(msg.contains("Cannot cast 'i': StringType to IntegerType") && + msg.contains("Cannot cast 'd': StringType to DoubleType")) + } + } + test("test path option in load") { spark.read .format("org.apache.spark.sql.test") From e2b37544b4888008ab92c554bb2dfe51ecae4b35 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 26 Aug 2019 22:21:21 +0800 Subject: [PATCH 2/6] revise tests --- .../DataTypeWriteCompatibilitySuite.scala | 418 +++++++++++------- .../sql/test/DataFrameReaderWriterSuite.scala | 3 +- 2 files changed, 268 insertions(+), 153 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala index 7a8cd4d18cf59..812fb48d1f941 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -22,21 +22,277 @@ import scala.collection.mutable import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy -class DataTypeWriteCompatibilitySuite extends SparkFunSuite { - private val atomicTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, - DoubleType, DateType, TimestampType, StringType, BinaryType) +class StrictDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBaseSuite { + override protected def storeAssignmentPolicy: SQLConf.StoreAssignmentPolicy.Value = + StoreAssignmentPolicy.STRICT - private val point2 = StructType(Seq( + test("Check atomic types: write allowed only when casting is safe") { + atomicTypes.foreach { w => + atomicTypes.foreach { r => + if (Cast.canUpCast(w, r)) { + assertAllowed(w, r, "t", s"Should allow writing $w to $r because cast is safe") + + } else { + assertSingleError(w, r, "t", + s"Should not allow writing $w to $r because cast is not safe") { err => + assert(err.contains("'t'"), "Should include the field name context") + assert(err.contains("Cannot safely cast"), "Should identify unsafe cast") + assert(err.contains(s"$w"), "Should include write type") + assert(err.contains(s"$r"), "Should include read type") + } + } + } + } + } + + test("Check struct types: unsafe casts are not allowed") { + assertNumErrors(widerPoint2, point2, "t", + "Should fail because types require unsafe casts", 2) { errs => + + assert(errs(0).contains("'t.x'"), "Should include the nested field name context") + assert(errs(0).contains("Cannot safely cast")) + + assert(errs(1).contains("'t.y'"), "Should include the nested field name context") + assert(errs(1).contains("Cannot safely cast")) + } + } + + test("Check array types: unsafe casts are not allowed") { + val arrayOfLong = ArrayType(LongType) + val arrayOfInt = ArrayType(IntegerType) + + assertSingleError(arrayOfLong, arrayOfInt, "arr", + "Should not allow array of longs to array of ints") { err => + assert(err.contains("'arr.element'"), + "Should identify problem with named array's element type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map value types: casting Long to Integer is not allowed") { + val mapOfLong = MapType(StringType, LongType) + val mapOfInt = MapType(StringType, IntegerType) + + assertSingleError(mapOfLong, mapOfInt, "m", + "Should not allow map of longs to map of ints") { err => + assert(err.contains("'m.value'"), "Should identify problem with named map's value type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map key types: unsafe casts are not allowed") { + val mapKeyLong = MapType(LongType, StringType) + val mapKeyInt = MapType(IntegerType, StringType) + + assertSingleError(mapKeyLong, mapKeyInt, "m", + "Should not allow map of long keys to map of int keys") { err => + assert(err.contains("'m.key'"), "Should identify problem with named map's key type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check types with multiple errors") { + val readType = StructType(Seq( + StructField("a", ArrayType(DoubleType, containsNull = false)), + StructField("arr_of_structs", ArrayType(point2, containsNull = false)), + StructField("bad_nested_type", ArrayType(StringType)), + StructField("m", MapType(LongType, FloatType, valueContainsNull = false)), + StructField("map_of_structs", MapType(StringType, point3, valueContainsNull = false)), + StructField("x", IntegerType, nullable = false), + StructField("missing1", StringType, nullable = false), + StructField("missing2", StringType) + )) + + val missingMiddleField = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("z", FloatType, nullable = false))) + + val writeType = StructType(Seq( + StructField("a", ArrayType(StringType)), + StructField("arr_of_structs", ArrayType(point3)), + StructField("bad_nested_type", point3), + StructField("m", MapType(DoubleType, DoubleType)), + StructField("map_of_structs", MapType(StringType, missingMiddleField)), + StructField("y", LongType) + )) + + assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) { errs => + assert(errs(0).contains("'top.a.element'"), "Should identify bad type") + assert(errs(0).contains("Cannot safely cast")) + assert(errs(0).contains("StringType to DoubleType")) + + assert(errs(1).contains("'top.a'"), "Should identify bad type") + assert(errs(1).contains("Cannot write nullable elements to array of non-nulls")) + + assert(errs(2).contains("'top.arr_of_structs.element'"), "Should identify bad type") + assert(errs(2).contains("'z'"), "Should identify bad field") + assert(errs(2).contains("Cannot write extra fields to struct")) + + assert(errs(3).contains("'top.arr_of_structs'"), "Should identify bad type") + assert(errs(3).contains("Cannot write nullable elements to array of non-nulls")) + + assert(errs(4).contains("'top.bad_nested_type'"), "Should identify bad type") + assert(errs(4).contains("is incompatible with")) + + assert(errs(5).contains("'top.m.key'"), "Should identify bad type") + assert(errs(5).contains("Cannot safely cast")) + assert(errs(5).contains("DoubleType to LongType")) + + assert(errs(6).contains("'top.m.value'"), "Should identify bad type") + assert(errs(6).contains("Cannot safely cast")) + assert(errs(6).contains("DoubleType to FloatType")) + + assert(errs(7).contains("'top.m'"), "Should identify bad type") + assert(errs(7).contains("Cannot write nullable values to map of non-nulls")) + + assert(errs(8).contains("'top.map_of_structs.value'"), "Should identify bad type") + assert(errs(8).contains("expected 'y', found 'z'"), "Should detect name mismatch") + assert(errs(8).contains("field name does not match"), "Should identify name problem") + + assert(errs(9).contains("'top.map_of_structs.value'"), "Should identify bad type") + assert(errs(9).contains("'z'"), "Should identify missing field") + assert(errs(9).contains("missing fields"), "Should detect missing field") + + assert(errs(10).contains("'top.map_of_structs'"), "Should identify bad type") + assert(errs(10).contains("Cannot write nullable values to map of non-nulls")) + + assert(errs(11).contains("'top.x'"), "Should identify bad type") + assert(errs(11).contains("Cannot safely cast")) + assert(errs(11).contains("LongType to IntegerType")) + + assert(errs(12).contains("'top'"), "Should identify bad type") + assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name mismatch") + assert(errs(12).contains("field name does not match"), "Should identify name problem") + + assert(errs(13).contains("'top'"), "Should identify bad type") + assert(errs(13).contains("'missing1'"), "Should identify missing field") + assert(errs(13).contains("missing fields"), "Should detect missing field") + } + } +} + +class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBaseSuite { + override protected def storeAssignmentPolicy: SQLConf.StoreAssignmentPolicy.Value = + StoreAssignmentPolicy.ANSI + + test("Check map value types: unsafe casts are not allowed") { + val mapOfString = MapType(StringType, StringType) + val mapOfInt = MapType(StringType, IntegerType) + + assertSingleError(mapOfString, mapOfInt, "m", + "Should not allow map of strings to map of ints") { err => + assert(err.contains("'m.value'"), "Should identify problem with named map's value type") + assert(err.contains("Cannot cast")) + } + } + + test("Check map key types: unsafe casts are not allowed") { + val mapKeyString = MapType(StringType, StringType) + val mapKeyInt = MapType(IntegerType, StringType) + + assertSingleError(mapKeyString, mapKeyInt, "m", + "Should not allow map of string keys to map of int keys") { err => + assert(err.contains("'m.key'"), "Should identify problem with named map's key type") + assert(err.contains("Cannot cast")) + } + } + + test("Check types with multiple errors") { + val readType = StructType(Seq( + StructField("a", ArrayType(DoubleType, containsNull = false)), + StructField("arr_of_structs", ArrayType(point2, containsNull = false)), + StructField("bad_nested_type", ArrayType(StringType)), + StructField("m", MapType(LongType, FloatType, valueContainsNull = false)), + StructField("map_of_structs", MapType(StringType, point3, valueContainsNull = false)), + StructField("x", IntegerType, nullable = false), + StructField("missing1", StringType, nullable = false), + StructField("missing2", StringType) + )) + + val missingMiddleField = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("z", FloatType, nullable = false))) + + val writeType = StructType(Seq( + StructField("a", ArrayType(StringType)), + StructField("arr_of_structs", ArrayType(point3)), + StructField("bad_nested_type", point3), + StructField("m", MapType(StringType, BooleanType)), + StructField("map_of_structs", MapType(StringType, missingMiddleField)), + StructField("y", StringType) + )) + + assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) { errs => + assert(errs(0).contains("'top.a.element'"), "Should identify bad type") + assert(errs(0).contains("Cannot cast")) + assert(errs(0).contains("StringType to DoubleType")) + + assert(errs(1).contains("'top.a'"), "Should identify bad type") + assert(errs(1).contains("Cannot write nullable elements to array of non-nulls")) + + assert(errs(2).contains("'top.arr_of_structs.element'"), "Should identify bad type") + assert(errs(2).contains("'z'"), "Should identify bad field") + assert(errs(2).contains("Cannot write extra fields to struct")) + + assert(errs(3).contains("'top.arr_of_structs'"), "Should identify bad type") + assert(errs(3).contains("Cannot write nullable elements to array of non-nulls")) + + assert(errs(4).contains("'top.bad_nested_type'"), "Should identify bad type") + assert(errs(4).contains("is incompatible with")) + + assert(errs(5).contains("'top.m.key'"), "Should identify bad type") + assert(errs(5).contains("Cannot cast")) + assert(errs(5).contains("StringType to LongType")) + + assert(errs(6).contains("'top.m.value'"), "Should identify bad type") + assert(errs(6).contains("Cannot cast")) + assert(errs(6).contains("BooleanType to FloatType")) + + assert(errs(7).contains("'top.m'"), "Should identify bad type") + assert(errs(7).contains("Cannot write nullable values to map of non-nulls")) + + assert(errs(8).contains("'top.map_of_structs.value'"), "Should identify bad type") + assert(errs(8).contains("expected 'y', found 'z'"), "Should detect name mismatch") + assert(errs(8).contains("field name does not match"), "Should identify name problem") + + assert(errs(9).contains("'top.map_of_structs.value'"), "Should identify bad type") + assert(errs(9).contains("'z'"), "Should identify missing field") + assert(errs(9).contains("missing fields"), "Should detect missing field") + + assert(errs(10).contains("'top.map_of_structs'"), "Should identify bad type") + assert(errs(10).contains("Cannot write nullable values to map of non-nulls")) + + assert(errs(11).contains("'top.x'"), "Should identify bad type") + assert(errs(11).contains("Cannot cast")) + assert(errs(11).contains("StringType to IntegerType")) + + assert(errs(12).contains("'top'"), "Should identify bad type") + assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name mismatch") + assert(errs(12).contains("field name does not match"), "Should identify name problem") + + assert(errs(13).contains("'top'"), "Should identify bad type") + assert(errs(13).contains("'missing1'"), "Should identify missing field") + assert(errs(13).contains("missing fields"), "Should detect missing field") + } + } +} + +abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { + protected val atomicTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DateType, TimestampType, StringType, BinaryType) + + protected val point2 = StructType(Seq( StructField("x", FloatType, nullable = false), StructField("y", FloatType, nullable = false))) - private val widerPoint2 = StructType(Seq( + protected val widerPoint2 = StructType(Seq( StructField("x", DoubleType, nullable = false), StructField("y", DoubleType, nullable = false))) - private val point3 = StructType(Seq( + protected val point3 = StructType(Seq( StructField("x", FloatType, nullable = false), StructField("y", FloatType, nullable = false), StructField("z", FloatType))) @@ -65,25 +321,6 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { } } - test("Check atomic types: write allowed only when casting is safe") { - atomicTypes.foreach { w => - atomicTypes.foreach { r => - if (Cast.canUpCast(w, r)) { - assertAllowed(w, r, "t", s"Should allow writing $w to $r because cast is safe") - - } else { - assertSingleError(w, r, "t", - s"Should not allow writing $w to $r because cast is not safe") { err => - assert(err.contains("'t'"), "Should include the field name context") - assert(err.contains("Cannot safely cast"), "Should identify unsafe cast") - assert(err.contains(s"$w"), "Should include write type") - assert(err.contains(s"$r"), "Should include read type") - } - } - } - } - } - test("Check struct types: missing required field") { val missingRequiredField = StructType(Seq(StructField("x", FloatType, nullable = false))) assertSingleError(missingRequiredField, point2, "t", @@ -173,18 +410,6 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { } } - test("Check struct types: unsafe casts are not allowed") { - assertNumErrors(widerPoint2, point2, "t", - "Should fail because types require unsafe casts", 2) { errs => - - assert(errs(0).contains("'t.x'"), "Should include the nested field name context") - assert(errs(0).contains("Cannot safely cast")) - - assert(errs(1).contains("'t.y'"), "Should include the nested field name context") - assert(errs(1).contains("Cannot safely cast")) - } - } - test("Check struct types: type promotion is allowed") { assertAllowed(point2, widerPoint2, "t", "Should allow widening float fields x and y to double") @@ -204,18 +429,6 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { "Should allow writing point (x,y) to point(x,y,z=null)") } - test("Check array types: unsafe casts are not allowed") { - val arrayOfLong = ArrayType(LongType) - val arrayOfInt = ArrayType(IntegerType) - - assertSingleError(arrayOfLong, arrayOfInt, "arr", - "Should not allow array of longs to array of ints") { err => - assert(err.contains("'arr.element'"), - "Should identify problem with named array's element type") - assert(err.contains("Cannot safely cast")) - } - } - test("Check array types: type promotion is allowed") { val arrayOfLong = ArrayType(LongType) val arrayOfInt = ArrayType(IntegerType) @@ -242,17 +455,6 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { "Should allow array of required elements to array of optional elements") } - test("Check map value types: unsafe casts are not allowed") { - val mapOfLong = MapType(StringType, LongType) - val mapOfInt = MapType(StringType, IntegerType) - - assertSingleError(mapOfLong, mapOfInt, "m", - "Should not allow map of longs to map of ints") { err => - assert(err.contains("'m.value'"), "Should identify problem with named map's value type") - assert(err.contains("Cannot safely cast")) - } - } - test("Check map value types: type promotion is allowed") { val mapOfLong = MapType(StringType, LongType) val mapOfInt = MapType(StringType, IntegerType) @@ -279,17 +481,6 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { "Should allow map of required elements to map of optional elements") } - test("Check map key types: unsafe casts are not allowed") { - val mapKeyLong = MapType(LongType, StringType) - val mapKeyInt = MapType(IntegerType, StringType) - - assertSingleError(mapKeyLong, mapKeyInt, "m", - "Should not allow map of long keys to map of int keys") { err => - assert(err.contains("'m.key'"), "Should identify problem with named map's key type") - assert(err.contains("Cannot safely cast")) - } - } - test("Check map key types: type promotion is allowed") { val mapKeyLong = MapType(LongType, StringType) val mapKeyInt = MapType(IntegerType, StringType) @@ -298,87 +489,10 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { "Should allow map of int written to map of long column") } - test("Check types with multiple errors") { - val readType = StructType(Seq( - StructField("a", ArrayType(DoubleType, containsNull = false)), - StructField("arr_of_structs", ArrayType(point2, containsNull = false)), - StructField("bad_nested_type", ArrayType(StringType)), - StructField("m", MapType(LongType, FloatType, valueContainsNull = false)), - StructField("map_of_structs", MapType(StringType, point3, valueContainsNull = false)), - StructField("x", IntegerType, nullable = false), - StructField("missing1", StringType, nullable = false), - StructField("missing2", StringType) - )) - - val missingMiddleField = StructType(Seq( - StructField("x", FloatType, nullable = false), - StructField("z", FloatType, nullable = false))) - - val writeType = StructType(Seq( - StructField("a", ArrayType(StringType)), - StructField("arr_of_structs", ArrayType(point3)), - StructField("bad_nested_type", point3), - StructField("m", MapType(DoubleType, DoubleType)), - StructField("map_of_structs", MapType(StringType, missingMiddleField)), - StructField("y", LongType) - )) - - assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) { errs => - assert(errs(0).contains("'top.a.element'"), "Should identify bad type") - assert(errs(0).contains("Cannot safely cast")) - assert(errs(0).contains("StringType to DoubleType")) - - assert(errs(1).contains("'top.a'"), "Should identify bad type") - assert(errs(1).contains("Cannot write nullable elements to array of non-nulls")) - - assert(errs(2).contains("'top.arr_of_structs.element'"), "Should identify bad type") - assert(errs(2).contains("'z'"), "Should identify bad field") - assert(errs(2).contains("Cannot write extra fields to struct")) - - assert(errs(3).contains("'top.arr_of_structs'"), "Should identify bad type") - assert(errs(3).contains("Cannot write nullable elements to array of non-nulls")) - - assert(errs(4).contains("'top.bad_nested_type'"), "Should identify bad type") - assert(errs(4).contains("is incompatible with")) - - assert(errs(5).contains("'top.m.key'"), "Should identify bad type") - assert(errs(5).contains("Cannot safely cast")) - assert(errs(5).contains("DoubleType to LongType")) - - assert(errs(6).contains("'top.m.value'"), "Should identify bad type") - assert(errs(6).contains("Cannot safely cast")) - assert(errs(6).contains("DoubleType to FloatType")) - - assert(errs(7).contains("'top.m'"), "Should identify bad type") - assert(errs(7).contains("Cannot write nullable values to map of non-nulls")) - - assert(errs(8).contains("'top.map_of_structs.value'"), "Should identify bad type") - assert(errs(8).contains("expected 'y', found 'z'"), "Should detect name mismatch") - assert(errs(8).contains("field name does not match"), "Should identify name problem") - - assert(errs(9).contains("'top.map_of_structs.value'"), "Should identify bad type") - assert(errs(9).contains("'z'"), "Should identify missing field") - assert(errs(9).contains("missing fields"), "Should detect missing field") - - assert(errs(10).contains("'top.map_of_structs'"), "Should identify bad type") - assert(errs(10).contains("Cannot write nullable values to map of non-nulls")) - - assert(errs(11).contains("'top.x'"), "Should identify bad type") - assert(errs(11).contains("Cannot safely cast")) - assert(errs(11).contains("LongType to IntegerType")) - - assert(errs(12).contains("'top'"), "Should identify bad type") - assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name mismatch") - assert(errs(12).contains("field name does not match"), "Should identify name problem") - - assert(errs(13).contains("'top'"), "Should identify bad type") - assert(errs(13).contains("'missing1'"), "Should identify missing field") - assert(errs(13).contains("missing fields"), "Should detect missing field") - } - } - // Helper functions + protected def storeAssignmentPolicy: StoreAssignmentPolicy.Value + def assertAllowed( writeType: DataType, readType: DataType, @@ -387,7 +501,7 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { byName: Boolean = true): Unit = { assert( DataType.canWrite(writeType, readType, byName, analysis.caseSensitiveResolution, name, - StoreAssignmentPolicy.STRICT, + storeAssignmentPolicy, errMsg => fail(s"Should not produce errors but was called with: $errMsg")), desc) } @@ -413,7 +527,7 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { val errs = new mutable.ArrayBuffer[String]() assert( DataType.canWrite(writeType, readType, byName, analysis.caseSensitiveResolution, name, - StoreAssignmentPolicy.STRICT, errMsg => errs += errMsg) === false, desc) + storeAssignmentPolicy, errMsg => errs += errMsg) === false, desc) assert(errs.size === numErrs, s"Should produce $numErrs error messages") checkErrors(errs) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index d0fb93896dffc..ba8c9c1339dfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -330,7 +330,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with test("Throw exception on unsafe cast with ANSI casting policy") { withSQLConf( SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "parquet", - SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.STRICT.toString) { + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.ANSI.toString) { withTable("t") { sql("create table t(i int, d double) using parquet") // Calling `saveAsTable` to an existing table with append mode results in table insertion. @@ -339,6 +339,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with }.getMessage assert(msg.contains("Cannot cast 'i': StringType to IntegerType") && msg.contains("Cannot cast 'd': StringType to DoubleType")) + } } } From 68da9cca6030ed2bc248ccbb6d7f8731f0bae59d Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 26 Aug 2019 22:30:10 +0800 Subject: [PATCH 3/6] revise --- .../scala/org/apache/spark/sql/catalyst/expressions/Cast.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index cb5cb72dc5a6a..549d47a9226c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -161,7 +161,8 @@ object Cast { def canANSIStoreAssign(from: DataType, to: DataType): Boolean = (from, to) match { case _ if from == to => true case (_: NumericType, _: NumericType) => true - case (_, StringType) => true + case (_: AtomicType, StringType) => true + case (_: CalendarIntervalType, StringType) => true case (DateType, TimestampType) => true case (TimestampType, DateType) => true // Spark supports casting between long and timestamp, please see `longToTimestamp` and From fcc68dc3ed9165e7a51231ace93ec8950a386a37 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 26 Aug 2019 23:42:15 +0800 Subject: [PATCH 4/6] more test cases --- .../DataTypeWriteCompatibilitySuite.scala | 19 +++++++++++++++++++ .../spark/sql/sources/InsertSuite.scala | 9 ++++++++- .../sql/test/DataFrameReaderWriterSuite.scala | 6 ++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala index 812fb48d1f941..342e62ba9a965 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -178,6 +178,25 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase override protected def storeAssignmentPolicy: SQLConf.StoreAssignmentPolicy.Value = StoreAssignmentPolicy.ANSI + test("Check atomic types: write allowed only when casting is safe") { + atomicTypes.foreach { w => + atomicTypes.foreach { r => + if ((w.isInstanceOf[NumericType] && r.isInstanceOf[NumericType]) || + Cast.canANSIStoreAssign(w, r)) { + assertAllowed(w, r, "t", s"Should allow writing $w to $r because cast is safe") + } else { + assertSingleError(w, r, "t", + s"Should not allow writing $w to $r because cast is not safe") { err => + assert(err.contains("'t'"), "Should include the field name context") + assert(err.contains("Cannot cast"), "Should identify unsafe cast") + assert(err.contains(s"$w"), "Should include write type") + assert(err.contains(s"$r"), "Should include read type") + } + } + } + } + } + test("Check map value types: unsafe casts are not allowed") { val mapOfString = MapType(StringType, StringType) val mapOfInt = MapType(StringType, IntegerType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index ac666b61cca81..0fa7093959250 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -598,6 +598,11 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { }.getMessage assert(msg.contains("Cannot cast 'i': TimestampType to IntegerType") && msg.contains("Cannot cast 'd': TimestampType to DoubleType")) + msg = intercept[AnalysisException] { + sql("insert into t values(true, false)") + }.getMessage + assert(msg.contains("Cannot cast 'i': BooleanType to IntegerType") && + msg.contains("Cannot cast 'd': BooleanType to DoubleType")) } } } @@ -609,7 +614,9 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { withTable("t") { sql("create table t(i int, d float) using parquet") sql("insert into t values(1L, 2.0)") - checkAnswer(sql("select * from t"), Row(1, 2.0F)) + sql("insert into t values(3.0, 4)") + sql("insert into t values(5.0, 6L)") + checkAnswer(sql("select * from t"), Seq(Row(1, 2.0F), Row(3, 4.0F), Row(5, 6.0F))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index ba8c9c1339dfb..318daddf0bc20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -339,6 +339,12 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with }.getMessage assert(msg.contains("Cannot cast 'i': StringType to IntegerType") && msg.contains("Cannot cast 'd': StringType to DoubleType")) + + msg = intercept[AnalysisException] { + Seq((true, false)).toDF("i", "d").write.mode("append").saveAsTable("t") + }.getMessage + assert(msg.contains("Cannot cast 'i': BooleanType to IntegerType") && + msg.contains("Cannot cast 'd': BooleanType to DoubleType")) } } } From b9d49ef64d913031f3fec79c8aa5b98c39b303b3 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 27 Aug 2019 17:33:38 +0800 Subject: [PATCH 5/6] update --- .../org/apache/spark/sql/types/DataType.scala | 2 +- .../DataTypeWriteCompatibilitySuite.scala | 105 ++++++++++-------- .../spark/sql/sources/InsertSuite.scala | 25 ++++- .../sql/test/DataFrameReaderWriterSuite.scala | 8 +- 4 files changed, 84 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 259b047a3f178..3a10a56f6937f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -458,7 +458,7 @@ object DataType { case (w: AtomicType, r: AtomicType) if storeAssignmentPolicy == ANSI => if (!Cast.canANSIStoreAssign(w, r)) { - addError(s"Cannot cast '$context': $w to $r") + addError(s"Cannot safely cast '$context': $w to $r") false } else { true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala index 342e62ba9a965..af5b00be55a11 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -26,27 +26,10 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy class StrictDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBaseSuite { - override protected def storeAssignmentPolicy: SQLConf.StoreAssignmentPolicy.Value = + override def storeAssignmentPolicy: SQLConf.StoreAssignmentPolicy.Value = StoreAssignmentPolicy.STRICT - test("Check atomic types: write allowed only when casting is safe") { - atomicTypes.foreach { w => - atomicTypes.foreach { r => - if (Cast.canUpCast(w, r)) { - assertAllowed(w, r, "t", s"Should allow writing $w to $r because cast is safe") - - } else { - assertSingleError(w, r, "t", - s"Should not allow writing $w to $r because cast is not safe") { err => - assert(err.contains("'t'"), "Should include the field name context") - assert(err.contains("Cannot safely cast"), "Should identify unsafe cast") - assert(err.contains(s"$w"), "Should include write type") - assert(err.contains(s"$r"), "Should include read type") - } - } - } - } - } + override def canCast: (DataType, DataType) => Boolean = Cast.canUpCast test("Check struct types: unsafe casts are not allowed") { assertNumErrors(widerPoint2, point2, "t", @@ -178,24 +161,7 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase override protected def storeAssignmentPolicy: SQLConf.StoreAssignmentPolicy.Value = StoreAssignmentPolicy.ANSI - test("Check atomic types: write allowed only when casting is safe") { - atomicTypes.foreach { w => - atomicTypes.foreach { r => - if ((w.isInstanceOf[NumericType] && r.isInstanceOf[NumericType]) || - Cast.canANSIStoreAssign(w, r)) { - assertAllowed(w, r, "t", s"Should allow writing $w to $r because cast is safe") - } else { - assertSingleError(w, r, "t", - s"Should not allow writing $w to $r because cast is not safe") { err => - assert(err.contains("'t'"), "Should include the field name context") - assert(err.contains("Cannot cast"), "Should identify unsafe cast") - assert(err.contains(s"$w"), "Should include write type") - assert(err.contains(s"$r"), "Should include read type") - } - } - } - } - } + override def canCast: (DataType, DataType) => Boolean = Cast.canANSIStoreAssign test("Check map value types: unsafe casts are not allowed") { val mapOfString = MapType(StringType, StringType) @@ -204,7 +170,35 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase assertSingleError(mapOfString, mapOfInt, "m", "Should not allow map of strings to map of ints") { err => assert(err.contains("'m.value'"), "Should identify problem with named map's value type") - assert(err.contains("Cannot cast")) + assert(err.contains("Cannot safely cast")) + } + } + + private val stringPoint2 = StructType(Seq( + StructField("x", StringType, nullable = false), + StructField("y", StringType, nullable = false))) + + test("Check struct types: unsafe casts are not allowed") { + assertNumErrors(stringPoint2, point2, "t", + "Should fail because types require unsafe casts", 2) { errs => + + assert(errs(0).contains("'t.x'"), "Should include the nested field name context") + assert(errs(0).contains("Cannot safely cast")) + + assert(errs(1).contains("'t.y'"), "Should include the nested field name context") + assert(errs(1).contains("Cannot safely cast")) + } + } + + test("Check array types: unsafe casts are not allowed") { + val arrayOfString = ArrayType(StringType) + val arrayOfInt = ArrayType(IntegerType) + + assertSingleError(arrayOfString, arrayOfInt, "arr", + "Should not allow array of strings to array of ints") { err => + assert(err.contains("'arr.element'"), + "Should identify problem with named array's element type") + assert(err.contains("Cannot safely cast")) } } @@ -215,7 +209,7 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase assertSingleError(mapKeyString, mapKeyInt, "m", "Should not allow map of string keys to map of int keys") { err => assert(err.contains("'m.key'"), "Should identify problem with named map's key type") - assert(err.contains("Cannot cast")) + assert(err.contains("Cannot safely cast")) } } @@ -246,7 +240,7 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) { errs => assert(errs(0).contains("'top.a.element'"), "Should identify bad type") - assert(errs(0).contains("Cannot cast")) + assert(errs(0).contains("Cannot safely cast")) assert(errs(0).contains("StringType to DoubleType")) assert(errs(1).contains("'top.a'"), "Should identify bad type") @@ -263,11 +257,11 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase assert(errs(4).contains("is incompatible with")) assert(errs(5).contains("'top.m.key'"), "Should identify bad type") - assert(errs(5).contains("Cannot cast")) + assert(errs(5).contains("Cannot safely cast")) assert(errs(5).contains("StringType to LongType")) assert(errs(6).contains("'top.m.value'"), "Should identify bad type") - assert(errs(6).contains("Cannot cast")) + assert(errs(6).contains("Cannot safely cast")) assert(errs(6).contains("BooleanType to FloatType")) assert(errs(7).contains("'top.m'"), "Should identify bad type") @@ -285,7 +279,7 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase assert(errs(10).contains("Cannot write nullable values to map of non-nulls")) assert(errs(11).contains("'top.x'"), "Should identify bad type") - assert(errs(11).contains("Cannot cast")) + assert(errs(11).contains("Cannot safely cast")) assert(errs(11).contains("StringType to IntegerType")) assert(errs(12).contains("'top'"), "Should identify bad type") @@ -300,6 +294,10 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase } abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { + protected def storeAssignmentPolicy: StoreAssignmentPolicy.Value + + protected def canCast: (DataType, DataType) => Boolean + protected val atomicTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DateType, TimestampType, StringType, BinaryType) @@ -340,6 +338,25 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { } } + test("Check atomic types: write allowed only when casting is safe") { + atomicTypes.foreach { w => + atomicTypes.foreach { r => + if (canCast(w, r)) { + assertAllowed(w, r, "t", s"Should allow writing $w to $r because cast is safe") + + } else { + assertSingleError(w, r, "t", + s"Should not allow writing $w to $r because cast is not safe") { err => + assert(err.contains("'t'"), "Should include the field name context") + assert(err.contains("Cannot safely cast"), "Should identify unsafe cast") + assert(err.contains(s"$w"), "Should include write type") + assert(err.contains(s"$r"), "Should include read type") + } + } + } + } + } + test("Check struct types: missing required field") { val missingRequiredField = StructType(Seq(StructField("x", FloatType, nullable = false))) assertSingleError(missingRequiredField, point2, "t", @@ -510,8 +527,6 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { // Helper functions - protected def storeAssignmentPolicy: StoreAssignmentPolicy.Value - def assertAllowed( writeType: DataType, readType: DataType, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 0fa7093959250..09b99508c1182 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import java.io.File +import java.sql.Date import org.apache.spark.SparkException import org.apache.spark.sql._ @@ -591,18 +592,18 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { var msg = intercept[AnalysisException] { sql("insert into t values('a', 'b')") }.getMessage - assert(msg.contains("Cannot cast 'i': StringType to IntegerType") && - msg.contains("Cannot cast 'd': StringType to DoubleType")) + assert(msg.contains("Cannot safely cast 'i': StringType to IntegerType") && + msg.contains("Cannot safely cast 'd': StringType to DoubleType")) msg = intercept[AnalysisException] { sql("insert into t values(now(), now())") }.getMessage - assert(msg.contains("Cannot cast 'i': TimestampType to IntegerType") && - msg.contains("Cannot cast 'd': TimestampType to DoubleType")) + assert(msg.contains("Cannot safely cast 'i': TimestampType to IntegerType") && + msg.contains("Cannot safely cast 'd': TimestampType to DoubleType")) msg = intercept[AnalysisException] { sql("insert into t values(true, false)") }.getMessage - assert(msg.contains("Cannot cast 'i': BooleanType to IntegerType") && - msg.contains("Cannot cast 'd': BooleanType to DoubleType")) + assert(msg.contains("Cannot safely cast 'i': BooleanType to IntegerType") && + msg.contains("Cannot safely cast 'd': BooleanType to DoubleType")) } } } @@ -621,6 +622,18 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } } + test("Allow on writing timestamp value to date type with ANSI policy") { + withSQLConf( + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "parquet", + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.ANSI.toString) { + withTable("t") { + sql("create table t(i date) using parquet") + sql("insert into t values(TIMESTAMP('2010-09-02 14:10:10'))") + checkAnswer(sql("select * from t"), Seq(Row(Date.valueOf("2010-09-02")))) + } + } + } + test("SPARK-24860: dynamic partition overwrite specified per source without catalog table") { withTempPath { path => Seq((1, 1), (2, 2)).toDF("i", "part") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 318daddf0bc20..922c0ee4525d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -337,14 +337,14 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with var msg = intercept[AnalysisException] { Seq(("a", "b")).toDF("i", "d").write.mode("append").saveAsTable("t") }.getMessage - assert(msg.contains("Cannot cast 'i': StringType to IntegerType") && - msg.contains("Cannot cast 'd': StringType to DoubleType")) + assert(msg.contains("Cannot safely cast 'i': StringType to IntegerType") && + msg.contains("Cannot safely cast 'd': StringType to DoubleType")) msg = intercept[AnalysisException] { Seq((true, false)).toDF("i", "d").write.mode("append").saveAsTable("t") }.getMessage - assert(msg.contains("Cannot cast 'i': BooleanType to IntegerType") && - msg.contains("Cannot cast 'd': BooleanType to DoubleType")) + assert(msg.contains("Cannot safely cast 'i': BooleanType to IntegerType") && + msg.contains("Cannot safely cast 'd': BooleanType to DoubleType")) } } } From af0f7390e61455682cb5089f8e8dd3c0e60d32b1 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 27 Aug 2019 17:39:42 +0800 Subject: [PATCH 6/6] revise --- .../DataTypeWriteCompatibilitySuite.scala | 237 ++++++------------ 1 file changed, 79 insertions(+), 158 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala index af5b00be55a11..784cc7a70489f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -76,85 +76,6 @@ class StrictDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBa assert(err.contains("Cannot safely cast")) } } - - test("Check types with multiple errors") { - val readType = StructType(Seq( - StructField("a", ArrayType(DoubleType, containsNull = false)), - StructField("arr_of_structs", ArrayType(point2, containsNull = false)), - StructField("bad_nested_type", ArrayType(StringType)), - StructField("m", MapType(LongType, FloatType, valueContainsNull = false)), - StructField("map_of_structs", MapType(StringType, point3, valueContainsNull = false)), - StructField("x", IntegerType, nullable = false), - StructField("missing1", StringType, nullable = false), - StructField("missing2", StringType) - )) - - val missingMiddleField = StructType(Seq( - StructField("x", FloatType, nullable = false), - StructField("z", FloatType, nullable = false))) - - val writeType = StructType(Seq( - StructField("a", ArrayType(StringType)), - StructField("arr_of_structs", ArrayType(point3)), - StructField("bad_nested_type", point3), - StructField("m", MapType(DoubleType, DoubleType)), - StructField("map_of_structs", MapType(StringType, missingMiddleField)), - StructField("y", LongType) - )) - - assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) { errs => - assert(errs(0).contains("'top.a.element'"), "Should identify bad type") - assert(errs(0).contains("Cannot safely cast")) - assert(errs(0).contains("StringType to DoubleType")) - - assert(errs(1).contains("'top.a'"), "Should identify bad type") - assert(errs(1).contains("Cannot write nullable elements to array of non-nulls")) - - assert(errs(2).contains("'top.arr_of_structs.element'"), "Should identify bad type") - assert(errs(2).contains("'z'"), "Should identify bad field") - assert(errs(2).contains("Cannot write extra fields to struct")) - - assert(errs(3).contains("'top.arr_of_structs'"), "Should identify bad type") - assert(errs(3).contains("Cannot write nullable elements to array of non-nulls")) - - assert(errs(4).contains("'top.bad_nested_type'"), "Should identify bad type") - assert(errs(4).contains("is incompatible with")) - - assert(errs(5).contains("'top.m.key'"), "Should identify bad type") - assert(errs(5).contains("Cannot safely cast")) - assert(errs(5).contains("DoubleType to LongType")) - - assert(errs(6).contains("'top.m.value'"), "Should identify bad type") - assert(errs(6).contains("Cannot safely cast")) - assert(errs(6).contains("DoubleType to FloatType")) - - assert(errs(7).contains("'top.m'"), "Should identify bad type") - assert(errs(7).contains("Cannot write nullable values to map of non-nulls")) - - assert(errs(8).contains("'top.map_of_structs.value'"), "Should identify bad type") - assert(errs(8).contains("expected 'y', found 'z'"), "Should detect name mismatch") - assert(errs(8).contains("field name does not match"), "Should identify name problem") - - assert(errs(9).contains("'top.map_of_structs.value'"), "Should identify bad type") - assert(errs(9).contains("'z'"), "Should identify missing field") - assert(errs(9).contains("missing fields"), "Should detect missing field") - - assert(errs(10).contains("'top.map_of_structs'"), "Should identify bad type") - assert(errs(10).contains("Cannot write nullable values to map of non-nulls")) - - assert(errs(11).contains("'top.x'"), "Should identify bad type") - assert(errs(11).contains("Cannot safely cast")) - assert(errs(11).contains("LongType to IntegerType")) - - assert(errs(12).contains("'top'"), "Should identify bad type") - assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name mismatch") - assert(errs(12).contains("field name does not match"), "Should identify name problem") - - assert(errs(13).contains("'top'"), "Should identify bad type") - assert(errs(13).contains("'missing1'"), "Should identify missing field") - assert(errs(13).contains("missing fields"), "Should detect missing field") - } - } } class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBaseSuite { @@ -212,85 +133,6 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase assert(err.contains("Cannot safely cast")) } } - - test("Check types with multiple errors") { - val readType = StructType(Seq( - StructField("a", ArrayType(DoubleType, containsNull = false)), - StructField("arr_of_structs", ArrayType(point2, containsNull = false)), - StructField("bad_nested_type", ArrayType(StringType)), - StructField("m", MapType(LongType, FloatType, valueContainsNull = false)), - StructField("map_of_structs", MapType(StringType, point3, valueContainsNull = false)), - StructField("x", IntegerType, nullable = false), - StructField("missing1", StringType, nullable = false), - StructField("missing2", StringType) - )) - - val missingMiddleField = StructType(Seq( - StructField("x", FloatType, nullable = false), - StructField("z", FloatType, nullable = false))) - - val writeType = StructType(Seq( - StructField("a", ArrayType(StringType)), - StructField("arr_of_structs", ArrayType(point3)), - StructField("bad_nested_type", point3), - StructField("m", MapType(StringType, BooleanType)), - StructField("map_of_structs", MapType(StringType, missingMiddleField)), - StructField("y", StringType) - )) - - assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) { errs => - assert(errs(0).contains("'top.a.element'"), "Should identify bad type") - assert(errs(0).contains("Cannot safely cast")) - assert(errs(0).contains("StringType to DoubleType")) - - assert(errs(1).contains("'top.a'"), "Should identify bad type") - assert(errs(1).contains("Cannot write nullable elements to array of non-nulls")) - - assert(errs(2).contains("'top.arr_of_structs.element'"), "Should identify bad type") - assert(errs(2).contains("'z'"), "Should identify bad field") - assert(errs(2).contains("Cannot write extra fields to struct")) - - assert(errs(3).contains("'top.arr_of_structs'"), "Should identify bad type") - assert(errs(3).contains("Cannot write nullable elements to array of non-nulls")) - - assert(errs(4).contains("'top.bad_nested_type'"), "Should identify bad type") - assert(errs(4).contains("is incompatible with")) - - assert(errs(5).contains("'top.m.key'"), "Should identify bad type") - assert(errs(5).contains("Cannot safely cast")) - assert(errs(5).contains("StringType to LongType")) - - assert(errs(6).contains("'top.m.value'"), "Should identify bad type") - assert(errs(6).contains("Cannot safely cast")) - assert(errs(6).contains("BooleanType to FloatType")) - - assert(errs(7).contains("'top.m'"), "Should identify bad type") - assert(errs(7).contains("Cannot write nullable values to map of non-nulls")) - - assert(errs(8).contains("'top.map_of_structs.value'"), "Should identify bad type") - assert(errs(8).contains("expected 'y', found 'z'"), "Should detect name mismatch") - assert(errs(8).contains("field name does not match"), "Should identify name problem") - - assert(errs(9).contains("'top.map_of_structs.value'"), "Should identify bad type") - assert(errs(9).contains("'z'"), "Should identify missing field") - assert(errs(9).contains("missing fields"), "Should detect missing field") - - assert(errs(10).contains("'top.map_of_structs'"), "Should identify bad type") - assert(errs(10).contains("Cannot write nullable values to map of non-nulls")) - - assert(errs(11).contains("'top.x'"), "Should identify bad type") - assert(errs(11).contains("Cannot safely cast")) - assert(errs(11).contains("StringType to IntegerType")) - - assert(errs(12).contains("'top'"), "Should identify bad type") - assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name mismatch") - assert(errs(12).contains("field name does not match"), "Should identify name problem") - - assert(errs(13).contains("'top'"), "Should identify bad type") - assert(errs(13).contains("'missing1'"), "Should identify missing field") - assert(errs(13).contains("missing fields"), "Should detect missing field") - } - } } abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { @@ -525,6 +367,85 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { "Should allow map of int written to map of long column") } + test("Check types with multiple errors") { + val readType = StructType(Seq( + StructField("a", ArrayType(DoubleType, containsNull = false)), + StructField("arr_of_structs", ArrayType(point2, containsNull = false)), + StructField("bad_nested_type", ArrayType(StringType)), + StructField("m", MapType(LongType, FloatType, valueContainsNull = false)), + StructField("map_of_structs", MapType(StringType, point3, valueContainsNull = false)), + StructField("x", IntegerType, nullable = false), + StructField("missing1", StringType, nullable = false), + StructField("missing2", StringType) + )) + + val missingMiddleField = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("z", FloatType, nullable = false))) + + val writeType = StructType(Seq( + StructField("a", ArrayType(StringType)), + StructField("arr_of_structs", ArrayType(point3)), + StructField("bad_nested_type", point3), + StructField("m", MapType(StringType, BooleanType)), + StructField("map_of_structs", MapType(StringType, missingMiddleField)), + StructField("y", StringType) + )) + + assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) { errs => + assert(errs(0).contains("'top.a.element'"), "Should identify bad type") + assert(errs(0).contains("Cannot safely cast")) + assert(errs(0).contains("StringType to DoubleType")) + + assert(errs(1).contains("'top.a'"), "Should identify bad type") + assert(errs(1).contains("Cannot write nullable elements to array of non-nulls")) + + assert(errs(2).contains("'top.arr_of_structs.element'"), "Should identify bad type") + assert(errs(2).contains("'z'"), "Should identify bad field") + assert(errs(2).contains("Cannot write extra fields to struct")) + + assert(errs(3).contains("'top.arr_of_structs'"), "Should identify bad type") + assert(errs(3).contains("Cannot write nullable elements to array of non-nulls")) + + assert(errs(4).contains("'top.bad_nested_type'"), "Should identify bad type") + assert(errs(4).contains("is incompatible with")) + + assert(errs(5).contains("'top.m.key'"), "Should identify bad type") + assert(errs(5).contains("Cannot safely cast")) + assert(errs(5).contains("StringType to LongType")) + + assert(errs(6).contains("'top.m.value'"), "Should identify bad type") + assert(errs(6).contains("Cannot safely cast")) + assert(errs(6).contains("BooleanType to FloatType")) + + assert(errs(7).contains("'top.m'"), "Should identify bad type") + assert(errs(7).contains("Cannot write nullable values to map of non-nulls")) + + assert(errs(8).contains("'top.map_of_structs.value'"), "Should identify bad type") + assert(errs(8).contains("expected 'y', found 'z'"), "Should detect name mismatch") + assert(errs(8).contains("field name does not match"), "Should identify name problem") + + assert(errs(9).contains("'top.map_of_structs.value'"), "Should identify bad type") + assert(errs(9).contains("'z'"), "Should identify missing field") + assert(errs(9).contains("missing fields"), "Should detect missing field") + + assert(errs(10).contains("'top.map_of_structs'"), "Should identify bad type") + assert(errs(10).contains("Cannot write nullable values to map of non-nulls")) + + assert(errs(11).contains("'top.x'"), "Should identify bad type") + assert(errs(11).contains("Cannot safely cast")) + assert(errs(11).contains("StringType to IntegerType")) + + assert(errs(12).contains("'top'"), "Should identify bad type") + assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name mismatch") + assert(errs(12).contains("field name does not match"), "Should identify name problem") + + assert(errs(13).contains("'top'"), "Should identify bad type") + assert(errs(13).contains("'missing1'"), "Should identify missing field") + assert(errs(13).contains("missing fields"), "Should detect missing field") + } + } + // Helper functions def assertAllowed(