diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 90c7cf6f082c..d23d43bef76e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -176,6 +176,7 @@ class Analyzer( ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: + ResolveOutputRelation :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: @@ -2227,6 +2228,98 @@ class Analyzer( } } + /** + * Resolves columns of an output table from the data in a logical plan. This rule will: + * + * - Reorder columns when the write is by name + * - Insert safe casts when data types do not match + * - Insert aliases when column names do not match + * - Detect plans that are not compatible with the output table and throw AnalysisException + */ + object ResolveOutputRelation extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case append @ AppendData(table, query, isByName) + if table.resolved && query.resolved && !append.resolved => + val projection = resolveOutputColumns(table.name, table.output, query, isByName) + + if (projection != query) { + append.copy(query = projection) + } else { + append + } + } + + def resolveOutputColumns( + tableName: String, + expected: Seq[Attribute], + query: LogicalPlan, + byName: Boolean): LogicalPlan = { + + if (expected.size < query.output.size) { + throw new AnalysisException( + s"""Cannot write to '$tableName', too many data columns: + |Table columns: ${expected.map(_.name).mkString(", ")} + |Data columns: ${query.output.map(_.name).mkString(", ")}""".stripMargin) + } + + val errors = new mutable.ArrayBuffer[String]() + val resolved: Seq[NamedExpression] = if (byName) { + expected.flatMap { tableAttr => + query.resolveQuoted(tableAttr.name, resolver) match { + case Some(queryExpr) => + checkField(tableAttr, queryExpr, err => errors += err) + case None => + errors += s"Cannot find data for output column '${tableAttr.name}'" + None + } + } + + } else { + if (expected.size > query.output.size) { + throw new AnalysisException( + s"""Cannot write to '$tableName', not enough data columns: + |Table columns: ${expected.map(_.name).mkString(", ")} + |Data columns: ${query.output.map(_.name).mkString(", ")}""".stripMargin) + } + + query.output.zip(expected).flatMap { + case (queryExpr, tableAttr) => + checkField(tableAttr, queryExpr, err => errors += err) + } + } + + if (errors.nonEmpty) { + throw new AnalysisException( + s"Cannot write incompatible data to table '$tableName':\n- ${errors.mkString("\n- ")}") + } + + Project(resolved, query) + } + + private def checkField( + tableAttr: Attribute, + queryExpr: NamedExpression, + addError: String => Unit): Option[NamedExpression] = { + + if (queryExpr.nullable && !tableAttr.nullable) { + addError(s"Cannot write nullable values to non-null column '${tableAttr.name}'") + None + + } else if (!DataType.canWrite( + tableAttr.dataType, queryExpr.dataType, resolver, tableAttr.name, addError)) { + None + + } else { + // always add an UpCast. it will be removed in the optimizer if it is unnecessary. + Some(Alias( + UpCast(queryExpr, tableAttr.dataType, Seq()), tableAttr.name + )( + explicitMetadata = Option(tableAttr.metadata) + )) + } + } + } + private def commonNaturalJoinProcessing( left: LogicalPlan, right: LogicalPlan, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala new file mode 100644 index 000000000000..ad201f947b67 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +trait NamedRelation extends LogicalPlan { + def name: String +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 19779b73d6db..0d31c6f6b9c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.{AliasIdentifier} -import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, - RangePartitioning, RoundRobinPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.RandomSampler @@ -353,6 +352,37 @@ case class Join( } } +/** + * Append data to an existing table. + */ +case class AppendData( + table: NamedRelation, + query: LogicalPlan, + isByName: Boolean) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Seq.empty + + override lazy val resolved: Boolean = { + query.output.size == table.output.size && query.output.zip(table.output).forall { + case (inAttr, outAttr) => + // names and types must match, nullability must be compatible + inAttr.name == outAttr.name && + DataType.equalsIgnoreCompatibleNullability(outAttr.dataType, inAttr.dataType) && + (outAttr.nullable || !inAttr.nullable) + } + } +} + +object AppendData { + def byName(table: NamedRelation, df: LogicalPlan): AppendData = { + new AppendData(table, df, true) + } + + def byPosition(table: NamedRelation, query: LogicalPlan): AppendData = { + new AppendData(table, query, false) + } +} + /** * Insert some data into a table. Note that this plan is unresolved and has to be replaced by the * concrete implementations during analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index fd40741cfb5f..50f2a9df522c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -27,7 +27,8 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -336,4 +337,124 @@ object DataType { case (fromDataType, toDataType) => fromDataType == toDataType } } + + private val SparkGeneratedName = """col\d+""".r + private def isSparkGeneratedName(name: String): Boolean = name match { + case SparkGeneratedName(_*) => true + case _ => false + } + + /** + * Returns true if the write data type can be read using the read data type. + * + * The write type is compatible with the read type if: + * - Both types are arrays, the array element types are compatible, and element nullability is + * compatible (read allows nulls or write does not contain nulls). + * - Both types are maps and the map key and value types are compatible, and value nullability + * is compatible (read allows nulls or write does not contain nulls). + * - Both types are structs and each field in the read struct is present in the write struct and + * compatible (including nullability), or is nullable if the write struct does not contain the + * field. Write-side structs are not compatible if they contain fields that are not present in + * the read-side struct. + * - Both types are atomic and the write type can be safely cast to the read type. + * + * Extra fields in write-side structs are not allowed to avoid accidentally writing data that + * the read schema will not read, and to ensure map key equality is not changed when data is read. + * + * @param write a write-side data type to validate against the read type + * @param read a read-side data type + * @return true if data written with the write type can be read using the read type + */ + def canWrite( + write: DataType, + read: DataType, + resolver: Resolver, + context: String, + addError: String => Unit = (_: String) => {}): Boolean = { + (write, read) match { + case (wArr: ArrayType, rArr: ArrayType) => + // run compatibility check first to produce all error messages + val typesCompatible = + canWrite(wArr.elementType, rArr.elementType, resolver, context + ".element", addError) + + if (wArr.containsNull && !rArr.containsNull) { + addError(s"Cannot write nullable elements to array of non-nulls: '$context'") + false + } else { + typesCompatible + } + + case (wMap: MapType, rMap: MapType) => + // map keys cannot include data fields not in the read schema without changing equality when + // read. map keys can be missing fields as long as they are nullable in the read schema. + + // run compatibility check first to produce all error messages + val keyCompatible = + canWrite(wMap.keyType, rMap.keyType, resolver, context + ".key", addError) + val valueCompatible = + canWrite(wMap.valueType, rMap.valueType, resolver, context + ".value", addError) + val typesCompatible = keyCompatible && valueCompatible + + if (wMap.valueContainsNull && !rMap.valueContainsNull) { + addError(s"Cannot write nullable values to map of non-nulls: '$context'") + false + } else { + typesCompatible + } + + case (StructType(writeFields), StructType(readFields)) => + var fieldCompatible = true + readFields.zip(writeFields).foreach { + case (rField, wField) => + val namesMatch = resolver(wField.name, rField.name) || isSparkGeneratedName(wField.name) + val fieldContext = s"$context.${rField.name}" + val typesCompatible = + canWrite(wField.dataType, rField.dataType, resolver, fieldContext, addError) + + if (!namesMatch) { + addError(s"Struct '$context' field name does not match (may be out of order): " + + s"expected '${rField.name}', found '${wField.name}'") + fieldCompatible = false + } else if (!rField.nullable && wField.nullable) { + addError(s"Cannot write nullable values to non-null field: '$fieldContext'") + fieldCompatible = false + } else if (!typesCompatible) { + // errors are added in the recursive call to canWrite above + fieldCompatible = false + } + } + + if (readFields.size > writeFields.size) { + val missingFieldsStr = readFields.takeRight(readFields.size - writeFields.size) + .map(f => s"'${f.name}'").mkString(", ") + if (missingFieldsStr.nonEmpty) { + addError(s"Struct '$context' missing fields: $missingFieldsStr") + fieldCompatible = false + } + + } else if (writeFields.size > readFields.size) { + val extraFieldsStr = writeFields.takeRight(writeFields.size - readFields.size) + .map(f => s"'${f.name}'").mkString(", ") + addError(s"Cannot write extra fields to struct '$context': $extraFieldsStr") + fieldCompatible = false + } + + fieldCompatible + + case (w: AtomicType, r: AtomicType) => + if (!Cast.canSafeCast(w, r)) { + addError(s"Cannot safely cast '$context': $w to $r") + false + } else { + true + } + + case (w, r) if w.sameType(r) && !w.isInstanceOf[NullType] => + true + + case (w, r) => + addError(s"Cannot write '$context': $w is incompatible with $r") + false + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala new file mode 100644 index 000000000000..d92f52f3248a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions.Cast + +class DataTypeWriteCompatibilitySuite extends SparkFunSuite { + private val atomicTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DateType, TimestampType, StringType, BinaryType) + + private val point2 = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false))) + + private val widerPoint2 = StructType(Seq( + StructField("x", DoubleType, nullable = false), + StructField("y", DoubleType, nullable = false))) + + private val point3 = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false), + StructField("z", FloatType))) + + private val simpleContainerTypes = Seq( + ArrayType(LongType), ArrayType(LongType, containsNull = false), MapType(StringType, DoubleType), + MapType(StringType, DoubleType, valueContainsNull = false), point2, point3) + + private val nestedContainerTypes = Seq(ArrayType(point2, containsNull = false), + MapType(StringType, point3, valueContainsNull = false)) + + private val allNonNullTypes = Seq( + atomicTypes, simpleContainerTypes, nestedContainerTypes, Seq(CalendarIntervalType)).flatten + + test("Check NullType is incompatible with all other types") { + allNonNullTypes.foreach { t => + assertSingleError(NullType, t, "nulls", s"Should not allow writing None to type $t") { err => + assert(err.contains(s"incompatible with $t")) + } + } + } + + test("Check each type with itself") { + allNonNullTypes.foreach { t => + assertAllowed(t, t, "t", s"Should allow writing type to itself $t") + } + } + + test("Check atomic types: write allowed only when casting is safe") { + atomicTypes.foreach { w => + atomicTypes.foreach { r => + if (Cast.canSafeCast(w, r)) { + assertAllowed(w, r, "t", s"Should allow writing $w to $r because cast is safe") + + } else { + assertSingleError(w, r, "t", + s"Should not allow writing $w to $r because cast is not safe") { err => + assert(err.contains("'t'"), "Should include the field name context") + assert(err.contains("Cannot safely cast"), "Should identify unsafe cast") + assert(err.contains(s"$w"), "Should include write type") + assert(err.contains(s"$r"), "Should include read type") + } + } + } + } + } + + test("Check struct types: missing required field") { + val missingRequiredField = StructType(Seq(StructField("x", FloatType, nullable = false))) + assertSingleError(missingRequiredField, point2, "t", + "Should fail because required field 'y' is missing") { err => + assert(err.contains("'t'"), "Should include the struct name for context") + assert(err.contains("'y'"), "Should include the nested field name") + assert(err.contains("missing field"), "Should call out field missing") + } + } + + test("Check struct types: missing starting field, matched by position") { + val missingRequiredField = StructType(Seq(StructField("y", FloatType, nullable = false))) + + // should have 2 errors: names x and y don't match, and field y is missing + assertNumErrors(missingRequiredField, point2, "t", + "Should fail because field 'x' is matched to field 'y' and required field 'y' is missing", 2) + { errs => + assert(errs(0).contains("'t'"), "Should include the struct name for context") + assert(errs(0).contains("expected 'x', found 'y'"), "Should detect name mismatch") + assert(errs(0).contains("field name does not match"), "Should identify name problem") + + assert(errs(1).contains("'t'"), "Should include the struct name for context") + assert(errs(1).contains("'y'"), "Should include the _last_ nested fields of the read schema") + assert(errs(1).contains("missing field"), "Should call out field missing") + } + } + + test("Check struct types: missing middle field, matched by position") { + val missingMiddleField = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("z", FloatType, nullable = false))) + + val expectedStruct = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false), + StructField("z", FloatType, nullable = true))) + + // types are compatible: (req int, req int) => (req int, req int, opt int) + // but this should still fail because the names do not match. + + assertNumErrors(missingMiddleField, expectedStruct, "t", + "Should fail because field 'y' is matched to field 'z'", 2) { errs => + assert(errs(0).contains("'t'"), "Should include the struct name for context") + assert(errs(0).contains("expected 'y', found 'z'"), "Should detect name mismatch") + assert(errs(0).contains("field name does not match"), "Should identify name problem") + + assert(errs(1).contains("'t'"), "Should include the struct name for context") + assert(errs(1).contains("'z'"), "Should include the nested field name") + assert(errs(1).contains("missing field"), "Should call out field missing") + } + } + + test("Check struct types: generic colN names are ignored") { + val missingMiddleField = StructType(Seq( + StructField("col1", FloatType, nullable = false), + StructField("col2", FloatType, nullable = false))) + + val expectedStruct = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false))) + + // types are compatible: (req int, req int) => (req int, req int) + // names don't match, but match the naming convention used by Spark to fill in names + + assertAllowed(missingMiddleField, expectedStruct, "t", + "Should succeed because column names are ignored") + } + + test("Check struct types: required field is optional") { + val requiredFieldIsOptional = StructType(Seq( + StructField("x", FloatType), + StructField("y", FloatType, nullable = false))) + + assertSingleError(requiredFieldIsOptional, point2, "t", + "Should fail because required field 'x' is optional") { err => + assert(err.contains("'t.x'"), "Should include the nested field name context") + assert(err.contains("Cannot write nullable values to non-null field")) + } + } + + test("Check struct types: data field would be dropped") { + assertSingleError(point3, point2, "t", + "Should fail because field 'z' would be dropped") { err => + assert(err.contains("'t'"), "Should include the struct name for context") + assert(err.contains("'z'"), "Should include the extra field name") + assert(err.contains("Cannot write extra fields")) + } + } + + test("Check struct types: unsafe casts are not allowed") { + assertNumErrors(widerPoint2, point2, "t", + "Should fail because types require unsafe casts", 2) { errs => + + assert(errs(0).contains("'t.x'"), "Should include the nested field name context") + assert(errs(0).contains("Cannot safely cast")) + + assert(errs(1).contains("'t.y'"), "Should include the nested field name context") + assert(errs(1).contains("Cannot safely cast")) + } + } + + test("Check struct types: type promotion is allowed") { + assertAllowed(point2, widerPoint2, "t", + "Should allow widening float fields x and y to double") + } + + ignore("Check struct types: missing optional field is allowed") { + // built-in data sources do not yet support missing fields when optional + assertAllowed(point2, point3, "t", + "Should allow writing point (x,y) to point(x,y,z=null)") + } + + test("Check array types: unsafe casts are not allowed") { + val arrayOfLong = ArrayType(LongType) + val arrayOfInt = ArrayType(IntegerType) + + assertSingleError(arrayOfLong, arrayOfInt, "arr", + "Should not allow array of longs to array of ints") { err => + assert(err.contains("'arr.element'"), + "Should identify problem with named array's element type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check array types: type promotion is allowed") { + val arrayOfLong = ArrayType(LongType) + val arrayOfInt = ArrayType(IntegerType) + assertAllowed(arrayOfInt, arrayOfLong, "arr", + "Should allow array of int written to array of long column") + } + + test("Check array types: cannot write optional to required elements") { + val arrayOfRequired = ArrayType(LongType, containsNull = false) + val arrayOfOptional = ArrayType(LongType) + + assertSingleError(arrayOfOptional, arrayOfRequired, "arr", + "Should not allow array of optional elements to array of required elements") { err => + assert(err.contains("'arr'"), "Should include type name context") + assert(err.contains("Cannot write nullable elements to array of non-nulls")) + } + } + + test("Check array types: writing required to optional elements is allowed") { + val arrayOfRequired = ArrayType(LongType, containsNull = false) + val arrayOfOptional = ArrayType(LongType) + + assertAllowed(arrayOfRequired, arrayOfOptional, "arr", + "Should allow array of required elements to array of optional elements") + } + + test("Check map value types: unsafe casts are not allowed") { + val mapOfLong = MapType(StringType, LongType) + val mapOfInt = MapType(StringType, IntegerType) + + assertSingleError(mapOfLong, mapOfInt, "m", + "Should not allow map of longs to map of ints") { err => + assert(err.contains("'m.value'"), "Should identify problem with named map's value type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map value types: type promotion is allowed") { + val mapOfLong = MapType(StringType, LongType) + val mapOfInt = MapType(StringType, IntegerType) + + assertAllowed(mapOfInt, mapOfLong, "m", "Should allow map of int written to map of long column") + } + + test("Check map value types: cannot write optional to required values") { + val mapOfRequired = MapType(StringType, LongType, valueContainsNull = false) + val mapOfOptional = MapType(StringType, LongType) + + assertSingleError(mapOfOptional, mapOfRequired, "m", + "Should not allow map of optional values to map of required values") { err => + assert(err.contains("'m'"), "Should include type name context") + assert(err.contains("Cannot write nullable values to map of non-nulls")) + } + } + + test("Check map value types: writing required to optional values is allowed") { + val mapOfRequired = MapType(StringType, LongType, valueContainsNull = false) + val mapOfOptional = MapType(StringType, LongType) + + assertAllowed(mapOfRequired, mapOfOptional, "m", + "Should allow map of required elements to map of optional elements") + } + + test("Check map key types: unsafe casts are not allowed") { + val mapKeyLong = MapType(LongType, StringType) + val mapKeyInt = MapType(IntegerType, StringType) + + assertSingleError(mapKeyLong, mapKeyInt, "m", + "Should not allow map of long keys to map of int keys") { err => + assert(err.contains("'m.key'"), "Should identify problem with named map's key type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map key types: type promotion is allowed") { + val mapKeyLong = MapType(LongType, StringType) + val mapKeyInt = MapType(IntegerType, StringType) + + assertAllowed(mapKeyInt, mapKeyLong, "m", + "Should allow map of int written to map of long column") + } + + test("Check types with multiple errors") { + val readType = StructType(Seq( + StructField("a", ArrayType(DoubleType, containsNull = false)), + StructField("arr_of_structs", ArrayType(point2, containsNull = false)), + StructField("bad_nested_type", ArrayType(StringType)), + StructField("m", MapType(LongType, FloatType, valueContainsNull = false)), + StructField("map_of_structs", MapType(StringType, point3, valueContainsNull = false)), + StructField("x", IntegerType, nullable = false), + StructField("missing1", StringType, nullable = false), + StructField("missing2", StringType) + )) + + val missingMiddleField = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("z", FloatType, nullable = false))) + + val writeType = StructType(Seq( + StructField("a", ArrayType(StringType)), + StructField("arr_of_structs", ArrayType(point3)), + StructField("bad_nested_type", point3), + StructField("m", MapType(DoubleType, DoubleType)), + StructField("map_of_structs", MapType(StringType, missingMiddleField)), + StructField("y", LongType) + )) + + assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) { errs => + assert(errs(0).contains("'top.a.element'"), "Should identify bad type") + assert(errs(0).contains("Cannot safely cast")) + assert(errs(0).contains("StringType to DoubleType")) + + assert(errs(1).contains("'top.a'"), "Should identify bad type") + assert(errs(1).contains("Cannot write nullable elements to array of non-nulls")) + + assert(errs(2).contains("'top.arr_of_structs.element'"), "Should identify bad type") + assert(errs(2).contains("'z'"), "Should identify bad field") + assert(errs(2).contains("Cannot write extra fields to struct")) + + assert(errs(3).contains("'top.arr_of_structs'"), "Should identify bad type") + assert(errs(3).contains("Cannot write nullable elements to array of non-nulls")) + + assert(errs(4).contains("'top.bad_nested_type'"), "Should identify bad type") + assert(errs(4).contains("is incompatible with")) + + assert(errs(5).contains("'top.m.key'"), "Should identify bad type") + assert(errs(5).contains("Cannot safely cast")) + assert(errs(5).contains("DoubleType to LongType")) + + assert(errs(6).contains("'top.m.value'"), "Should identify bad type") + assert(errs(6).contains("Cannot safely cast")) + assert(errs(6).contains("DoubleType to FloatType")) + + assert(errs(7).contains("'top.m'"), "Should identify bad type") + assert(errs(7).contains("Cannot write nullable values to map of non-nulls")) + + assert(errs(8).contains("'top.map_of_structs.value'"), "Should identify bad type") + assert(errs(8).contains("expected 'y', found 'z'"), "Should detect name mismatch") + assert(errs(8).contains("field name does not match"), "Should identify name problem") + + assert(errs(9).contains("'top.map_of_structs.value'"), "Should identify bad type") + assert(errs(9).contains("'z'"), "Should identify missing field") + assert(errs(9).contains("missing fields"), "Should detect missing field") + + assert(errs(10).contains("'top.map_of_structs'"), "Should identify bad type") + assert(errs(10).contains("Cannot write nullable values to map of non-nulls")) + + assert(errs(11).contains("'top.x'"), "Should identify bad type") + assert(errs(11).contains("Cannot safely cast")) + assert(errs(11).contains("LongType to IntegerType")) + + assert(errs(12).contains("'top'"), "Should identify bad type") + assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name mismatch") + assert(errs(12).contains("field name does not match"), "Should identify name problem") + + assert(errs(13).contains("'top'"), "Should identify bad type") + assert(errs(13).contains("'missing1'"), "Should identify missing field") + assert(errs(13).contains("missing fields"), "Should detect missing field") + } + } + + // Helper functions + + def assertAllowed(writeType: DataType, readType: DataType, name: String, desc: String): Unit = { + assert( + DataType.canWrite(writeType, readType, analysis.caseSensitiveResolution, name, + errMsg => fail(s"Should not produce errors but was called with: $errMsg")) === true, desc) + } + + def assertSingleError( + writeType: DataType, + readType: DataType, + name: String, + desc: String) + (errFunc: String => Unit): Unit = { + assertNumErrors(writeType, readType, name, desc, 1) { errs => + errFunc(errs.head) + } + } + + def assertNumErrors( + writeType: DataType, + readType: DataType, + name: String, + desc: String, + numErrs: Int) + (errFunc: Seq[String] => Unit): Unit = { + val errs = new mutable.ArrayBuffer[String]() + assert( + DataType.canWrite(writeType, readType, analysis.caseSensitiveResolution, name, + errMsg => errs += errMsg) === false, desc) + assert(errs.size === numErrs, s"Should produce $numErrs error messages") + errFunc(errs) + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index 83aeec0c4785..048787a7a0a0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -38,15 +38,16 @@ public interface WriteSupport extends DataSourceV2 { * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * - * @param jobId A unique string for the writing job. It's possible that there are many writing - * jobs running at the same time, and the returned {@link DataSourceWriter} can - * use this job id to distinguish itself from other jobs. + * @param writeUUID A unique string for the writing job. It's possible that there are many writing + * jobs running at the same time, and the returned {@link DataSourceWriter} can + * use this job id to distinguish itself from other jobs. * @param schema the schema of the data to be written. * @param mode the save mode which determines what to do when the data are already in this data * source, please refer to {@link SaveMode} for more details. * @param options the options for the returned data source writer, which is an immutable * case-insensitive string-to-string map. + * @return a writer to append data to this data source */ Optional createWriter( - String jobId, StructType schema, SaveMode mode, DataSourceOptions options); + String writeUUID, StructType schema, SaveMode mode, DataSourceOptions options); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index cd7dc2a2727e..db2a1e742619 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql -import java.text.SimpleDateFormat -import java.util.{Date, Locale, Properties, UUID} +import java.util.{Locale, Properties, UUID} import scala.collection.JavaConverters._ @@ -26,12 +25,11 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, WriteToDataSourceV2} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.StructType @@ -240,21 +238,27 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() - ds match { + val source = cls.newInstance().asInstanceOf[DataSourceV2] + source match { case ws: WriteSupport => - val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = df.sparkSession.sessionState.conf)).asJava) - // Using a timestamp and a random UUID to distinguish different writing jobs. This is good - // enough as there won't be tons of writing jobs created at the same second. - val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) - .format(new Date()) + "-" + UUID.randomUUID() - val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options) - if (writer.isPresent) { + val options = extraOptions ++ + DataSourceV2Utils.extractSessionConfigs(source, df.sparkSession.sessionState.conf) + + val relation = DataSourceV2Relation.create(source, options.toMap) + if (mode == SaveMode.Append) { runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(writer.get(), df.logicalPlan) + AppendData.byName(relation, df.logicalPlan) + } + + } else { + val writer = ws.createWriter( + UUID.randomUUID.toString, df.logicalPlan.output.toStructType, mode, + new DataSourceOptions(options.asJava)) + + if (writer.isPresent) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(writer.get, df.logicalPlan) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 46166928f449..a4bfc861cc9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,15 +17,19 @@ package org.apache.spark.sql.execution.datasources.v2 +import java.util.UUID + import scala.collection.JavaConverters._ -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport} import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter import org.apache.spark.sql.types.StructType /** @@ -40,17 +44,24 @@ case class DataSourceV2Relation( source: DataSourceV2, output: Seq[AttributeReference], options: Map[String, String], - userSpecifiedSchema: Option[StructType]) - extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { + tableIdent: Option[TableIdentifier] = None, + userSpecifiedSchema: Option[StructType] = None) + extends LeafNode with MultiInstanceRelation with NamedRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ + override def name: String = { + tableIdent.map(_.unquotedString).getOrElse(s"${source.name}:unknown") + } + override def pushedFilters: Seq[Expression] = Seq.empty override def simpleString: String = "RelationV2 " + metadataString def newReader(): DataSourceReader = source.createReader(options, userSpecifiedSchema) + def newWriter(): DataSourceWriter = source.createWriter(options, schema) + override def computeStats(): Statistics = newReader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) @@ -115,6 +126,15 @@ object DataSourceV2Relation { } } + def asWriteSupport: WriteSupport = { + source match { + case support: WriteSupport => + support + case _ => + throw new AnalysisException(s"Data source is not writable: $name") + } + } + def name: String = { source match { case registered: DataSourceRegister => @@ -135,14 +155,29 @@ object DataSourceV2Relation { asReadSupport.createReader(v2Options) } } + + def createWriter( + options: Map[String, String], + schema: StructType): DataSourceWriter = { + val v2Options = new DataSourceOptions(options.asJava) + asWriteSupport.createWriter(UUID.randomUUID.toString, schema, SaveMode.Append, v2Options).get + } } def create( source: DataSourceV2, options: Map[String, String], - userSpecifiedSchema: Option[StructType]): DataSourceV2Relation = { + tableIdent: Option[TableIdentifier] = None, + userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { val reader = source.createReader(options, userSpecifiedSchema) + val ident = tableIdent.orElse(tableFromOptions(options)) DataSourceV2Relation( - source, reader.readSchema().toAttributes, options, userSpecifiedSchema) + source, reader.readSchema().toAttributes, options, ident, userSpecifiedSchema) + } + + private def tableFromOptions(options: Map[String, String]): Option[TableIdentifier] = { + options + .get(DataSourceOptions.TABLE_KEY) + .map(TableIdentifier(_, options.get(DataSourceOptions.DATABASE_KEY))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 9414e68155b9..6daaa4c65c33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.{sources, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} @@ -136,6 +136,9 @@ object DataSourceV2Strategy extends Strategy { case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case AppendData(r: DataSourceV2Relation, query, _) => + WriteToDataSourceV2Exec(r.newWriter(), planLater(query)) :: Nil + case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 0399970495be..59ebb9bc5431 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -35,8 +35,10 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils /** - * The logical plan for writing data into data source v2. + * Deprecated logical plan for writing data into data source v2. This is being replaced by more + * specific logical plans, like [[org.apache.spark.sql.catalyst.plans.logical.AppendData]]. */ +@deprecated("Use specific logical plans like AppendData instead", "2.4.0") case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index fef53e6f7b6f..aa5f723365d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -192,33 +192,33 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val path = file.getCanonicalPath assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) - spark.range(10).select('id, -'id).write.format(cls.getName) + spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(10).select('id, -'id)) // test with different save modes - spark.range(10).select('id, -'id).write.format(cls.getName) + spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("append").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(10).union(spark.range(10)).select('id, -'id)) - spark.range(5).select('id, -'id).write.format(cls.getName) + spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("overwrite").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(5).select('id, -'id)) - spark.range(5).select('id, -'id).write.format(cls.getName) + spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("ignore").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(5).select('id, -'id)) val e = intercept[Exception] { - spark.range(5).select('id, -'id).write.format(cls.getName) + spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("error").save() } assert(e.getMessage.contains("data already exists")) @@ -235,7 +235,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } // this input data will fail to read middle way. - val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i) + val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i as 'j) val e2 = intercept[SparkException] { input.write.format(cls.getName).option("path", path).mode("overwrite").save() } @@ -253,7 +253,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) val numPartition = 6 - spark.range(0, 10, 1, numPartition).select('id, -'id).write.format(cls.getName) + spark.range(0, 10, 1, numPartition).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(),