diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 4df100c2a830..3b0f6883f2da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -36,6 +36,8 @@ trait CatalystConf { def warehousePath: String + def repartitionColumnarData: Boolean + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. @@ -55,5 +57,6 @@ case class SimpleCatalystConf( optimizerInSetConversionThreshold: Int = 10, maxCaseBranchesForCodegen: Int = 20, runSQLonFile: Boolean = true, - warehousePath: String = "/user/hive/warehouse") + warehousePath: String = "/user/hive/warehouse", + repartitionColumnarData: Boolean = false) extends CatalystConf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1313a011c69c..6f7506b4b04d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -101,6 +101,7 @@ class Analyzer( ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: + ResolveOutputColumns :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: @@ -445,7 +446,7 @@ class Analyzer( } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => + case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _, _) if child.resolved => val table = lookupTableFromCatalog(u) // adding the table's partitions or validate the query's partition info table match { @@ -499,6 +500,120 @@ class Analyzer( } } + object ResolveOutputColumns extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.transform { + case ins @ InsertIntoTable(relation: LogicalPlan, partition, _, _, _, _) + if ins.childrenResolved && !ins.resolved => + resolveOutputColumns(ins, expectedColumns(relation, partition), relation.toString) + } + + private def resolveOutputColumns( + insertInto: InsertIntoTable, + columns: Seq[Attribute], + relation: String) = { + val resolved = if (insertInto.isMatchByName) { + projectAndCastOutputColumns(columns, insertInto.child, relation) + } else { + castAndRenameOutputColumns(columns, insertInto.child, relation) + } + + if (resolved == insertInto.child.output) { + insertInto + } else { + insertInto.copy(child = Project(resolved, insertInto.child)) + } + } + + /** + * Resolves output columns by input column name, adding casts if necessary. + */ + private def projectAndCastOutputColumns( + output: Seq[Attribute], + data: LogicalPlan, + relation: String): Seq[NamedExpression] = { + output.map { col => + data.resolveQuoted(col.name, resolver) match { + case Some(inCol) if col.dataType != inCol.dataType => + Alias(UpCast(inCol, col.dataType, Seq()), col.name)() + case Some(inCol) => inCol + case None => + throw new AnalysisException( + s"Cannot resolve ${col.name} in ${data.output.mkString(",")}") + } + } + } + + private def castAndRenameOutputColumns( + output: Seq[Attribute], + data: LogicalPlan, + relation: String): Seq[NamedExpression] = { + val outputNames = output.map(_.name) + // incoming expressions may not have names + val inputNames = data.output.flatMap(col => Option(col.name)) + if (output.size > data.output.size) { + // always a problem + throw new AnalysisException( + s"""Not enough data columns to write into $relation: + |Data columns: ${data.output.mkString(",")} + |Table columns: ${outputNames.mkString(",")}""".stripMargin) + } else if (output.size < data.output.size) { + if (outputNames.toSet.subsetOf(inputNames.toSet)) { + throw new AnalysisException( + s"""Table column names are a subset of the input data columns: + |Data columns: ${inputNames.mkString(",")} + |Table columns: ${outputNames.mkString(",")} + |To write a subset of the columns by name, use df.write.byName.insertInto(...)""" + .stripMargin) + } else { + // be conservative and fail if there are too many columns + throw new AnalysisException( + s"""Extra data columns to write into $relation: + |Data columns: ${data.output.mkString(",")} + |Table columns: ${outputNames.mkString(",")}""".stripMargin) + } + } else { + // check for reordered names and warn. this may be on purpose, so it isn't an error. + if (outputNames.toSet == inputNames.toSet && outputNames != inputNames) { + logWarning( + s"""Data column names match the table in a different order: + |Data columns: ${inputNames.mkString(",")} + |Table columns: ${outputNames.mkString(",")} + |To map columns by name, use df.write.byName.insertInto(...)""".stripMargin) + } + } + + data.output.zip(output).map { + case (in, out) if !in.dataType.sameType(out.dataType) => + Alias(Cast(in, out.dataType), out.name)() + case (in, out) if in.name != out.name => + Alias(in, out.name)() + case (in, _) => in + } + } + + private def expectedColumns( + data: LogicalPlan, + partitionData: Map[String, Option[String]]): Seq[Attribute] = { + data match { + case partitioned: CatalogRelation => + val tablePartitionNames = partitioned.catalogTable.partitionColumns.map(_.name) + val (inputPartCols, dataColumns) = data.output.partition { attr => + tablePartitionNames.contains(attr.name) + } + // Get the dynamic partition columns in partition order + val dynamicNames = tablePartitionNames.filter( + name => partitionData.getOrElse(name, None).isEmpty) + val dynamicPartCols = dynamicNames.map { name => + inputPartCols.find(_.name == name).getOrElse( + throw new AnalysisException(s"Cannot find partition column $name")) + } + + dataColumns ++ dynamicPartCols + case _ => data.output + } + } + } + /** * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from * a logical plan node's children. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 3fdd411ac4cc..3fbb9bb416ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -118,6 +118,16 @@ case class CatalogTable( def partitionColumns: Seq[CatalogColumn] = schema.filter { c => partitionColumnNames.contains(c.name) } + /** Columns this table is bucketed by. */ + private[sql] val bucketColumns: Seq[CatalogColumn] = bucketColumnNames.flatMap { name => + schema.find(_.name == name) + } + + /** Columns this table is bucketed by. */ + private[sql] val sortColumns: Seq[CatalogColumn] = sortColumnNames.flatMap { name => + schema.find(_.name == name) + } + /** Return the database this table was specified to belong to, assuming it exists. */ def database: String = identifier.database.getOrElse { throw new AnalysisException(s"table $identifier did not specify database") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 2ca990d19a2c..a5c2fd038ace 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -367,7 +367,7 @@ package object dsl { def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( analysis.UnresolvedRelation(TableIdentifier(tableName)), - Map.empty, logicalPlan, overwrite, false) + Map.empty, logicalPlan, overwrite, ifNotExists = false, Map.empty) def as(alias: String): LogicalPlan = logicalPlan match { case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 42a8be6b1b1e..ab562e426dcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator @@ -63,6 +64,21 @@ case class SortOrder(child: Expression, direction: SortDirection) def isAscending: Boolean = direction == Ascending } +// TODO: should this be an implicit class somewhere? +object SortOrder { + def satisfies(order: Seq[SortOrder], distribution: Distribution): Boolean = { + distribution match { + case c @ ClusteredDistribution(exprs) => + // Zip discards extra order by expressions + (order.size >= exprs.size) && exprs.zip(order.map(_.child)).forall { + case (clusterExpr, orderExpr) => clusterExpr.semanticEquals(orderExpr) + case _ => false + } + case _ => false + } + } +} + /** * An expression to generate a 64-bit long prefix used in sorting. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6825b65e2b28..9db70012d434 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -23,13 +23,14 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ @@ -72,6 +73,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("Aggregate", fixedPoint, RemoveLiteralFromGroupExpressions, RemoveRepetitionFromGroupExpressions) :: + Batch("Output Layout Optimizations", Once, + DistributeAndSortOutputData(conf)) :: Batch("Operator Optimizations", fixedPoint, // Operator push down SetOperationPushDown, @@ -1737,3 +1740,82 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { } } } + +case class DistributeAndSortOutputData(conf: CatalystConf) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case insertInto @ InsertIntoTable(rel: CatalogRelation, partition, data, _, _, _) + if insertInto.resolved && insertInto.writersPerPartition.isDefined => + insertInto.copy(child = + buildRepartitionAndSort(rel.catalogTable, data, insertInto.writersPerPartition)) + + case insertInto @ InsertIntoTable(rel: CatalogRelation, partition, data, _, _, _) + if insertInto.resolved && requiresSort(rel.catalogTable) => + insertInto.copy(child = buildSort(rel.catalogTable, data)) + + case insertInto @ InsertIntoTable(rel: CatalogRelation, partition, data, _, _, _) + if insertInto.resolved && isColumnar(rel.catalogTable) && shouldRepartition(data) => + insertInto.copy(child = buildRepartitionAndSort(rel.catalogTable, data, None)) + } + + private def isColumnar(table: CatalogTable): Boolean = { + table.storage.serde.map(_.toLowerCase) + .forall(serde => serde.contains("parquet") || serde.contains("orc")) + } + + private def shouldRepartition(plan: LogicalPlan): Boolean = { + // automatically add repartitioning for columnar formats if enabled and doesn't conflict + conf.repartitionColumnarData && !hasSortOrRepartition(plan); + } + + private def hasSortOrRepartition(plan: LogicalPlan): Boolean = { + plan.collectFirst { + case _: RepartitionByExpression => true + case _: Sort => true + }.getOrElse(false) + } + + private def requiresSort(table: CatalogTable): Boolean = { + (table.bucketColumnNames.size + table.sortColumnNames.size) > 0 + } + + private def buildSort(table: CatalogTable, data: LogicalPlan): LogicalPlan = { + val partitionExprs = asExpr(table.partitionColumns, data) + val bucketExpr = asBucketExpr(table.bucketColumns, table.numBuckets, data) + val sortExprs = partitionExprs ++ bucketExpr ++ asExpr(table.sortColumns, data) + // add a sort without a repartition + Sort(sortExprs.map(expr => SortOrder(expr, Ascending)), global = false, data) + } + + private def buildRepartitionAndSort( + table: CatalogTable, + data: LogicalPlan, + numWriters: Option[Int]): LogicalPlan = { + val partitionExprs = asExpr(table.partitionColumns, data) ++ asDistributeExpr(numWriters) + val bucketExpr = asBucketExpr(table.bucketColumns, table.numBuckets, data) + val sortExprs = partitionExprs ++ bucketExpr ++ asExpr(table.sortColumns, data) + + // add a sort with an inner repartition + Sort( + sortExprs.map(expr => SortOrder(expr, Ascending)), + global = false, + RepartitionByExpression(partitionExprs, data, None)) + } + + private def asExpr(columns: Seq[CatalogColumn], data: LogicalPlan): Seq[Attribute] = { + columns.map(col => data.output.find(_.name == col.name).get) + } + + private def asDistributeExpr(numWriters: Option[Int]): Option[Expression] = { + numWriters.map(n => Pmod(Cast(Multiply(Rand(0L), Literal(n)), IntegerType), Literal(n))) + } + + private def asBucketExpr(columns: Seq[CatalogColumn], numBuckets: Int, + data: LogicalPlan): Option[Expression] = { + if (columns.isEmpty) { + None + } else { + Some(HashPartitioning(asExpr(columns, data), numBuckets).partitionIdExpression) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 2d7d0f903295..4d2ee6f25a3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -211,8 +211,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { UnresolvedRelation(tableIdent, None), partitionKeys, query, - ctx.OVERWRITE != null, - ctx.EXISTS != null) + overwrite = ctx.OVERWRITE != null, + ifNotExists = ctx.EXISTS != null, + Map.empty /* SQL always matches by position */) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 732b0d7919c3..f39abfe19145 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -348,28 +348,45 @@ case class InsertIntoTable( partition: Map[String, Option[String]], child: LogicalPlan, overwrite: Boolean, - ifNotExists: Boolean) + ifNotExists: Boolean, + options: Map[String, String]) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = Seq.empty + private[spark] def isMatchByName: Boolean = { + options.get("matchByName").map(_.toBoolean).getOrElse(false) + } + + private[spark] def writersPerPartition: Option[Int] = { + options.get("writersPerPartition").map(_.toInt) + } + private[spark] lazy val expectedColumns = { if (table.output.isEmpty) { None } else { - val numDynamicPartitions = partition.values.count(_.isEmpty) + val dynamicPartitionNames = partition.filter { + case (name, Some(_)) => false + case (name, None) => true + }.keySet val (partitionColumns, dataColumns) = table.output .partition(a => partition.keySet.contains(a.name)) - Some(dataColumns ++ partitionColumns.takeRight(numDynamicPartitions)) + Some(dataColumns ++ partitionColumns.filter(col => dynamicPartitionNames.contains(col.name))) } } assert(overwrite || !ifNotExists) - override lazy val resolved: Boolean = childrenResolved && expectedColumns.forall { expected => - child.output.size == expected.size && child.output.zip(expected).forall { - case (childAttr, tableAttr) => - DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) + override lazy val resolved: Boolean = childrenResolved && { + expectedColumns match { + case Some(expected) => + child.output.size == expected.size && child.output.zip(expected).forall { + case (childAttr, tableAttr) => + childAttr.name == tableAttr.name && // required by some relations + DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) + } + case None => true } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 25d87d93bec4..015b8b413f64 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -191,7 +191,7 @@ class PlanParserSuite extends PlanTest { partition: Map[String, Option[String]], overwrite: Boolean = false, ifNotExists: Boolean = false): LogicalPlan = - InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists, Map.empty) // Single inserts assertEqual(s"insert overwrite table s $sql", @@ -209,9 +209,11 @@ class PlanParserSuite extends PlanTest { val plan2 = table("t").where('x > 5).select(star()) assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", InsertIntoTable( - table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( + table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false, + Map.empty).union( InsertIntoTable( - table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) + table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false, + Map.empty))) } test("aggregation") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 6f5fb69ea377..39ac9b8cad9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -224,6 +224,20 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } + /** + * Hint to distribute the output for each partition randomly across `numWriters` tasks. + * + * This is applicable for Parquet, JSON, ORC, and Hive. + * + * @param numWriters number of writers to use for each partition + * @return this DataFrameWriter for method chaining + * @since 2.0 + */ + def writersPerPartition(numWriters: Int): DataFrameWriter = { + option("writersPerPartition", numWriters) + this + } + /** * Saves the content of the [[DataFrame]] at the specified path. * @@ -241,6 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { */ def save(): Unit = { assertNotBucketed() + assertNotSorted() assertNotStreaming("save() can only be called on non-continuous queries") val dataSource = DataSource( df.sparkSession, @@ -290,6 +305,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { @Experimental def startStream(): ContinuousQuery = { assertNotBucketed() + assertNotSorted() assertStreaming("startStream() can only be called on continuous queries") if (source == "memory") { @@ -408,7 +424,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { partitions.getOrElse(Map.empty[String, Option[String]]), input, overwrite, - ifNotExists = false)).toRdd + ifNotExists = false, + options = extraOptions.toMap)).toRdd } private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => @@ -458,12 +475,24 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def assertNotBucketed(): Unit = { - if (numBuckets.isDefined || sortColumnNames.isDefined) { + if (bucketColumnNames.isDefined) { throw new IllegalArgumentException( "Currently we don't support writing bucketed data to this data source.") } } + private def assertNotSorted(): Unit = { + if (sortColumnNames.isDefined) { + throw new IllegalArgumentException( + "Currently we don't support writing sorted data to this data source.") + } + } + + def byName: DataFrameWriter = { + extraOptions.put("matchByName", "true") + this + } + /** * Saves the content of the [[DataFrame]] as the specified table. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 0494fafb0e42..e2d9cc95515b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -46,7 +46,7 @@ import org.apache.spark.unsafe.types.UTF8String private[sql] object DataSourceAnalysis extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transform { case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: HadoopFsRelation, _, _), part, query, overwrite, false) + l @ LogicalRelation(t: HadoopFsRelation, _, _), part, query, overwrite, false, _) if query.resolved && t.schema.asNullable == query.schema.asNullable => // Sanity checks @@ -110,7 +110,7 @@ private[sql] class FindDataSourceTable(sparkSession: SparkSession) extends Rule[ } override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ logical.InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) + case i @ logical.InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _, _) if DDLUtils.isDatasourceTable(s.metadata) => i.copy(table = readDataSourceTable(sparkSession, s.metadata)) @@ -152,7 +152,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _), - part, query, overwrite, false) if part.isEmpty => + part, query, overwrite, false, _) if part.isEmpty => ExecutedCommandExec(InsertIntoDataSource(l, query, overwrite)) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 4921e4ca6bb7..bfad55936e5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -130,9 +130,10 @@ private[sql] case class InsertIntoHadoopFsRelation( partitionColumns = partitionColumns, dataColumns = dataColumns, inputSchema = query.output, - PartitioningUtils.DEFAULT_PARTITION_NAME, - sparkSession.conf.get(SQLConf.PARTITION_MAX_FILES), - isAppend) + defaultPartitionName = PartitioningUtils.DEFAULT_PARTITION_NAME, + incomingOrder = queryExecution.sparkPlan.outputOrdering, + maxOpenFiles = sparkSession.conf.get(SQLConf.PARTITION_MAX_FILES), + isAppend = isAppend) } // This call shouldn't be put into the `try` block below because it only initializes and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 7e12bbb2128b..06908dd3a6ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -29,7 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.internal.SQLConf @@ -300,6 +300,7 @@ private[sql] class DynamicPartitionWriterContainer( dataColumns: Seq[Attribute], inputSchema: Seq[Attribute], defaultPartitionName: String, + incomingOrder: Seq[SortOrder], maxOpenFiles: Int, isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { @@ -371,13 +372,9 @@ private[sql] class DynamicPartitionWriterContainer( // We should first sort by partition columns, then bucket id, and finally sorting columns. val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns - val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema) - val sortingKeySchema = StructType(sortingExpressions.map { - case a: Attribute => StructField(a.name, a.dataType, a.nullable) - // The sorting expressions are all `Attribute` except bucket id. - case _ => StructField("bucketId", IntegerType, nullable = false) - }) + // If the data is already sorted correctly, avoid sorting it again + val isSorted = SortOrder.satisfies(incomingOrder, ClusteredDistribution(sortingExpressions)) // Returns the data columns to be written given an input row val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) @@ -386,53 +383,15 @@ private[sql] class DynamicPartitionWriterContainer( val getPartitionString = UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) - // Sorts the data before write, so that we only need one writer at the same time. - // TODO: inject a local sort operator in planning. - val sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) - - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } - logInfo(s"Sorting complete. Writing out partition files one at a time.") - - val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { - identity - } else { - UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { - case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) - }) - } - - val sortedIterator = sorter.sortedIterator() - // If anything below fails, we should abort the task. var currentWriter: OutputWriter = null try { Utils.tryWithSafeFinallyAndFailureCallbacks { - var currentKey: UnsafeRow = null - while (sortedIterator.next()) { - val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] - if (currentKey != nextKey) { - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } - currentKey = nextKey.copy() - logDebug(s"Writing partition: $currentKey") - - currentWriter = newOutputWriter(currentKey, getPartitionString) - } - currentWriter.writeInternal(sortedIterator.getValue) - } - if (currentWriter != null) { - currentWriter.close() - currentWriter = null + if (isSorted) { + sortedWrite(iterator) + } else { + // TODO: inject a local sort operator in planning. + sortAndWrite(iterator) } commitTask() @@ -446,6 +405,86 @@ private[sql] class DynamicPartitionWriterContainer( case t: Throwable => throw new SparkException("Task failed while writing rows", t) } + + def sortAndWrite(iterator: Iterator[InternalRow]): Unit = { + val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema) + + // Sorts the data before write, so that we only need one writer at the same time. + val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { + identity + } else { + UnsafeProjection.create( + sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { + case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) + }) + } + + val sortingKeySchema = StructType(sortingExpressions.map { + case a: Attribute => StructField(a.name, a.dataType, a.nullable) + // The sorting expressions are all `Attribute` except bucket id. + case _ => StructField("bucketId", IntegerType, nullable = false) + }) + + val sorter = new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + val sortedIterator = sorter.sortedIterator() + + var currentKey: UnsafeRow = null + while (sortedIterator.next()) { + val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") + + currentWriter = newOutputWriter(currentKey, getPartitionString) + } + currentWriter.writeInternal(sortedIterator.getValue) + } + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + } + + def sortedWrite(iterator: Iterator[InternalRow]): Unit = { + val getBucketingKey: InternalRow => InternalRow = + UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length), inputSchema) + var currentKey: UnsafeRow = null + while (iterator.hasNext) { + val currentRow = iterator.next() + val nextKey = getBucketingKey(currentRow).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") + + currentWriter = newOutputWriter(currentKey, getPartitionString) + } + currentWriter.writeInternal(getOutputRow(currentRow)) + } + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index b622f859413a..d6e8bcff303e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -49,55 +49,6 @@ private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[Lo } } -/** - * A rule to do pre-insert data type casting and field renaming. Before we insert into - * an [[InsertableRelation]], we will use this rule to make sure that - * the columns to be inserted have the correct data type and fields have the correct names. - */ -private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Wait until children are resolved. - case p: LogicalPlan if !p.childrenResolved => p - - // We are inserting into an InsertableRelation or HadoopFsRelation. - case i @ InsertIntoTable( - l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _), _, child, _, _) => - // First, make sure the data to be inserted have the same number of fields with the - // schema of the relation. - if (l.output.size != child.output.size) { - sys.error( - s"$l requires that the query in the SELECT clause of the INSERT INTO/OVERWRITE " + - s"statement generates the same number of columns as its schema.") - } - castAndRenameChildOutput(i, l.output, child) - } - - /** If necessary, cast data types and rename fields to the expected types and names. */ - def castAndRenameChildOutput( - insertInto: InsertIntoTable, - expectedOutput: Seq[Attribute], - child: LogicalPlan): InsertIntoTable = { - val newChildOutput = expectedOutput.zip(child.output).map { - case (expected, actual) => - val needCast = !expected.dataType.sameType(actual.dataType) - // We want to make sure the filed names in the data to be inserted exactly match - // names in the schema. - val needRename = expected.name != actual.name - (needCast, needRename) match { - case (true, _) => Alias(Cast(actual, expected.dataType), expected.name)() - case (false, true) => Alias(actual, expected.name)() - case (_, _) => actual - } - } - - if (newChildOutput == child.output) { - insertInto - } else { - insertInto.copy(child = Project(newChildOutput, child)) - } - } -} - /** * A rule to do various checks before inserting into or writing to a data source table. */ @@ -110,7 +61,7 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) plan.foreach { case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation, _, _), - partition, query, overwrite, ifNotExists) => + partition, query, overwrite, ifNotExists, _) => // Right now, we do not support insert into a data source table with partition specs. if (partition.nonEmpty) { failAnalysis(s"Insert into a partition is not allowed because $l is not partitioned.") @@ -128,7 +79,7 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) } case logical.InsertIntoTable( - LogicalRelation(r: HadoopFsRelation, _, _), part, query, overwrite, _) => + LogicalRelation(r: HadoopFsRelation, _, _), part, query, overwrite, _, _) => // We need to make sure the partition columns specified by users do match partition // columns of the relation. val existingPartitionColumns = r.partitionSchema.fieldNames.toSet @@ -156,11 +107,11 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) // OK } - case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _) => + case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _, _) => // The relation in l is not an InsertableRelation. failAnalysis(s"$l does not allow insertion.") - case logical.InsertIntoTable(t, _, _, _, _) => + case logical.InsertIntoTable(t, _, _, _, _, _) => if (!t.isInstanceOf[LeafNode] || t == OneRowRelation || t.isInstanceOf[LocalRelation]) { failAnalysis(s"Inserting into an RDD-based table is not allowed.") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 518430f16d71..cfb18a6eb5ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -521,6 +521,14 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(60 * 1000L) // 10 minutes + val REPARTITION_COLUMNAR_DATA = + SQLConfigBuilder("spark.sql.files.columnar.insertRepartition") + .internal() + .doc("Whether to automatically add a repartition step before writing columnar data " + + "formats, such as Parquet and Orc, to minimize output files and memory consumption.") + .booleanConf + .createWithDefault(false) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -652,6 +660,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def variableSubstituteDepth: Int = getConf(VARIABLE_SUBSTITUTE_DEPTH) + override def repartitionColumnarData: Boolean = getConf(REPARTITION_COLUMNAR_DATA) + def warehousePath: String = { getConf(WAREHOUSE_PATH).replace("${system:user.dir}", System.getProperty("user.dir")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index f0b8a83dee8c..46c5f634dd11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.AnalyzeTable -import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, PreInsertCastAndRename, ResolveDataSource} +import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, ResolveDataSource} import org.apache.spark.sql.util.ExecutionListenerManager @@ -109,7 +109,6 @@ private[sql] class SessionState(sparkSession: SparkSession) { lazy val analyzer: Analyzer = { new Analyzer(catalog, conf) { override val extendedResolutionRules = - PreInsertCastAndRename :: new FindDataSourceTable(sparkSession) :: DataSourceAnalysis :: (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 854fec5b22f7..9e07e66ac48d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -87,15 +87,15 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } test("SELECT clause generating a different number of columns is not allowed.") { - val message = intercept[RuntimeException] { + val message = intercept[AnalysisException] { sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt """.stripMargin) }.getMessage assert( - message.contains("generates the same number of columns as its schema"), - "SELECT clause generating a different number of columns should not be not allowed." + message.contains("Not enough data columns to write"), + "SELECT clause must generate all of a table's columns to write" ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index cb53b2b1aac1..9754bd2a7b66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -432,7 +432,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B .stream() val w = df.write val e = intercept[IllegalArgumentException](w.sortBy("text").startStream()) - assert(e.getMessage == "Currently we don't support writing bucketed data to this data source.") + assert(e.getMessage == "Currently we don't support writing sorted data to this data source.") } test("check save(path) can only be called on non-continuous queries") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 4c528fbbbeef..c9c025c2b879 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.execution.command.CreateTableAsSelectLogicalPlan @@ -370,16 +369,20 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log plan transformUp { // Write path - case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists, + options) // Inserting into partitioned table is not supported in Parquet data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) => - InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists) + InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists, + options) // Write path - case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + case InsertIntoHiveTable(r: MetastoreRelation, + partition, child, overwrite, ifNotExists, options) // Inserting into partitioned table is not supported in Parquet data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) => - InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists) + InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists, + options) // Read path case relation: MetastoreRelation if shouldConvertMetastoreParquet(relation) => @@ -414,16 +417,20 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log plan transformUp { // Write path - case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists, + options) // Inserting into partitioned table is not supported in Orc data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) => - InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists) + InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists, + options) // Write path - case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + case InsertIntoHiveTable(r: MetastoreRelation, + partition, child, overwrite, ifNotExists, options) // Inserting into partitioned table is not supported in Orc data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) => - InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists) + InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists, + options) // Read path case relation: MetastoreRelation if shouldConvertMetastoreOrc(relation) => @@ -492,49 +499,6 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } } - - /** - * Casts input data to correct data types according to table definition before inserting into - * that table. - */ - object PreInsertionCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transform { - // Wait until children are resolved. - case p: LogicalPlan if !p.childrenResolved => p - - case p @ InsertIntoTable(table: MetastoreRelation, _, child, _, _) => - castChildOutput(p, table, child) - } - - def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) - : LogicalPlan = { - val childOutputDataTypes = child.output.map(_.dataType) - val numDynamicPartitions = p.partition.values.count(_.isEmpty) - val tableOutputDataTypes = - (table.attributes ++ table.partitionKeys.takeRight(numDynamicPartitions)) - .take(child.output.length).map(_.dataType) - - if (childOutputDataTypes == tableOutputDataTypes) { - InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) - } else if (childOutputDataTypes.size == tableOutputDataTypes.size && - childOutputDataTypes.zip(tableOutputDataTypes) - .forall { case (left, right) => left.sameType(right) }) { - // If both types ignoring nullability of ArrayType, MapType, StructType are the same, - // use InsertIntoHiveTable instead of InsertIntoTable. - InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) - } else { - // Only do the casting when child output data types differ from table output data types. - val castedChildOutput = child.output.zip(table.output).map { - case (input, output) if input.dataType != output.dataType => - Alias(Cast(input, output.dataType), input.name)() - case (input, _) => input - } - - p.copy(child = logical.Project(castedChildOutput, child)) - } - } - } - } /** @@ -577,7 +541,8 @@ private[hive] case class InsertIntoHiveTable( partition: Map[String, Option[String]], child: LogicalPlan, overwrite: Boolean, - ifNotExists: Boolean) + ifNotExists: Boolean, + options: Map[String, String]) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 4f8aac8c2fcd..2f6a2207855e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -87,7 +87,6 @@ private[sql] class HiveSessionCatalog( val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions val CreateTables: Rule[LogicalPlan] = metastoreCatalog.CreateTables - val PreInsertionCasts: Rule[LogicalPlan] = metastoreCatalog.PreInsertionCasts override def refreshTable(name: TableIdentifier): Unit = { metastoreCatalog.refreshTable(name) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 46579ecd85ca..9ee029100f5a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -65,8 +65,6 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) catalog.ParquetConversions :: catalog.OrcConversions :: catalog.CreateTables :: - catalog.PreInsertionCasts :: - PreInsertCastAndRename :: DataSourceAnalysis :: (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 71b180e55b58..7d1daa496f09 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -43,11 +43,11 @@ private[hive] trait HiveStrategies { object DataSinks extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable( - table: MetastoreRelation, partition, child, overwrite, ifNotExists) => + table: MetastoreRelation, partition, child, overwrite, ifNotExists, _) => execution.InsertIntoHiveTable( table, partition, planLater(child), overwrite, ifNotExists) :: Nil case hive.InsertIntoHiveTable( - table: MetastoreRelation, partition, child, overwrite, ifNotExists) => + table: MetastoreRelation, partition, child, overwrite, ifNotExists, _) => execution.InsertIntoHiveTable( table, partition, planLater(child), overwrite, ifNotExists) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 9dfbafae872f..6b298cc15a64 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -86,8 +86,8 @@ case class CreateTableAsSelect( throw new AnalysisException(s"$tableIdentifier already exists.") } } else { - sparkSession.executePlan(InsertIntoTable( - metastoreRelation, Map(), query, overwrite = true, ifNotExists = false)).toRdd + sparkSession.executePlan(InsertIntoTable(metastoreRelation, Map(), query, + overwrite = true, ifNotExists = false, Map.empty)).toRdd } Seq.empty[Row] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3805674d3958..1a942c543eee 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -34,7 +34,8 @@ import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} @@ -218,10 +219,14 @@ case class InsertIntoHiveTable( val writerContainer = if (numDynamicPartitions > 0) { val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) + val partitionClustering = ClusteredDistribution(child.output.takeRight(numDynamicPartitions)) + val isSorted = SortOrder.satisfies(child.outputOrdering, partitionClustering) + new SparkHiveDynamicPartitionWriterContainer( jobConf, fileSinkConf, dynamicPartColNames, + isSorted, child.output, table) } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 706fdbc2604f..556d95a60375 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -218,6 +218,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( jobConf: JobConf, fileSinkConf: FileSinkDesc, dynamicPartColNames: Array[String], + isSorted: Boolean, inputSchema: Seq[Attribute], table: MetastoreRelation) extends SparkHiveWriterContainer(jobConf, fileSinkConf, inputSchema, table) { @@ -277,6 +278,20 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( // If anything below fails, we should abort the task. try { + if (isSorted) { + sortedWrite(iterator) + } else { + sortAndWrite(iterator) + } + commit() + } catch { + case cause: Throwable => + logError("Aborting task.", cause) + abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } + + def sortAndWrite(iterator: Iterator[InternalRow]): Unit = { val sorter: UnsafeKVExternalSorter = new UnsafeKVExternalSorter( StructType.fromAttributes(partitionOutput), StructType.fromAttributes(dataOutput), @@ -291,6 +306,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( } logInfo(s"Sorting complete. Writing out partition files one at a time.") + val sortedIterator = sorter.sortedIterator() var currentKey: InternalRow = null var currentWriter: FileSinkOperator.RecordWriter = null @@ -305,29 +321,53 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( currentWriter = newOutputWriter(currentKey) } - var i = 0 - while (i < fieldOIs.length) { - outputData(i) = if (sortedIterator.getValue.isNullAt(i)) { - null - } else { - wrappers(i)(sortedIterator.getValue.get(i, dataTypes(i))) + currentWriter.write(serialize(sortedIterator.getValue)) + } + } finally { + if (currentWriter != null) { + currentWriter.close(false) + } + } + } + + def sortedWrite(iterator: Iterator[InternalRow]): Unit = { + var currentKey: InternalRow = null + var currentWriter: FileSinkOperator.RecordWriter = null + try { + while (iterator.hasNext) { + val inputRow = iterator.next() + val rowKey = getPartitionKey(inputRow) + if (currentKey != rowKey) { + if (currentWriter != null) { + currentWriter.close(false) } - i += 1 + currentKey = rowKey + logDebug(s"Writing partition: $currentKey") + currentWriter = newOutputWriter(currentKey) } - currentWriter.write(serializer.serialize(outputData, standardOI)) + + currentWriter.write(serialize(getOutputRow(inputRow))) } } finally { if (currentWriter != null) { currentWriter.close(false) } } - commit() - } catch { - case cause: Throwable => - logError("Aborting task.", cause) - abortTask() - throw new SparkException("Task failed while writing rows.", cause) } + + def serialize(data: InternalRow) = { + var i = 0 + while (i < fieldOIs.length) { + outputData(i) = if (data.isNullAt(i)) { + null + } else { + wrappers(i)(data.get(i, dataTypes(i))) + } + i += 1 + } + serializer.serialize(outputData, standardOI) + } + /** Open and returns a new OutputWriter given a partition key. */ def newOutputWriter(key: InternalRow): FileSinkOperator.RecordWriter = { val partitionPath = getPartitionString(key).getString(0) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index b25684562075..0a17df9541ad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -313,8 +313,128 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef val data = (1 to 10).map(i => (i.toLong, s"data-$i")).toDF("id", "data") val logical = InsertIntoTable(spark.table("partitioned").logicalPlan, - Map("part" -> None), data.logicalPlan, overwrite = false, ifNotExists = false) + Map("part" -> None), data.logicalPlan, overwrite = false, ifNotExists = false, + Map("matchByName" -> "true")) assert(!logical.resolved, "Should not resolve: missing partition data") } } + + test("Insert unnamed expressions by position") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, part string)") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + + val expected = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "part") + val data = expected.select("id", "part") + + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + + // should be able to insert an expression when NOT mapping columns by name + spark.table("source").selectExpr("id", "part", "CONCAT('data-', id)") + .write.insertInto("partitioned") + checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) + } + } + + test("Insert expression by name") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, part string)") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + + val expected = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "part") + val data = expected.select("id", "part") + + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + + intercept[AnalysisException] { + // also a problem when mapping by name + spark.table("source").selectExpr("id", "part", "CONCAT('data-', id)") + .write.byName.insertInto("partitioned") + } + + // should be able to insert an expression using AS when mapping columns by name + spark.table("source").selectExpr("id", "part", "CONCAT('data-', id) as data") + .write.byName.insertInto("partitioned") + checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) + } + } + + test("Reject missing columns") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, part string)") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + + intercept[AnalysisException] { + spark.table("source").write.insertInto("partitioned") + } + + intercept[AnalysisException] { + // also a problem when mapping by name + spark.table("source").write.byName.insertInto("partitioned") + } + } + } + + test("Reject extra columns") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, data string, extra string, part string)") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + + intercept[AnalysisException] { + spark.table("source").write.insertInto("partitioned") + } + + val data = (1 to 10) + .map(i => (i, s"data-$i", s"${i * i}", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "extra", "part") + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + + spark.table("source").write.byName.insertInto("partitioned") + + val expected = data.select("id", "data", "part") + checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) + } + } + + test("Ignore names when writing by position") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, part string, data string)") // part, data transposed + sql("CREATE TABLE destination (id bigint, data string, part string)") + + val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "part") + + // write into the reordered table by name + data.write.byName.insertInto("source") + checkAnswer(sql("SELECT id, data, part FROM source"), data.collect().toSeq) + + val expected = data.select($"id", $"part" as "data", $"data" as "part") + + // this produces a warning, but writes src.part -> dest.data and src.data -> dest.part + spark.table("source").write.insertInto("destination") + checkAnswer(sql("SELECT id, data, part FROM destination"), expected.collect().toSeq) + } + } + + test("Reorder columns by name") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (data string, part string, id bigint)") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + + val data = (1 to 10).map(i => (s"data-$i", if ((i % 2) == 0) "even" else "odd", i)) + .toDF("data", "part", "id") + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + + spark.table("source").write.byName.insertInto("partitioned") + + val expected = data.select("id", "data", "part") + checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index b12f3aafefb8..319e4e15bf62 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -348,6 +348,7 @@ abstract class HiveComparisonTest val containsCommands = originalQuery.analyzed.collectFirst { case _: Command => () case _: LogicalInsertIntoHiveTable => () + case _: InsertIntoTable => () }.nonEmpty if (containsCommands) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2aaaaadb6afa..690538b1f31a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1039,7 +1039,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("SET hive.exec.dynamic.partition.mode=nonstrict") sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") - sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value, ds, hr FROM srcpart") .queryExecution.analyzed } @@ -1050,6 +1050,26 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + test("SPARK-14543: AnalysisException for missing partition columns") { + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") + + intercept[AnalysisException] { + // src doesn't have ds and hr partition columns + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") + .queryExecution.analyzed + } + + intercept[AnalysisException] { + // ds and hr partition columns aren't selected + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM srcpart") + .queryExecution.analyzed + } + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly"