Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -452,42 +452,7 @@ class Analyzer(

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
val table = lookupTableFromCatalog(u)
// adding the table's partitions or validate the query's partition info
table match {
case relation: CatalogRelation if relation.catalogTable.partitionColumns.nonEmpty =>
val tablePartitionNames = relation.catalogTable.partitionColumns.map(_.name)
if (parts.keys.nonEmpty) {
// the query's partitioning must match the table's partitioning
// this is set for queries like: insert into ... partition (one = "a", two = <expr>)
// TODO: add better checking to pre-inserts to avoid needing this here
if (tablePartitionNames.size != parts.keySet.size) {
throw new AnalysisException(
s"""Requested partitioning does not match the ${u.tableIdentifier} table:
|Requested partitions: ${parts.keys.mkString(",")}
|Table partitions: ${tablePartitionNames.mkString(",")}""".stripMargin)
}
// Assume partition columns are correctly placed at the end of the child's output
i.copy(table = EliminateSubqueryAliases(table))
} else {
// Set up the table's partition scheme with all dynamic partitions by moving partition
// columns to the end of the column list, in partition order.
val (inputPartCols, columns) = child.output.partition { attr =>
tablePartitionNames.contains(attr.name)
}
// All partition columns are dynamic because this InsertIntoTable had no partitioning
val partColumns = tablePartitionNames.map { name =>
inputPartCols.find(_.name == name).getOrElse(
throw new AnalysisException(s"Cannot find partition column $name"))
}
i.copy(
table = EliminateSubqueryAliases(table),
partition = tablePartitionNames.map(_ -> None).toMap,
child = Project(columns ++ partColumns, child))
}
case _ =>
i.copy(table = EliminateSubqueryAliases(table))
}
i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u)))
case u: UnresolvedRelation =>
val table = u.tableIdentifier
if (table.database.isDefined && conf.runSQLonFile &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,8 @@ case class InsertIntoTable(
if (table.output.isEmpty) {
None
} else {
val numDynamicPartitions = partition.values.count(_.isEmpty)
val (partitionColumns, dataColumns) = table.output
.partition(a => partition.keySet.contains(a.name))
Some(dataColumns ++ partitionColumns.takeRight(numDynamicPartitions))
val staticPartCols = partition.filter(_._2.isDefined).keySet
Some(table.output.filterNot(a => staticPartCols.contains(a.name)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this contains does not work for case-insensitive resolution. We can fix is in a separate PR.

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.util.control.NonFatal

import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -62,53 +62,79 @@ private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[Lo
}

/**
* A rule to do pre-insert data type casting and field renaming. Before we insert into
* an [[InsertableRelation]], we will use this rule to make sure that
* the columns to be inserted have the correct data type and fields have the correct names.
* Preprocess the [[InsertIntoTable]] plan. Throws exception if the number of columns mismatch, or
* specified partition columns are different from the existing partition columns in the target
* table. It also does data type casting and field renaming, to make sure that the columns to be
* inserted have the correct data type and fields have the correct names.
*/
private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Wait until children are resolved.
case p: LogicalPlan if !p.childrenResolved => p

// We are inserting into an InsertableRelation or HadoopFsRelation.
case i @ InsertIntoTable(
l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _), _, child, _, _) =>
// First, make sure the data to be inserted have the same number of fields with the
// schema of the relation.
if (l.output.size != child.output.size) {
sys.error(
s"$l requires that the data to be inserted have the same number of columns as the " +
s"target table: target table has ${l.output.size} column(s) but " +
s"the inserted data has ${child.output.size} column(s).")
}
castAndRenameChildOutput(i, l.output, child)
private[sql] object PreprocessTableInsertion extends Rule[LogicalPlan] {
private def preprocess(
insert: InsertIntoTable,
tblName: String,
partColNames: Seq[String]): InsertIntoTable = {

val expectedColumns = insert.expectedColumns
if (expectedColumns.isDefined && expectedColumns.get.length != insert.child.schema.length) {
throw new AnalysisException(
s"Cannot insert into table $tblName because the number of columns are different: " +
s"need ${expectedColumns.get.length} columns, " +
s"but query has ${insert.child.schema.length} columns.")
}

if (insert.partition.nonEmpty) {
// the query's partitioning must match the table's partitioning
// this is set for queries like: insert into ... partition (one = "a", two = <expr>)
if (insert.partition.keySet != partColNames.toSet) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is case-sensitive.

throw new AnalysisException(
s"""
|Requested partitioning does not match the table $tblName:
|Requested partitions: ${insert.partition.keys.mkString(",")}
|Table partitions: ${partColNames.mkString(",")}
""".stripMargin)
}
expectedColumns.map(castAndRenameChildOutput(insert, _)).getOrElse(insert)
} else {
// All partition columns are dynamic because this InsertIntoTable had no partitioning
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the InsertIntoTable command does not explicitly specify partitioning columns.

expectedColumns.map(castAndRenameChildOutput(insert, _)).getOrElse(insert)
.copy(partition = partColNames.map(_ -> None).toMap)
}
}

/** If necessary, cast data types and rename fields to the expected types and names. */
// TODO: do we really need to rename?
def castAndRenameChildOutput(
insertInto: InsertIntoTable,
expectedOutput: Seq[Attribute],
child: LogicalPlan): InsertIntoTable = {
val newChildOutput = expectedOutput.zip(child.output).map {
insert: InsertIntoTable,
expectedOutput: Seq[Attribute]): InsertIntoTable = {
val newChildOutput = expectedOutput.zip(insert.child.output).map {
case (expected, actual) =>
val needCast = !expected.dataType.sameType(actual.dataType)
// We want to make sure the filed names in the data to be inserted exactly match
// names in the schema.
val needRename = expected.name != actual.name
(needCast, needRename) match {
case (true, _) => Alias(Cast(actual, expected.dataType), expected.name)()
case (false, true) => Alias(actual, expected.name)()
case (_, _) => actual
if (expected.dataType.sameType(actual.dataType) && expected.name == actual.name) {
actual
} else {
Alias(Cast(actual, expected.dataType), expected.name)()
}
}

if (newChildOutput == child.output) {
insertInto
if (newChildOutput == insert.child.output) {
insert
} else {
insertInto.copy(child = Project(newChildOutput, child))
insert.copy(child = Project(newChildOutput, insert.child))
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case i @ InsertIntoTable(table, partition, child, _, _) if table.resolved && child.resolved =>
table match {
case relation: CatalogRelation =>
val metadata = relation.catalogTable
preprocess(i, metadata.identifier.quotedString, metadata.partitionColumnNames)
case LogicalRelation(h: HadoopFsRelation, _, identifier) =>
val tblName = identifier.map(_.quotedString).getOrElse("unknown")
preprocess(i, tblName, h.partitionSchema.map(_.name))
case LogicalRelation(_: InsertableRelation, _, identifier) =>
val tblName = identifier.map(_.quotedString).getOrElse("unknown")
preprocess(i, tblName, Nil)
case other => i
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.AnalyzeTableCommand
import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, PreInsertCastAndRename, ResolveDataSource}
import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, PreprocessTableInsertion, ResolveDataSource}
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager}
import org.apache.spark.sql.util.ExecutionListenerManager

Expand Down Expand Up @@ -111,7 +111,7 @@ private[sql] class SessionState(sparkSession: SparkSession) {
lazy val analyzer: Analyzer = {
new Analyzer(catalog, conf) {
override val extendedResolutionRules =
PreInsertCastAndRename ::
PreprocessTableInsertion ::
new FindDataSourceTable(sparkSession) ::
DataSourceAnalysis ::
(if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,13 @@ class InsertSuite extends DataSourceTest with SharedSQLContext {
}

test("SELECT clause generating a different number of columns is not allowed.") {
val message = intercept[RuntimeException] {
val message = intercept[AnalysisException] {
sql(
s"""
|INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt
""".stripMargin)
}.getMessage
assert(
message.contains("requires that the data to be inserted have the same number of columns"),
"SELECT clause generating a different number of columns should not be not allowed."
assert(message.contains("the number of columns are different")
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,49 +457,6 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
allowExisting)
}
}

/**
* Casts input data to correct data types according to table definition before inserting into
* that table.
*/
object PreInsertionCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
// Wait until children are resolved.
case p: LogicalPlan if !p.childrenResolved => p

case p @ InsertIntoTable(table: MetastoreRelation, _, child, _, _) =>
castChildOutput(p, table, child)
}

def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan)
: LogicalPlan = {
val childOutputDataTypes = child.output.map(_.dataType)
val numDynamicPartitions = p.partition.values.count(_.isEmpty)
val tableOutputDataTypes =
(table.attributes ++ table.partitionKeys.takeRight(numDynamicPartitions))
.take(child.output.length).map(_.dataType)

if (childOutputDataTypes == tableOutputDataTypes) {
InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists)
} else if (childOutputDataTypes.size == tableOutputDataTypes.size &&
childOutputDataTypes.zip(tableOutputDataTypes)
.forall { case (left, right) => left.sameType(right) }) {
// If both types ignoring nullability of ArrayType, MapType, StructType are the same,
// use InsertIntoHiveTable instead of InsertIntoTable.
InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists)
} else {
// Only do the casting when child output data types differ from table output data types.
val castedChildOutput = child.output.zip(table.output).map {
case (input, output) if input.dataType != output.dataType =>
Alias(Cast(input, output.dataType), input.name)()
case (input, _) => input
}

p.copy(child = logical.Project(castedChildOutput, child))
}
}
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ private[sql] class HiveSessionCatalog(
val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions
val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions
val CreateTables: Rule[LogicalPlan] = metastoreCatalog.CreateTables
val PreInsertionCasts: Rule[LogicalPlan] = metastoreCatalog.PreInsertionCasts

override def refreshTable(name: TableIdentifier): Unit = {
metastoreCatalog.refreshTable(name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession)
catalog.ParquetConversions ::
catalog.OrcConversions ::
catalog.CreateTables ::
catalog.PreInsertionCasts ::
PreInsertCastAndRename ::
PreprocessTableInsertion ::
DataSourceAnalysis ::
(if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,27 +325,6 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
}
}

test("Detect table partitioning with correct partition order") {
withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
sql("CREATE TABLE source (id bigint, part2 string, part1 string, data string)")
val data = (1 to 10).map(i => (i, if ((i % 2) == 0) "even" else "odd", "p", s"data-$i"))
.toDF("id", "part2", "part1", "data")

data.write.insertInto("source")
checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq)

// the original data with part1 and part2 at the end
val expected = data.select("id", "data", "part1", "part2")

sql(
"""CREATE TABLE partitioned (id bigint, data string)
|PARTITIONED BY (part1 string, part2 string)""".stripMargin)
spark.table("source").write.insertInto("partitioned")

checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq)
}
}

private def testPartitionedHiveSerDeTable(testName: String)(f: String => Unit): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better have this test. We can change the expected answers to respect the semantic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm. We do not need this test anymore because the semantic has been changed (we will not adjust the column ordering).

test(s"Hive SerDe table - $testName") {
val hiveTable = "hive_table"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ abstract class HiveComparisonTest
queryString.replace("../../data", testDataPath))
val containsCommands = originalQuery.analyzed.collectFirst {
case _: Command => ()
case _: InsertIntoTable => ()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not have InsertIntoTable inside plan tree when run hive query, looks like this PR breaks something, need some more time to investigate it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, it's because I removed the PreInsertionCasts rule, which turns InsertIntoTable to InsertIntoHiveTable. This conversion doesn't matter, as hive planner will plan InsertIntoTable into physical InsertIntoHiveTable.

So adding a case here is a reasonable fix.

case _: LogicalInsertIntoHiveTable => ()
}.nonEmpty

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1033,41 +1033,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
sql("SELECT * FROM boom").queryExecution.analyzed
}

test("SPARK-3810: PreInsertionCasts static partitioning support") {
val analyzedPlan = {
loadTestTable("srcpart")
sql("DROP TABLE IF EXISTS withparts")
sql("CREATE TABLE withparts LIKE srcpart")
sql("INSERT INTO TABLE withparts PARTITION(ds='1', hr='2') SELECT key, value FROM src")
.queryExecution.analyzed
}

assertResult(1, "Duplicated project detected\n" + analyzedPlan) {
analyzedPlan.collect {
case _: Project => ()
}.size
}
}

test("SPARK-3810: PreInsertionCasts dynamic partitioning support") {
val analyzedPlan = {
loadTestTable("srcpart")
sql("DROP TABLE IF EXISTS withparts")
sql("CREATE TABLE withparts LIKE srcpart")
sql("SET hive.exec.dynamic.partition.mode=nonstrict")

sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart")
sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src")
.queryExecution.analyzed
}

assertResult(1, "Duplicated project detected\n" + analyzedPlan) {
analyzedPlan.collect {
case _: Project => ()
}.size
}
}

test("parse HQL set commands") {
// Adapted from its SQL counterpart.
val testKey = "spark.sql.key.usedfortestonly"
Expand Down
Loading