Skip to content

Commit 3d010c8

Browse files
cloud-fanyhuai
authored andcommitted
[SPARK-16036][SPARK-16037][SQL] fix various table insertion problems
## What changes were proposed in this pull request? The current table insertion has some weird behaviours: 1. inserting into a partitioned table with mismatch columns has confusing error message for hive table, and wrong result for datasource table 2. inserting into a partitioned table without partition list has wrong result for hive table. This PR fixes these 2 problems. ## How was this patch tested? new test in hive `SQLQuerySuite` Author: Wenchen Fan <[email protected]> Closes #13754 from cloud-fan/insert2.
1 parent e574c99 commit 3d010c8

File tree

12 files changed

+104
-185
lines changed

12 files changed

+104
-185
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -452,42 +452,7 @@ class Analyzer(
452452

453453
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
454454
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
455-
val table = lookupTableFromCatalog(u)
456-
// adding the table's partitions or validate the query's partition info
457-
table match {
458-
case relation: CatalogRelation if relation.catalogTable.partitionColumns.nonEmpty =>
459-
val tablePartitionNames = relation.catalogTable.partitionColumns.map(_.name)
460-
if (parts.keys.nonEmpty) {
461-
// the query's partitioning must match the table's partitioning
462-
// this is set for queries like: insert into ... partition (one = "a", two = <expr>)
463-
// TODO: add better checking to pre-inserts to avoid needing this here
464-
if (tablePartitionNames.size != parts.keySet.size) {
465-
throw new AnalysisException(
466-
s"""Requested partitioning does not match the ${u.tableIdentifier} table:
467-
|Requested partitions: ${parts.keys.mkString(",")}
468-
|Table partitions: ${tablePartitionNames.mkString(",")}""".stripMargin)
469-
}
470-
// Assume partition columns are correctly placed at the end of the child's output
471-
i.copy(table = EliminateSubqueryAliases(table))
472-
} else {
473-
// Set up the table's partition scheme with all dynamic partitions by moving partition
474-
// columns to the end of the column list, in partition order.
475-
val (inputPartCols, columns) = child.output.partition { attr =>
476-
tablePartitionNames.contains(attr.name)
477-
}
478-
// All partition columns are dynamic because this InsertIntoTable had no partitioning
479-
val partColumns = tablePartitionNames.map { name =>
480-
inputPartCols.find(_.name == name).getOrElse(
481-
throw new AnalysisException(s"Cannot find partition column $name"))
482-
}
483-
i.copy(
484-
table = EliminateSubqueryAliases(table),
485-
partition = tablePartitionNames.map(_ -> None).toMap,
486-
child = Project(columns ++ partColumns, child))
487-
}
488-
case _ =>
489-
i.copy(table = EliminateSubqueryAliases(table))
490-
}
455+
i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u)))
491456
case u: UnresolvedRelation =>
492457
val table = u.tableIdentifier
493458
if (table.database.isDefined && conf.runSQLonFile &&

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,10 +369,8 @@ case class InsertIntoTable(
369369
if (table.output.isEmpty) {
370370
None
371371
} else {
372-
val numDynamicPartitions = partition.values.count(_.isEmpty)
373-
val (partitionColumns, dataColumns) = table.output
374-
.partition(a => partition.keySet.contains(a.name))
375-
Some(dataColumns ++ partitionColumns.takeRight(numDynamicPartitions))
372+
val staticPartCols = partition.filter(_._2.isDefined).keySet
373+
Some(table.output.filterNot(a => staticPartCols.contains(a.name)))
376374
}
377375
}
378376

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.util.control.NonFatal
2121

2222
import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
2323
import org.apache.spark.sql.catalyst.analysis._
24-
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
24+
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, SessionCatalog}
2525
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering}
2626
import org.apache.spark.sql.catalyst.plans.logical
2727
import org.apache.spark.sql.catalyst.plans.logical._
@@ -62,53 +62,79 @@ private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[Lo
6262
}
6363

6464
/**
65-
* A rule to do pre-insert data type casting and field renaming. Before we insert into
66-
* an [[InsertableRelation]], we will use this rule to make sure that
67-
* the columns to be inserted have the correct data type and fields have the correct names.
65+
* Preprocess the [[InsertIntoTable]] plan. Throws exception if the number of columns mismatch, or
66+
* specified partition columns are different from the existing partition columns in the target
67+
* table. It also does data type casting and field renaming, to make sure that the columns to be
68+
* inserted have the correct data type and fields have the correct names.
6869
*/
69-
private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
70-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
71-
// Wait until children are resolved.
72-
case p: LogicalPlan if !p.childrenResolved => p
73-
74-
// We are inserting into an InsertableRelation or HadoopFsRelation.
75-
case i @ InsertIntoTable(
76-
l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _), _, child, _, _) =>
77-
// First, make sure the data to be inserted have the same number of fields with the
78-
// schema of the relation.
79-
if (l.output.size != child.output.size) {
80-
sys.error(
81-
s"$l requires that the data to be inserted have the same number of columns as the " +
82-
s"target table: target table has ${l.output.size} column(s) but " +
83-
s"the inserted data has ${child.output.size} column(s).")
84-
}
85-
castAndRenameChildOutput(i, l.output, child)
70+
private[sql] object PreprocessTableInsertion extends Rule[LogicalPlan] {
71+
private def preprocess(
72+
insert: InsertIntoTable,
73+
tblName: String,
74+
partColNames: Seq[String]): InsertIntoTable = {
75+
76+
val expectedColumns = insert.expectedColumns
77+
if (expectedColumns.isDefined && expectedColumns.get.length != insert.child.schema.length) {
78+
throw new AnalysisException(
79+
s"Cannot insert into table $tblName because the number of columns are different: " +
80+
s"need ${expectedColumns.get.length} columns, " +
81+
s"but query has ${insert.child.schema.length} columns.")
82+
}
83+
84+
if (insert.partition.nonEmpty) {
85+
// the query's partitioning must match the table's partitioning
86+
// this is set for queries like: insert into ... partition (one = "a", two = <expr>)
87+
if (insert.partition.keySet != partColNames.toSet) {
88+
throw new AnalysisException(
89+
s"""
90+
|Requested partitioning does not match the table $tblName:
91+
|Requested partitions: ${insert.partition.keys.mkString(",")}
92+
|Table partitions: ${partColNames.mkString(",")}
93+
""".stripMargin)
94+
}
95+
expectedColumns.map(castAndRenameChildOutput(insert, _)).getOrElse(insert)
96+
} else {
97+
// All partition columns are dynamic because this InsertIntoTable had no partitioning
98+
expectedColumns.map(castAndRenameChildOutput(insert, _)).getOrElse(insert)
99+
.copy(partition = partColNames.map(_ -> None).toMap)
100+
}
86101
}
87102

88-
/** If necessary, cast data types and rename fields to the expected types and names. */
103+
// TODO: do we really need to rename?
89104
def castAndRenameChildOutput(
90-
insertInto: InsertIntoTable,
91-
expectedOutput: Seq[Attribute],
92-
child: LogicalPlan): InsertIntoTable = {
93-
val newChildOutput = expectedOutput.zip(child.output).map {
105+
insert: InsertIntoTable,
106+
expectedOutput: Seq[Attribute]): InsertIntoTable = {
107+
val newChildOutput = expectedOutput.zip(insert.child.output).map {
94108
case (expected, actual) =>
95-
val needCast = !expected.dataType.sameType(actual.dataType)
96-
// We want to make sure the filed names in the data to be inserted exactly match
97-
// names in the schema.
98-
val needRename = expected.name != actual.name
99-
(needCast, needRename) match {
100-
case (true, _) => Alias(Cast(actual, expected.dataType), expected.name)()
101-
case (false, true) => Alias(actual, expected.name)()
102-
case (_, _) => actual
109+
if (expected.dataType.sameType(actual.dataType) && expected.name == actual.name) {
110+
actual
111+
} else {
112+
Alias(Cast(actual, expected.dataType), expected.name)()
103113
}
104114
}
105115

106-
if (newChildOutput == child.output) {
107-
insertInto
116+
if (newChildOutput == insert.child.output) {
117+
insert
108118
} else {
109-
insertInto.copy(child = Project(newChildOutput, child))
119+
insert.copy(child = Project(newChildOutput, insert.child))
110120
}
111121
}
122+
123+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
124+
case i @ InsertIntoTable(table, partition, child, _, _) if table.resolved && child.resolved =>
125+
table match {
126+
case relation: CatalogRelation =>
127+
val metadata = relation.catalogTable
128+
preprocess(i, metadata.identifier.quotedString, metadata.partitionColumnNames)
129+
case LogicalRelation(h: HadoopFsRelation, _, identifier) =>
130+
val tblName = identifier.map(_.quotedString).getOrElse("unknown")
131+
preprocess(i, tblName, h.partitionSchema.map(_.name))
132+
case LogicalRelation(_: InsertableRelation, _, identifier) =>
133+
val tblName = identifier.map(_.quotedString).getOrElse("unknown")
134+
preprocess(i, tblName, Nil)
135+
case other => i
136+
}
137+
}
112138
}
113139

114140
/**

sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface
3030
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3131
import org.apache.spark.sql.execution._
3232
import org.apache.spark.sql.execution.command.AnalyzeTableCommand
33-
import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, PreInsertCastAndRename, ResolveDataSource}
33+
import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, PreprocessTableInsertion, ResolveDataSource}
3434
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager}
3535
import org.apache.spark.sql.util.ExecutionListenerManager
3636

@@ -111,7 +111,7 @@ private[sql] class SessionState(sparkSession: SparkSession) {
111111
lazy val analyzer: Analyzer = {
112112
new Analyzer(catalog, conf) {
113113
override val extendedResolutionRules =
114-
PreInsertCastAndRename ::
114+
PreprocessTableInsertion ::
115115
new FindDataSourceTable(sparkSession) ::
116116
DataSourceAnalysis ::
117117
(if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil)

sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,13 @@ class InsertSuite extends DataSourceTest with SharedSQLContext {
8888
}
8989

9090
test("SELECT clause generating a different number of columns is not allowed.") {
91-
val message = intercept[RuntimeException] {
91+
val message = intercept[AnalysisException] {
9292
sql(
9393
s"""
9494
|INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt
9595
""".stripMargin)
9696
}.getMessage
97-
assert(
98-
message.contains("requires that the data to be inserted have the same number of columns"),
99-
"SELECT clause generating a different number of columns should not be not allowed."
97+
assert(message.contains("the number of columns are different")
10098
)
10199
}
102100

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -457,49 +457,6 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
457457
allowExisting)
458458
}
459459
}
460-
461-
/**
462-
* Casts input data to correct data types according to table definition before inserting into
463-
* that table.
464-
*/
465-
object PreInsertionCasts extends Rule[LogicalPlan] {
466-
def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
467-
// Wait until children are resolved.
468-
case p: LogicalPlan if !p.childrenResolved => p
469-
470-
case p @ InsertIntoTable(table: MetastoreRelation, _, child, _, _) =>
471-
castChildOutput(p, table, child)
472-
}
473-
474-
def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan)
475-
: LogicalPlan = {
476-
val childOutputDataTypes = child.output.map(_.dataType)
477-
val numDynamicPartitions = p.partition.values.count(_.isEmpty)
478-
val tableOutputDataTypes =
479-
(table.attributes ++ table.partitionKeys.takeRight(numDynamicPartitions))
480-
.take(child.output.length).map(_.dataType)
481-
482-
if (childOutputDataTypes == tableOutputDataTypes) {
483-
InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists)
484-
} else if (childOutputDataTypes.size == tableOutputDataTypes.size &&
485-
childOutputDataTypes.zip(tableOutputDataTypes)
486-
.forall { case (left, right) => left.sameType(right) }) {
487-
// If both types ignoring nullability of ArrayType, MapType, StructType are the same,
488-
// use InsertIntoHiveTable instead of InsertIntoTable.
489-
InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists)
490-
} else {
491-
// Only do the casting when child output data types differ from table output data types.
492-
val castedChildOutput = child.output.zip(table.output).map {
493-
case (input, output) if input.dataType != output.dataType =>
494-
Alias(Cast(input, output.dataType), input.name)()
495-
case (input, _) => input
496-
}
497-
498-
p.copy(child = logical.Project(castedChildOutput, child))
499-
}
500-
}
501-
}
502-
503460
}
504461

505462
/**

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ private[sql] class HiveSessionCatalog(
8787
val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions
8888
val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions
8989
val CreateTables: Rule[LogicalPlan] = metastoreCatalog.CreateTables
90-
val PreInsertionCasts: Rule[LogicalPlan] = metastoreCatalog.PreInsertionCasts
9190

9291
override def refreshTable(name: TableIdentifier): Unit = {
9392
metastoreCatalog.refreshTable(name)

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession)
6565
catalog.ParquetConversions ::
6666
catalog.OrcConversions ::
6767
catalog.CreateTables ::
68-
catalog.PreInsertionCasts ::
69-
PreInsertCastAndRename ::
68+
PreprocessTableInsertion ::
7069
DataSourceAnalysis ::
7170
(if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil)
7271

sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -325,27 +325,6 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
325325
}
326326
}
327327

328-
test("Detect table partitioning with correct partition order") {
329-
withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
330-
sql("CREATE TABLE source (id bigint, part2 string, part1 string, data string)")
331-
val data = (1 to 10).map(i => (i, if ((i % 2) == 0) "even" else "odd", "p", s"data-$i"))
332-
.toDF("id", "part2", "part1", "data")
333-
334-
data.write.insertInto("source")
335-
checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq)
336-
337-
// the original data with part1 and part2 at the end
338-
val expected = data.select("id", "data", "part1", "part2")
339-
340-
sql(
341-
"""CREATE TABLE partitioned (id bigint, data string)
342-
|PARTITIONED BY (part1 string, part2 string)""".stripMargin)
343-
spark.table("source").write.insertInto("partitioned")
344-
345-
checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq)
346-
}
347-
}
348-
349328
private def testPartitionedHiveSerDeTable(testName: String)(f: String => Unit): Unit = {
350329
test(s"Hive SerDe table - $testName") {
351330
val hiveTable = "hive_table"

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ abstract class HiveComparisonTest
348348
queryString.replace("../../data", testDataPath))
349349
val containsCommands = originalQuery.analyzed.collectFirst {
350350
case _: Command => ()
351+
case _: InsertIntoTable => ()
351352
case _: LogicalInsertIntoHiveTable => ()
352353
}.nonEmpty
353354

0 commit comments

Comments
 (0)