From 65f253ea9a731135a952c04253e15cf5eff59151 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 5 Apr 2016 13:35:50 -0700 Subject: [PATCH 1/6] SPARK-14543: Update InsertInto column resolution. This combines Hive's pre-insertion casts (without renames) that handle partitioning with the pre-insertion casts/renames in core. The combined rule, ResolveOutputColumns, will resolve columns by name or by position. Resolving by position will detect cases where the number of columns is incorrect or where the input columns are a permutation of the output columns and fail. When resolving by name, each output column is located by name in the child plan. This handles cases where a subset of a data frame is written out. --- .../sql/catalyst/analysis/Analyzer.scala | 117 ++++++++++++++++- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 5 +- .../plans/logical/basicLogicalOperators.scala | 23 +++- .../sql/catalyst/parser/PlanParserSuite.scala | 8 +- .../apache/spark/sql/DataFrameWriter.scala | 12 +- .../datasources/DataSourceStrategy.scala | 6 +- .../sql/execution/datasources/rules.scala | 57 +------- .../spark/sql/internal/SessionState.scala | 3 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 66 ++-------- .../spark/sql/hive/HiveSessionCatalog.scala | 1 - .../spark/sql/hive/HiveSessionState.scala | 2 - .../spark/sql/hive/HiveStrategies.scala | 4 +- .../hive/execution/CreateTableAsSelect.scala | 4 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 122 +++++++++++++++++- .../hive/execution/HiveComparisonTest.scala | 1 + 16 files changed, 300 insertions(+), 133 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1313a011c69c..6f7506b4b04d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -101,6 +101,7 @@ class Analyzer( ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: + ResolveOutputColumns :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: @@ -445,7 +446,7 @@ class Analyzer( } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => + case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _, _) if child.resolved => val table = lookupTableFromCatalog(u) // adding the table's partitions or validate the query's partition info table match { @@ -499,6 +500,120 @@ class Analyzer( } } + object ResolveOutputColumns extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.transform { + case ins @ InsertIntoTable(relation: LogicalPlan, partition, _, _, _, _) + if ins.childrenResolved && !ins.resolved => + resolveOutputColumns(ins, expectedColumns(relation, partition), relation.toString) + } + + private def resolveOutputColumns( + insertInto: InsertIntoTable, + columns: Seq[Attribute], + relation: String) = { + val resolved = if (insertInto.isMatchByName) { + projectAndCastOutputColumns(columns, insertInto.child, relation) + } else { + castAndRenameOutputColumns(columns, insertInto.child, relation) + } + + if (resolved == insertInto.child.output) { + insertInto + } else { + insertInto.copy(child = Project(resolved, insertInto.child)) + } + } + + /** + * Resolves output columns by input column name, adding casts if necessary. + */ + private def projectAndCastOutputColumns( + output: Seq[Attribute], + data: LogicalPlan, + relation: String): Seq[NamedExpression] = { + output.map { col => + data.resolveQuoted(col.name, resolver) match { + case Some(inCol) if col.dataType != inCol.dataType => + Alias(UpCast(inCol, col.dataType, Seq()), col.name)() + case Some(inCol) => inCol + case None => + throw new AnalysisException( + s"Cannot resolve ${col.name} in ${data.output.mkString(",")}") + } + } + } + + private def castAndRenameOutputColumns( + output: Seq[Attribute], + data: LogicalPlan, + relation: String): Seq[NamedExpression] = { + val outputNames = output.map(_.name) + // incoming expressions may not have names + val inputNames = data.output.flatMap(col => Option(col.name)) + if (output.size > data.output.size) { + // always a problem + throw new AnalysisException( + s"""Not enough data columns to write into $relation: + |Data columns: ${data.output.mkString(",")} + |Table columns: ${outputNames.mkString(",")}""".stripMargin) + } else if (output.size < data.output.size) { + if (outputNames.toSet.subsetOf(inputNames.toSet)) { + throw new AnalysisException( + s"""Table column names are a subset of the input data columns: + |Data columns: ${inputNames.mkString(",")} + |Table columns: ${outputNames.mkString(",")} + |To write a subset of the columns by name, use df.write.byName.insertInto(...)""" + .stripMargin) + } else { + // be conservative and fail if there are too many columns + throw new AnalysisException( + s"""Extra data columns to write into $relation: + |Data columns: ${data.output.mkString(",")} + |Table columns: ${outputNames.mkString(",")}""".stripMargin) + } + } else { + // check for reordered names and warn. this may be on purpose, so it isn't an error. + if (outputNames.toSet == inputNames.toSet && outputNames != inputNames) { + logWarning( + s"""Data column names match the table in a different order: + |Data columns: ${inputNames.mkString(",")} + |Table columns: ${outputNames.mkString(",")} + |To map columns by name, use df.write.byName.insertInto(...)""".stripMargin) + } + } + + data.output.zip(output).map { + case (in, out) if !in.dataType.sameType(out.dataType) => + Alias(Cast(in, out.dataType), out.name)() + case (in, out) if in.name != out.name => + Alias(in, out.name)() + case (in, _) => in + } + } + + private def expectedColumns( + data: LogicalPlan, + partitionData: Map[String, Option[String]]): Seq[Attribute] = { + data match { + case partitioned: CatalogRelation => + val tablePartitionNames = partitioned.catalogTable.partitionColumns.map(_.name) + val (inputPartCols, dataColumns) = data.output.partition { attr => + tablePartitionNames.contains(attr.name) + } + // Get the dynamic partition columns in partition order + val dynamicNames = tablePartitionNames.filter( + name => partitionData.getOrElse(name, None).isEmpty) + val dynamicPartCols = dynamicNames.map { name => + inputPartCols.find(_.name == name).getOrElse( + throw new AnalysisException(s"Cannot find partition column $name")) + } + + dataColumns ++ dynamicPartCols + case _ => data.output + } + } + } + /** * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from * a logical plan node's children. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 2ca990d19a2c..076ae596d99e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -367,7 +367,7 @@ package object dsl { def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( analysis.UnresolvedRelation(TableIdentifier(tableName)), - Map.empty, logicalPlan, overwrite, false) + Map.empty, logicalPlan, overwrite, ifNotExists = false, isMatchByName = false) def as(alias: String): LogicalPlan = logicalPlan match { case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 2d7d0f903295..904963b3fcad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -211,8 +211,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { UnresolvedRelation(tableIdent, None), partitionKeys, query, - ctx.OVERWRITE != null, - ctx.EXISTS != null) + overwrite = ctx.OVERWRITE != null, + ifNotExists = ctx.EXISTS != null, + isMatchByName = false /* SQL always matches by position */) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 732b0d7919c3..d89a7333dc37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -348,7 +348,8 @@ case class InsertIntoTable( partition: Map[String, Option[String]], child: LogicalPlan, overwrite: Boolean, - ifNotExists: Boolean) + ifNotExists: Boolean, + isMatchByName: Boolean) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil @@ -358,18 +359,26 @@ case class InsertIntoTable( if (table.output.isEmpty) { None } else { - val numDynamicPartitions = partition.values.count(_.isEmpty) + val dynamicPartitionNames = partition.filter { + case (name, Some(_)) => false + case (name, None) => true + }.keySet val (partitionColumns, dataColumns) = table.output .partition(a => partition.keySet.contains(a.name)) - Some(dataColumns ++ partitionColumns.takeRight(numDynamicPartitions)) + Some(dataColumns ++ partitionColumns.filter(col => dynamicPartitionNames.contains(col.name))) } } assert(overwrite || !ifNotExists) - override lazy val resolved: Boolean = childrenResolved && expectedColumns.forall { expected => - child.output.size == expected.size && child.output.zip(expected).forall { - case (childAttr, tableAttr) => - DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) + override lazy val resolved: Boolean = childrenResolved && { + expectedColumns match { + case Some(expected) => + child.output.size == expected.size && child.output.zip(expected).forall { + case (childAttr, tableAttr) => + childAttr.name == tableAttr.name && // required by some relations + DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) + } + case None => true } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 25d87d93bec4..e9c75e30b71e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -191,7 +191,7 @@ class PlanParserSuite extends PlanTest { partition: Map[String, Option[String]], overwrite: Boolean = false, ifNotExists: Boolean = false): LogicalPlan = - InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists, isMatchByName = false) // Single inserts assertEqual(s"insert overwrite table s $sql", @@ -209,9 +209,11 @@ class PlanParserSuite extends PlanTest { val plan2 = table("t").where('x > 5).select(star()) assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", InsertIntoTable( - table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( + table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false, + isMatchByName = false).union( InsertIntoTable( - table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) + table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false, + isMatchByName = false))) } test("aggregation") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 6f5fb69ea377..adc3726b7a1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -408,7 +408,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { partitions.getOrElse(Map.empty[String, Option[String]]), input, overwrite, - ifNotExists = false)).toRdd + ifNotExists = false, + isMatchByName = matchOutputColumnsByName)).toRdd } private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => @@ -464,6 +465,15 @@ final class DataFrameWriter private[sql](df: DataFrame) { } } + def byName: DataFrameWriter = { + extraOptions.put("matchByName", "true") + this + } + + private def matchOutputColumnsByName: Boolean = { + extraOptions.getOrElse("matchByName", "false").toBoolean + } + /** * Saves the content of the [[DataFrame]] as the specified table. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 0494fafb0e42..e2d9cc95515b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -46,7 +46,7 @@ import org.apache.spark.unsafe.types.UTF8String private[sql] object DataSourceAnalysis extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transform { case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: HadoopFsRelation, _, _), part, query, overwrite, false) + l @ LogicalRelation(t: HadoopFsRelation, _, _), part, query, overwrite, false, _) if query.resolved && t.schema.asNullable == query.schema.asNullable => // Sanity checks @@ -110,7 +110,7 @@ private[sql] class FindDataSourceTable(sparkSession: SparkSession) extends Rule[ } override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ logical.InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) + case i @ logical.InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _, _) if DDLUtils.isDatasourceTable(s.metadata) => i.copy(table = readDataSourceTable(sparkSession, s.metadata)) @@ -152,7 +152,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _), - part, query, overwrite, false) if part.isEmpty => + part, query, overwrite, false, _) if part.isEmpty => ExecutedCommandExec(InsertIntoDataSource(l, query, overwrite)) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index b622f859413a..d6e8bcff303e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -49,55 +49,6 @@ private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[Lo } } -/** - * A rule to do pre-insert data type casting and field renaming. Before we insert into - * an [[InsertableRelation]], we will use this rule to make sure that - * the columns to be inserted have the correct data type and fields have the correct names. - */ -private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Wait until children are resolved. - case p: LogicalPlan if !p.childrenResolved => p - - // We are inserting into an InsertableRelation or HadoopFsRelation. - case i @ InsertIntoTable( - l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _), _, child, _, _) => - // First, make sure the data to be inserted have the same number of fields with the - // schema of the relation. - if (l.output.size != child.output.size) { - sys.error( - s"$l requires that the query in the SELECT clause of the INSERT INTO/OVERWRITE " + - s"statement generates the same number of columns as its schema.") - } - castAndRenameChildOutput(i, l.output, child) - } - - /** If necessary, cast data types and rename fields to the expected types and names. */ - def castAndRenameChildOutput( - insertInto: InsertIntoTable, - expectedOutput: Seq[Attribute], - child: LogicalPlan): InsertIntoTable = { - val newChildOutput = expectedOutput.zip(child.output).map { - case (expected, actual) => - val needCast = !expected.dataType.sameType(actual.dataType) - // We want to make sure the filed names in the data to be inserted exactly match - // names in the schema. - val needRename = expected.name != actual.name - (needCast, needRename) match { - case (true, _) => Alias(Cast(actual, expected.dataType), expected.name)() - case (false, true) => Alias(actual, expected.name)() - case (_, _) => actual - } - } - - if (newChildOutput == child.output) { - insertInto - } else { - insertInto.copy(child = Project(newChildOutput, child)) - } - } -} - /** * A rule to do various checks before inserting into or writing to a data source table. */ @@ -110,7 +61,7 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) plan.foreach { case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation, _, _), - partition, query, overwrite, ifNotExists) => + partition, query, overwrite, ifNotExists, _) => // Right now, we do not support insert into a data source table with partition specs. if (partition.nonEmpty) { failAnalysis(s"Insert into a partition is not allowed because $l is not partitioned.") @@ -128,7 +79,7 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) } case logical.InsertIntoTable( - LogicalRelation(r: HadoopFsRelation, _, _), part, query, overwrite, _) => + LogicalRelation(r: HadoopFsRelation, _, _), part, query, overwrite, _, _) => // We need to make sure the partition columns specified by users do match partition // columns of the relation. val existingPartitionColumns = r.partitionSchema.fieldNames.toSet @@ -156,11 +107,11 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) // OK } - case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _) => + case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _, _) => // The relation in l is not an InsertableRelation. failAnalysis(s"$l does not allow insertion.") - case logical.InsertIntoTable(t, _, _, _, _) => + case logical.InsertIntoTable(t, _, _, _, _, _) => if (!t.isInstanceOf[LeafNode] || t == OneRowRelation || t.isInstanceOf[LocalRelation]) { failAnalysis(s"Inserting into an RDD-based table is not allowed.") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index f0b8a83dee8c..46c5f634dd11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.AnalyzeTable -import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, PreInsertCastAndRename, ResolveDataSource} +import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, ResolveDataSource} import org.apache.spark.sql.util.ExecutionListenerManager @@ -109,7 +109,6 @@ private[sql] class SessionState(sparkSession: SparkSession) { lazy val analyzer: Analyzer = { new Analyzer(catalog, conf) { override val extendedResolutionRules = - PreInsertCastAndRename :: new FindDataSourceTable(sparkSession) :: DataSourceAnalysis :: (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 4c528fbbbeef..0b524a76005a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -370,16 +370,19 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log plan transformUp { // Write path - case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists, byName) // Inserting into partitioned table is not supported in Parquet data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) => - InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists) + InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists, + byName) // Write path - case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + case InsertIntoHiveTable(r: MetastoreRelation, + partition, child, overwrite, ifNotExists, byName) // Inserting into partitioned table is not supported in Parquet data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) => - InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists) + InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists, + byName) // Read path case relation: MetastoreRelation if shouldConvertMetastoreParquet(relation) => @@ -414,16 +417,17 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log plan transformUp { // Write path - case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists, byName) // Inserting into partitioned table is not supported in Orc data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) => - InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists) + InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists, byName) // Write path - case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + case InsertIntoHiveTable(r: MetastoreRelation, + partition, child, overwrite, ifNotExists, byName) // Inserting into partitioned table is not supported in Orc data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) => - InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists) + InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists, byName) // Read path case relation: MetastoreRelation if shouldConvertMetastoreOrc(relation) => @@ -492,49 +496,6 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } } - - /** - * Casts input data to correct data types according to table definition before inserting into - * that table. - */ - object PreInsertionCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transform { - // Wait until children are resolved. - case p: LogicalPlan if !p.childrenResolved => p - - case p @ InsertIntoTable(table: MetastoreRelation, _, child, _, _) => - castChildOutput(p, table, child) - } - - def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) - : LogicalPlan = { - val childOutputDataTypes = child.output.map(_.dataType) - val numDynamicPartitions = p.partition.values.count(_.isEmpty) - val tableOutputDataTypes = - (table.attributes ++ table.partitionKeys.takeRight(numDynamicPartitions)) - .take(child.output.length).map(_.dataType) - - if (childOutputDataTypes == tableOutputDataTypes) { - InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) - } else if (childOutputDataTypes.size == tableOutputDataTypes.size && - childOutputDataTypes.zip(tableOutputDataTypes) - .forall { case (left, right) => left.sameType(right) }) { - // If both types ignoring nullability of ArrayType, MapType, StructType are the same, - // use InsertIntoHiveTable instead of InsertIntoTable. - InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) - } else { - // Only do the casting when child output data types differ from table output data types. - val castedChildOutput = child.output.zip(table.output).map { - case (input, output) if input.dataType != output.dataType => - Alias(Cast(input, output.dataType), input.name)() - case (input, _) => input - } - - p.copy(child = logical.Project(castedChildOutput, child)) - } - } - } - } /** @@ -577,7 +538,8 @@ private[hive] case class InsertIntoHiveTable( partition: Map[String, Option[String]], child: LogicalPlan, overwrite: Boolean, - ifNotExists: Boolean) + ifNotExists: Boolean, + matchByName: Boolean) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 4f8aac8c2fcd..2f6a2207855e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -87,7 +87,6 @@ private[sql] class HiveSessionCatalog( val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions val CreateTables: Rule[LogicalPlan] = metastoreCatalog.CreateTables - val PreInsertionCasts: Rule[LogicalPlan] = metastoreCatalog.PreInsertionCasts override def refreshTable(name: TableIdentifier): Unit = { metastoreCatalog.refreshTable(name) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 46579ecd85ca..9ee029100f5a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -65,8 +65,6 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) catalog.ParquetConversions :: catalog.OrcConversions :: catalog.CreateTables :: - catalog.PreInsertionCasts :: - PreInsertCastAndRename :: DataSourceAnalysis :: (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 71b180e55b58..7d1daa496f09 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -43,11 +43,11 @@ private[hive] trait HiveStrategies { object DataSinks extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable( - table: MetastoreRelation, partition, child, overwrite, ifNotExists) => + table: MetastoreRelation, partition, child, overwrite, ifNotExists, _) => execution.InsertIntoHiveTable( table, partition, planLater(child), overwrite, ifNotExists) :: Nil case hive.InsertIntoHiveTable( - table: MetastoreRelation, partition, child, overwrite, ifNotExists) => + table: MetastoreRelation, partition, child, overwrite, ifNotExists, _) => execution.InsertIntoHiveTable( table, partition, planLater(child), overwrite, ifNotExists) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 9dfbafae872f..7a999c291a41 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -86,8 +86,8 @@ case class CreateTableAsSelect( throw new AnalysisException(s"$tableIdentifier already exists.") } } else { - sparkSession.executePlan(InsertIntoTable( - metastoreRelation, Map(), query, overwrite = true, ifNotExists = false)).toRdd + sparkSession.executePlan(InsertIntoTable(metastoreRelation, Map(), query, + overwrite = true, ifNotExists = false, isMatchByName = false)).toRdd } Seq.empty[Row] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index b25684562075..26c1cb49c1f5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -313,8 +313,128 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef val data = (1 to 10).map(i => (i.toLong, s"data-$i")).toDF("id", "data") val logical = InsertIntoTable(spark.table("partitioned").logicalPlan, - Map("part" -> None), data.logicalPlan, overwrite = false, ifNotExists = false) + Map("part" -> None), data.logicalPlan, overwrite = false, ifNotExists = false, + isMatchByName = false) assert(!logical.resolved, "Should not resolve: missing partition data") } } + + test("Insert unnamed expressions by position") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, part string)") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + + val expected = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "part") + val data = expected.select("id", "part") + + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + + // should be able to insert an expression when NOT mapping columns by name + spark.table("source").selectExpr("id", "part", "CONCAT('data-', id)") + .write.insertInto("partitioned") + checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) + } + } + + test("Insert expression by name") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, part string)") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + + val expected = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "part") + val data = expected.select("id", "part") + + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + + intercept[AnalysisException] { + // also a problem when mapping by name + spark.table("source").selectExpr("id", "part", "CONCAT('data-', id)") + .write.byName.insertInto("partitioned") + } + + // should be able to insert an expression using AS when mapping columns by name + spark.table("source").selectExpr("id", "part", "CONCAT('data-', id) as data") + .write.byName.insertInto("partitioned") + checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) + } + } + + test("Reject missing columns") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, part string)") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + + intercept[AnalysisException] { + spark.table("source").write.insertInto("partitioned") + } + + intercept[AnalysisException] { + // also a problem when mapping by name + spark.table("source").write.byName.insertInto("partitioned") + } + } + } + + test("Reject extra columns") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, data string, extra string, part string)") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + + intercept[AnalysisException] { + spark.table("source").write.insertInto("partitioned") + } + + val data = (1 to 10) + .map(i => (i, s"data-$i", s"${i * i}", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "extra", "part") + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + + spark.table("source").write.byName.insertInto("partitioned") + + val expected = data.select("id", "data", "part") + checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) + } + } + + test("Ignore names when writing by position") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, part string, data string)") // part, data transposed + sql("CREATE TABLE destination (id bigint, data string, part string)") + + val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "part") + + // write into the reordered table by name + data.write.byName.insertInto("source") + checkAnswer(sql("SELECT id, data, part FROM source"), data.collect().toSeq) + + val expected = data.select($"id", $"part" as "data", $"data" as "part") + + // this produces a warning, but writes src.part -> dest.data and src.data -> dest.part + spark.table("source").write.insertInto("destination") + checkAnswer(sql("SELECT id, data, part FROM destination"), expected.collect().toSeq) + } + } + + test("Reorder columns by name") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (data string, part string, id bigint)") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + + val data = (1 to 10).map(i => (s"data-$i", if ((i % 2) == 0) "even" else "odd", i)) + .toDF("data", "part", "id") + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + + spark.table("source").write.byName.insertInto("partitioned") + + val expected = data.select("id", "data", "part") + checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index b12f3aafefb8..319e4e15bf62 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -348,6 +348,7 @@ abstract class HiveComparisonTest val containsCommands = originalQuery.analyzed.collectFirst { case _: Command => () case _: LogicalInsertIntoHiveTable => () + case _: InsertIntoTable => () }.nonEmpty if (containsCommands) { From d1339490c8421539a2f800b0b89562493414e794 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 20 Apr 2016 14:14:44 -0700 Subject: [PATCH 2/6] SPARK-14543: Fix bad SQL in HiveQuerySuite test. --- .../sql/hive/execution/HiveQuerySuite.scala | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2aaaaadb6afa..690538b1f31a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1039,7 +1039,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("SET hive.exec.dynamic.partition.mode=nonstrict") sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") - sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value, ds, hr FROM srcpart") .queryExecution.analyzed } @@ -1050,6 +1050,26 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + test("SPARK-14543: AnalysisException for missing partition columns") { + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") + + intercept[AnalysisException] { + // src doesn't have ds and hr partition columns + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") + .queryExecution.analyzed + } + + intercept[AnalysisException] { + // ds and hr partition columns aren't selected + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM srcpart") + .queryExecution.analyzed + } + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" From 2b1193504412e5942642f45dfc641039f190f310 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Thu, 21 Apr 2016 16:18:08 -0700 Subject: [PATCH 3/6] SPARK-14543: Update InsertSuite test for too few columns. This PR now catches this problem during analysis and has a better error message. This commit updates the test for the new message and exception type. --- .../scala/org/apache/spark/sql/sources/InsertSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 854fec5b22f7..9e07e66ac48d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -87,15 +87,15 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } test("SELECT clause generating a different number of columns is not allowed.") { - val message = intercept[RuntimeException] { + val message = intercept[AnalysisException] { sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt """.stripMargin) }.getMessage assert( - message.contains("generates the same number of columns as its schema"), - "SELECT clause generating a different number of columns should not be not allowed." + message.contains("Not enough data columns to write"), + "SELECT clause must generate all of a table's columns to write" ) } From 3a24e36ceb9b815e8c933723e22dbdbfed35c840 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Mon, 9 May 2016 11:01:19 -0700 Subject: [PATCH 4/6] SPARK-14543: Update new InsertIntoTable parameter to Map. Adding new argumetns to InsertIntoTable requires changes to several files. Instead of adding a long list of optional args, this adds an options map, like the one passed to DataSource. Future options can be added and used only where they are needed. --- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../plans/logical/basicLogicalOperators.scala | 6 ++++- .../sql/catalyst/parser/PlanParserSuite.scala | 6 ++--- .../apache/spark/sql/DataFrameWriter.scala | 6 +---- .../spark/sql/hive/HiveMetastoreCatalog.scala | 23 +++++++++++-------- .../hive/execution/CreateTableAsSelect.scala | 2 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 2 +- 8 files changed, 26 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 076ae596d99e..a5c2fd038ace 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -367,7 +367,7 @@ package object dsl { def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( analysis.UnresolvedRelation(TableIdentifier(tableName)), - Map.empty, logicalPlan, overwrite, ifNotExists = false, isMatchByName = false) + Map.empty, logicalPlan, overwrite, ifNotExists = false, Map.empty) def as(alias: String): LogicalPlan = logicalPlan match { case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 904963b3fcad..4d2ee6f25a3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -213,7 +213,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { query, overwrite = ctx.OVERWRITE != null, ifNotExists = ctx.EXISTS != null, - isMatchByName = false /* SQL always matches by position */) + Map.empty /* SQL always matches by position */) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d89a7333dc37..3b4f66ee5685 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -349,12 +349,16 @@ case class InsertIntoTable( child: LogicalPlan, overwrite: Boolean, ifNotExists: Boolean, - isMatchByName: Boolean) + options: Map[String, String]) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = Seq.empty + private[spark] def isMatchByName: Boolean = { + options.get("matchByName").map(_.toBoolean).getOrElse(false) + } + private[spark] lazy val expectedColumns = { if (table.output.isEmpty) { None diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index e9c75e30b71e..015b8b413f64 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -191,7 +191,7 @@ class PlanParserSuite extends PlanTest { partition: Map[String, Option[String]], overwrite: Boolean = false, ifNotExists: Boolean = false): LogicalPlan = - InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists, isMatchByName = false) + InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists, Map.empty) // Single inserts assertEqual(s"insert overwrite table s $sql", @@ -210,10 +210,10 @@ class PlanParserSuite extends PlanTest { assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", InsertIntoTable( table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false, - isMatchByName = false).union( + Map.empty).union( InsertIntoTable( table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false, - isMatchByName = false))) + Map.empty))) } test("aggregation") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index adc3726b7a1c..70ec1a730939 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -409,7 +409,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { input, overwrite, ifNotExists = false, - isMatchByName = matchOutputColumnsByName)).toRdd + options = extraOptions.toMap)).toRdd } private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => @@ -470,10 +470,6 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } - private def matchOutputColumnsByName: Boolean = { - extraOptions.getOrElse("matchByName", "false").toBoolean - } - /** * Saves the content of the [[DataFrame]] as the specified table. * diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0b524a76005a..c9c025c2b879 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.execution.command.CreateTableAsSelectLogicalPlan @@ -370,19 +369,20 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log plan transformUp { // Write path - case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists, byName) + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists, + options) // Inserting into partitioned table is not supported in Parquet data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) => InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists, - byName) + options) // Write path case InsertIntoHiveTable(r: MetastoreRelation, - partition, child, overwrite, ifNotExists, byName) + partition, child, overwrite, ifNotExists, options) // Inserting into partitioned table is not supported in Parquet data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) => InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists, - byName) + options) // Read path case relation: MetastoreRelation if shouldConvertMetastoreParquet(relation) => @@ -417,17 +417,20 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log plan transformUp { // Write path - case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists, byName) + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists, + options) // Inserting into partitioned table is not supported in Orc data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) => - InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists, byName) + InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists, + options) // Write path case InsertIntoHiveTable(r: MetastoreRelation, - partition, child, overwrite, ifNotExists, byName) + partition, child, overwrite, ifNotExists, options) // Inserting into partitioned table is not supported in Orc data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) => - InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists, byName) + InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists, + options) // Read path case relation: MetastoreRelation if shouldConvertMetastoreOrc(relation) => @@ -539,7 +542,7 @@ private[hive] case class InsertIntoHiveTable( child: LogicalPlan, overwrite: Boolean, ifNotExists: Boolean, - matchByName: Boolean) + options: Map[String, String]) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 7a999c291a41..6b298cc15a64 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -87,7 +87,7 @@ case class CreateTableAsSelect( } } else { sparkSession.executePlan(InsertIntoTable(metastoreRelation, Map(), query, - overwrite = true, ifNotExists = false, isMatchByName = false)).toRdd + overwrite = true, ifNotExists = false, Map.empty)).toRdd } Seq.empty[Row] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 26c1cb49c1f5..0a17df9541ad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -314,7 +314,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef val logical = InsertIntoTable(spark.table("partitioned").logicalPlan, Map("part" -> None), data.logicalPlan, overwrite = false, ifNotExists = false, - isMatchByName = false) + Map("matchByName" -> "true")) assert(!logical.resolved, "Should not resolve: missing partition data") } } From 3cdbfa83d4b064fbaf9d50b3bec51f4645dad0fb Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 22 Apr 2016 14:15:06 -0700 Subject: [PATCH 5/6] SPARK-15420: Detect sorting and do not sort in WriteContainers. This avoids an extra sort in the WriterContainer when data has already been sorted as part of the query plan. This fixes writes for both HadoopFsRelation and MetastoreRelation. --- .../sql/catalyst/expressions/SortOrder.scala | 16 ++ .../InsertIntoHadoopFsRelation.scala | 7 +- .../datasources/WriterContainer.scala | 139 +++++++++++------- .../hive/execution/InsertIntoHiveTable.scala | 7 +- .../spark/sql/hive/hiveWriterContainers.scala | 68 +++++++-- 5 files changed, 169 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 42a8be6b1b1e..ab562e426dcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator @@ -63,6 +64,21 @@ case class SortOrder(child: Expression, direction: SortDirection) def isAscending: Boolean = direction == Ascending } +// TODO: should this be an implicit class somewhere? +object SortOrder { + def satisfies(order: Seq[SortOrder], distribution: Distribution): Boolean = { + distribution match { + case c @ ClusteredDistribution(exprs) => + // Zip discards extra order by expressions + (order.size >= exprs.size) && exprs.zip(order.map(_.child)).forall { + case (clusterExpr, orderExpr) => clusterExpr.semanticEquals(orderExpr) + case _ => false + } + case _ => false + } + } +} + /** * An expression to generate a 64-bit long prefix used in sorting. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 4921e4ca6bb7..bfad55936e5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -130,9 +130,10 @@ private[sql] case class InsertIntoHadoopFsRelation( partitionColumns = partitionColumns, dataColumns = dataColumns, inputSchema = query.output, - PartitioningUtils.DEFAULT_PARTITION_NAME, - sparkSession.conf.get(SQLConf.PARTITION_MAX_FILES), - isAppend) + defaultPartitionName = PartitioningUtils.DEFAULT_PARTITION_NAME, + incomingOrder = queryExecution.sparkPlan.outputOrdering, + maxOpenFiles = sparkSession.conf.get(SQLConf.PARTITION_MAX_FILES), + isAppend = isAppend) } // This call shouldn't be put into the `try` block below because it only initializes and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 7e12bbb2128b..06908dd3a6ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -29,7 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.internal.SQLConf @@ -300,6 +300,7 @@ private[sql] class DynamicPartitionWriterContainer( dataColumns: Seq[Attribute], inputSchema: Seq[Attribute], defaultPartitionName: String, + incomingOrder: Seq[SortOrder], maxOpenFiles: Int, isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { @@ -371,13 +372,9 @@ private[sql] class DynamicPartitionWriterContainer( // We should first sort by partition columns, then bucket id, and finally sorting columns. val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns - val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema) - val sortingKeySchema = StructType(sortingExpressions.map { - case a: Attribute => StructField(a.name, a.dataType, a.nullable) - // The sorting expressions are all `Attribute` except bucket id. - case _ => StructField("bucketId", IntegerType, nullable = false) - }) + // If the data is already sorted correctly, avoid sorting it again + val isSorted = SortOrder.satisfies(incomingOrder, ClusteredDistribution(sortingExpressions)) // Returns the data columns to be written given an input row val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) @@ -386,53 +383,15 @@ private[sql] class DynamicPartitionWriterContainer( val getPartitionString = UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) - // Sorts the data before write, so that we only need one writer at the same time. - // TODO: inject a local sort operator in planning. - val sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) - - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } - logInfo(s"Sorting complete. Writing out partition files one at a time.") - - val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { - identity - } else { - UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { - case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) - }) - } - - val sortedIterator = sorter.sortedIterator() - // If anything below fails, we should abort the task. var currentWriter: OutputWriter = null try { Utils.tryWithSafeFinallyAndFailureCallbacks { - var currentKey: UnsafeRow = null - while (sortedIterator.next()) { - val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] - if (currentKey != nextKey) { - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } - currentKey = nextKey.copy() - logDebug(s"Writing partition: $currentKey") - - currentWriter = newOutputWriter(currentKey, getPartitionString) - } - currentWriter.writeInternal(sortedIterator.getValue) - } - if (currentWriter != null) { - currentWriter.close() - currentWriter = null + if (isSorted) { + sortedWrite(iterator) + } else { + // TODO: inject a local sort operator in planning. + sortAndWrite(iterator) } commitTask() @@ -446,6 +405,86 @@ private[sql] class DynamicPartitionWriterContainer( case t: Throwable => throw new SparkException("Task failed while writing rows", t) } + + def sortAndWrite(iterator: Iterator[InternalRow]): Unit = { + val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema) + + // Sorts the data before write, so that we only need one writer at the same time. + val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { + identity + } else { + UnsafeProjection.create( + sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { + case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) + }) + } + + val sortingKeySchema = StructType(sortingExpressions.map { + case a: Attribute => StructField(a.name, a.dataType, a.nullable) + // The sorting expressions are all `Attribute` except bucket id. + case _ => StructField("bucketId", IntegerType, nullable = false) + }) + + val sorter = new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + val sortedIterator = sorter.sortedIterator() + + var currentKey: UnsafeRow = null + while (sortedIterator.next()) { + val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") + + currentWriter = newOutputWriter(currentKey, getPartitionString) + } + currentWriter.writeInternal(sortedIterator.getValue) + } + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + } + + def sortedWrite(iterator: Iterator[InternalRow]): Unit = { + val getBucketingKey: InternalRow => InternalRow = + UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length), inputSchema) + var currentKey: UnsafeRow = null + while (iterator.hasNext) { + val currentRow = iterator.next() + val nextKey = getBucketingKey(currentRow).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") + + currentWriter = newOutputWriter(currentKey, getPartitionString) + } + currentWriter.writeInternal(getOutputRow(currentRow)) + } + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + } } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3805674d3958..1a942c543eee 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -34,7 +34,8 @@ import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} @@ -218,10 +219,14 @@ case class InsertIntoHiveTable( val writerContainer = if (numDynamicPartitions > 0) { val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) + val partitionClustering = ClusteredDistribution(child.output.takeRight(numDynamicPartitions)) + val isSorted = SortOrder.satisfies(child.outputOrdering, partitionClustering) + new SparkHiveDynamicPartitionWriterContainer( jobConf, fileSinkConf, dynamicPartColNames, + isSorted, child.output, table) } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 706fdbc2604f..556d95a60375 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -218,6 +218,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( jobConf: JobConf, fileSinkConf: FileSinkDesc, dynamicPartColNames: Array[String], + isSorted: Boolean, inputSchema: Seq[Attribute], table: MetastoreRelation) extends SparkHiveWriterContainer(jobConf, fileSinkConf, inputSchema, table) { @@ -277,6 +278,20 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( // If anything below fails, we should abort the task. try { + if (isSorted) { + sortedWrite(iterator) + } else { + sortAndWrite(iterator) + } + commit() + } catch { + case cause: Throwable => + logError("Aborting task.", cause) + abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } + + def sortAndWrite(iterator: Iterator[InternalRow]): Unit = { val sorter: UnsafeKVExternalSorter = new UnsafeKVExternalSorter( StructType.fromAttributes(partitionOutput), StructType.fromAttributes(dataOutput), @@ -291,6 +306,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( } logInfo(s"Sorting complete. Writing out partition files one at a time.") + val sortedIterator = sorter.sortedIterator() var currentKey: InternalRow = null var currentWriter: FileSinkOperator.RecordWriter = null @@ -305,29 +321,53 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( currentWriter = newOutputWriter(currentKey) } - var i = 0 - while (i < fieldOIs.length) { - outputData(i) = if (sortedIterator.getValue.isNullAt(i)) { - null - } else { - wrappers(i)(sortedIterator.getValue.get(i, dataTypes(i))) + currentWriter.write(serialize(sortedIterator.getValue)) + } + } finally { + if (currentWriter != null) { + currentWriter.close(false) + } + } + } + + def sortedWrite(iterator: Iterator[InternalRow]): Unit = { + var currentKey: InternalRow = null + var currentWriter: FileSinkOperator.RecordWriter = null + try { + while (iterator.hasNext) { + val inputRow = iterator.next() + val rowKey = getPartitionKey(inputRow) + if (currentKey != rowKey) { + if (currentWriter != null) { + currentWriter.close(false) } - i += 1 + currentKey = rowKey + logDebug(s"Writing partition: $currentKey") + currentWriter = newOutputWriter(currentKey) } - currentWriter.write(serializer.serialize(outputData, standardOI)) + + currentWriter.write(serialize(getOutputRow(inputRow))) } } finally { if (currentWriter != null) { currentWriter.close(false) } } - commit() - } catch { - case cause: Throwable => - logError("Aborting task.", cause) - abortTask() - throw new SparkException("Task failed while writing rows.", cause) } + + def serialize(data: InternalRow) = { + var i = 0 + while (i < fieldOIs.length) { + outputData(i) = if (data.isNullAt(i)) { + null + } else { + wrappers(i)(data.get(i, dataTypes(i))) + } + i += 1 + } + serializer.serialize(outputData, standardOI) + } + /** Open and returns a new OutputWriter given a partition key. */ def newOutputWriter(key: InternalRow): FileSinkOperator.RecordWriter = { val partitionPath = getPartitionString(key).getString(0) From a64be8a91ddadcd7acbbd08956f214b3c40f0dca Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 22 Apr 2016 16:36:24 -0700 Subject: [PATCH 6/6] SPARK-15420: Add repartition, sort optimization. This adds an optimizer rule that will add repartition and sort operations to the logical plan. Sort is added when the table has sort or bucketing columns. Repartition is added when writing columnar formats and the option "spark.sql.files.columnar.insertRepartition" is enabled. This also adds a `writersPerPartition(numTasks: Int)` option when writing that controls the number of files in each output table partition. The optimizer rule adds a repartition step that distributes output by partition and a random value in [0, numTasks). --- .../spark/sql/catalyst/CatalystConf.scala | 5 +- .../sql/catalyst/catalog/interface.scala | 10 +++ .../sql/catalyst/optimizer/Optimizer.scala | 84 ++++++++++++++++++- .../plans/logical/basicLogicalOperators.scala | 4 + .../apache/spark/sql/DataFrameWriter.scala | 25 +++++- .../apache/spark/sql/internal/SQLConf.scala | 10 +++ .../DataFrameReaderWriterSuite.scala | 2 +- 7 files changed, 136 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 4df100c2a830..3b0f6883f2da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -36,6 +36,8 @@ trait CatalystConf { def warehousePath: String + def repartitionColumnarData: Boolean + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. @@ -55,5 +57,6 @@ case class SimpleCatalystConf( optimizerInSetConversionThreshold: Int = 10, maxCaseBranchesForCodegen: Int = 20, runSQLonFile: Boolean = true, - warehousePath: String = "/user/hive/warehouse") + warehousePath: String = "/user/hive/warehouse", + repartitionColumnarData: Boolean = false) extends CatalystConf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 3fdd411ac4cc..3fbb9bb416ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -118,6 +118,16 @@ case class CatalogTable( def partitionColumns: Seq[CatalogColumn] = schema.filter { c => partitionColumnNames.contains(c.name) } + /** Columns this table is bucketed by. */ + private[sql] val bucketColumns: Seq[CatalogColumn] = bucketColumnNames.flatMap { name => + schema.find(_.name == name) + } + + /** Columns this table is bucketed by. */ + private[sql] val sortColumns: Seq[CatalogColumn] = sortColumnNames.flatMap { name => + schema.find(_.name == name) + } + /** Return the database this table was specified to belong to, assuming it exists. */ def database: String = identifier.database.getOrElse { throw new AnalysisException(s"table $identifier did not specify database") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6825b65e2b28..9db70012d434 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -23,13 +23,14 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ @@ -72,6 +73,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("Aggregate", fixedPoint, RemoveLiteralFromGroupExpressions, RemoveRepetitionFromGroupExpressions) :: + Batch("Output Layout Optimizations", Once, + DistributeAndSortOutputData(conf)) :: Batch("Operator Optimizations", fixedPoint, // Operator push down SetOperationPushDown, @@ -1737,3 +1740,82 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { } } } + +case class DistributeAndSortOutputData(conf: CatalystConf) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case insertInto @ InsertIntoTable(rel: CatalogRelation, partition, data, _, _, _) + if insertInto.resolved && insertInto.writersPerPartition.isDefined => + insertInto.copy(child = + buildRepartitionAndSort(rel.catalogTable, data, insertInto.writersPerPartition)) + + case insertInto @ InsertIntoTable(rel: CatalogRelation, partition, data, _, _, _) + if insertInto.resolved && requiresSort(rel.catalogTable) => + insertInto.copy(child = buildSort(rel.catalogTable, data)) + + case insertInto @ InsertIntoTable(rel: CatalogRelation, partition, data, _, _, _) + if insertInto.resolved && isColumnar(rel.catalogTable) && shouldRepartition(data) => + insertInto.copy(child = buildRepartitionAndSort(rel.catalogTable, data, None)) + } + + private def isColumnar(table: CatalogTable): Boolean = { + table.storage.serde.map(_.toLowerCase) + .forall(serde => serde.contains("parquet") || serde.contains("orc")) + } + + private def shouldRepartition(plan: LogicalPlan): Boolean = { + // automatically add repartitioning for columnar formats if enabled and doesn't conflict + conf.repartitionColumnarData && !hasSortOrRepartition(plan); + } + + private def hasSortOrRepartition(plan: LogicalPlan): Boolean = { + plan.collectFirst { + case _: RepartitionByExpression => true + case _: Sort => true + }.getOrElse(false) + } + + private def requiresSort(table: CatalogTable): Boolean = { + (table.bucketColumnNames.size + table.sortColumnNames.size) > 0 + } + + private def buildSort(table: CatalogTable, data: LogicalPlan): LogicalPlan = { + val partitionExprs = asExpr(table.partitionColumns, data) + val bucketExpr = asBucketExpr(table.bucketColumns, table.numBuckets, data) + val sortExprs = partitionExprs ++ bucketExpr ++ asExpr(table.sortColumns, data) + // add a sort without a repartition + Sort(sortExprs.map(expr => SortOrder(expr, Ascending)), global = false, data) + } + + private def buildRepartitionAndSort( + table: CatalogTable, + data: LogicalPlan, + numWriters: Option[Int]): LogicalPlan = { + val partitionExprs = asExpr(table.partitionColumns, data) ++ asDistributeExpr(numWriters) + val bucketExpr = asBucketExpr(table.bucketColumns, table.numBuckets, data) + val sortExprs = partitionExprs ++ bucketExpr ++ asExpr(table.sortColumns, data) + + // add a sort with an inner repartition + Sort( + sortExprs.map(expr => SortOrder(expr, Ascending)), + global = false, + RepartitionByExpression(partitionExprs, data, None)) + } + + private def asExpr(columns: Seq[CatalogColumn], data: LogicalPlan): Seq[Attribute] = { + columns.map(col => data.output.find(_.name == col.name).get) + } + + private def asDistributeExpr(numWriters: Option[Int]): Option[Expression] = { + numWriters.map(n => Pmod(Cast(Multiply(Rand(0L), Literal(n)), IntegerType), Literal(n))) + } + + private def asBucketExpr(columns: Seq[CatalogColumn], numBuckets: Int, + data: LogicalPlan): Option[Expression] = { + if (columns.isEmpty) { + None + } else { + Some(HashPartitioning(asExpr(columns, data), numBuckets).partitionIdExpression) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3b4f66ee5685..f39abfe19145 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -359,6 +359,10 @@ case class InsertIntoTable( options.get("matchByName").map(_.toBoolean).getOrElse(false) } + private[spark] def writersPerPartition: Option[Int] = { + options.get("writersPerPartition").map(_.toInt) + } + private[spark] lazy val expectedColumns = { if (table.output.isEmpty) { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 70ec1a730939..39ac9b8cad9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -224,6 +224,20 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } + /** + * Hint to distribute the output for each partition randomly across `numWriters` tasks. + * + * This is applicable for Parquet, JSON, ORC, and Hive. + * + * @param numWriters number of writers to use for each partition + * @return this DataFrameWriter for method chaining + * @since 2.0 + */ + def writersPerPartition(numWriters: Int): DataFrameWriter = { + option("writersPerPartition", numWriters) + this + } + /** * Saves the content of the [[DataFrame]] at the specified path. * @@ -241,6 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { */ def save(): Unit = { assertNotBucketed() + assertNotSorted() assertNotStreaming("save() can only be called on non-continuous queries") val dataSource = DataSource( df.sparkSession, @@ -290,6 +305,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { @Experimental def startStream(): ContinuousQuery = { assertNotBucketed() + assertNotSorted() assertStreaming("startStream() can only be called on continuous queries") if (source == "memory") { @@ -459,12 +475,19 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def assertNotBucketed(): Unit = { - if (numBuckets.isDefined || sortColumnNames.isDefined) { + if (bucketColumnNames.isDefined) { throw new IllegalArgumentException( "Currently we don't support writing bucketed data to this data source.") } } + private def assertNotSorted(): Unit = { + if (sortColumnNames.isDefined) { + throw new IllegalArgumentException( + "Currently we don't support writing sorted data to this data source.") + } + } + def byName: DataFrameWriter = { extraOptions.put("matchByName", "true") this diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 518430f16d71..cfb18a6eb5ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -521,6 +521,14 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(60 * 1000L) // 10 minutes + val REPARTITION_COLUMNAR_DATA = + SQLConfigBuilder("spark.sql.files.columnar.insertRepartition") + .internal() + .doc("Whether to automatically add a repartition step before writing columnar data " + + "formats, such as Parquet and Orc, to minimize output files and memory consumption.") + .booleanConf + .createWithDefault(false) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -652,6 +660,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def variableSubstituteDepth: Int = getConf(VARIABLE_SUBSTITUTE_DEPTH) + override def repartitionColumnarData: Boolean = getConf(REPARTITION_COLUMNAR_DATA) + def warehousePath: String = { getConf(WAREHOUSE_PATH).replace("${system:user.dir}", System.getProperty("user.dir")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index cb53b2b1aac1..9754bd2a7b66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -432,7 +432,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .stream() val w = df.write val e = intercept[IllegalArgumentException](w.sortBy("text").startStream()) - assert(e.getMessage == "Currently we don't support writing bucketed data to this data source.") + assert(e.getMessage == "Currently we don't support writing sorted data to this data source.") } test("check save(path) can only be called on non-continuous queries") {