diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 5abc6f3ed5769..8e2afbcff4252 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.StructType /** * Performs (external) sorting. @@ -71,36 +72,8 @@ case class SortExec( * should make it public. */ def createSorter(): UnsafeExternalRowSorter = { - val ordering = RowOrdering.create(sortOrder, output) - - // The comparator for comparing prefix - val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) - val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) - - val canUseRadixSort = enableRadixSort && sortOrder.length == 1 && - SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression) - - // The generator for prefix - val prefixExpr = SortPrefix(boundSortExpression) - val prefixProjection = UnsafeProjection.create(Seq(prefixExpr)) - val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { - private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix - override def computePrefix(row: InternalRow): - UnsafeExternalRowSorter.PrefixComputer.Prefix = { - val prefix = prefixProjection.apply(row) - result.isNull = prefix.isNullAt(0) - result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0) - result - } - } - - val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - rowSorter = UnsafeExternalRowSorter.create( - schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort) - - if (testSpillFrequency > 0) { - rowSorter.setTestSpillFrequency(testSpillFrequency) - } + rowSorter = SortExec.createSorter( + sortOrder, output, schema, enableRadixSort, testSpillFrequency) rowSorter } @@ -206,3 +179,43 @@ case class SortExec( override protected def withNewChildInternal(newChild: SparkPlan): SortExec = copy(child = newChild) } +object SortExec { + def createSorter( + sortOrder: Seq[SortOrder], + output: Seq[Attribute], + schema: StructType, + enableRadixSort: Boolean, + testSpillFrequency: Int = 0): UnsafeExternalRowSorter = { + val ordering = RowOrdering.create(sortOrder, output) + + // The comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + val canUseRadixSort = enableRadixSort && sortOrder.length == 1 && + SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression) + + // The generator for prefix + val prefixExpr = SortPrefix(boundSortExpression) + val prefixProjection = UnsafeProjection.create(Seq(prefixExpr)) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix + override def computePrefix(row: InternalRow): + UnsafeExternalRowSorter.PrefixComputer.Prefix = { + val prefix = prefixProjection.apply(row) + result.isNull = prefix.isNullAt(0) + result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0) + result + } + } + + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + val rowSorter = UnsafeExternalRowSorter.create( + schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort) + + if (testSpillFrequency > 0) { + rowSorter.setTestSpillFrequency(testSpillFrequency) + } + rowSorter + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index bfe4bd2924118..62539c5463316 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -23,8 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager -import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions -import org.apache.spark.sql.execution.datasources.SchemaPruning +import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes} import org.apache.spark.sql.execution.datasources.v2.{V2ScanPartitioning, V2ScanRelationPushDown, V2Writes} import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning} import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs} @@ -38,6 +37,7 @@ class SparkOptimizer( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = // TODO: move SchemaPruning into catalyst Seq(SchemaPruning) :+ + V1Writes :+ V2ScanRelationPushDown :+ V2ScanPartitioning :+ V2Writes :+ @@ -78,6 +78,7 @@ class SparkOptimizer( ExtractPythonUDFFromJoinCondition.ruleName :+ ExtractPythonUDFFromAggregate.ruleName :+ ExtractGroupingPythonUDFFromAggregate.ruleName :+ ExtractPythonUDFs.ruleName :+ + V1Writes.ruleName :+ V2ScanRelationPushDown.ruleName :+ V2ScanPartitioning.ruleName :+ V2Writes.ruleName diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index e64426f8de8f3..7e37727e30f41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -21,6 +21,7 @@ import java.net.URI import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.QueryCompilationErrors @@ -141,7 +142,18 @@ case class CreateDataSourceTableAsSelectCommand( mode: SaveMode, query: LogicalPlan, outputColumnNames: Seq[String]) - extends DataWritingCommand { + extends V1Write { + + override lazy val partitionColumns: Seq[Attribute] = { + table.partitionColumnNames.map { name => + query.resolve(name :: Nil, SparkSession.active.sessionState.analyzer.resolver).getOrElse { + throw QueryCompilationErrors.cannotResolveAttributeError( + name, query.output.map(_.name).mkString(", ")) + }.asInstanceOf[Attribute] + } + } + override lazy val bucketSpec: Option[BucketSpec] = table.bucketSpec + override lazy val options: Map[String, String] = table.storage.properties override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 643902e7cbcb2..369f9cf97630c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -34,9 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter} @@ -47,7 +45,7 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} /** A helper object for writing FileFormat data out to a location. */ -object FileFormatWriter extends Logging { +object FileFormatWriter extends Logging with V1WritesHelper { /** Describes how output files should be placed in the filesystem. */ case class OutputSpec( outputPath: String, @@ -78,6 +76,7 @@ object FileFormatWriter extends Logging { maxWriters: Int, createSorter: () => UnsafeExternalRowSorter) + // scalastyle:off argcount /** * Basic work flow of this command is: * 1. Driver side setup, including output committer initialization and data source specific @@ -100,6 +99,7 @@ object FileFormatWriter extends Logging { outputSpec: OutputSpec, hadoopConf: Configuration, partitionColumns: Seq[Attribute], + staticPartitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], statsTrackers: Seq[WriteJobStatsTracker], options: Map[String, String]) @@ -126,39 +126,7 @@ object FileFormatWriter extends Logging { } val empty2NullPlan = if (needConvert) ProjectExec(projectList, plan) else plan - val writerBucketSpec = bucketSpec.map { spec => - val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) - - if (options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, "false") == - "true") { - // Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression. - // Without the extra bitwise-and operation, we can get wrong bucket id when hash value of - // columns is negative. See Hive implementation in - // `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`. - val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue)) - val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets)) - - // The bucket file name prefix is following Hive, Presto and Trino conversion, so this - // makes sure Hive bucketed table written by Spark, can be read by other SQL engines. - // - // Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`. - // Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`. - val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_" - WriterBucketSpec(bucketIdExpression, fileNamePrefix) - } else { - // Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id - // expression, so that we can guarantee the data distribution is same between shuffle and - // bucketed data source, which enables us to only shuffle one side when join a bucketed - // table and a normal one. - val bucketIdExpression = HashPartitioning(bucketColumns, spec.numBuckets) - .partitionIdExpression - WriterBucketSpec(bucketIdExpression, (_: Int) => "") - } - } - val sortColumns = bucketSpec.toSeq.flatMap { - spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) - } - + val writerBucketSpec = getWriterBucketSpec(bucketSpec, dataColumns, options) val caseInsensitiveOptions = CaseInsensitiveMap(options) val dataSchema = dataColumns.toStructType @@ -184,20 +152,6 @@ object FileFormatWriter extends Logging { statsTrackers = statsTrackers ) - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val requiredOrdering = - partitionColumns ++ writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns - // the sort order doesn't matter - val actualOrdering = empty2NullPlan.outputOrdering.map(_.child) - val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { - false - } else { - requiredOrdering.zip(actualOrdering).forall { - case (requiredOrder, childOutputOrder) => - requiredOrder.semanticEquals(childOutputOrder) - } - } - SQLExecution.checkSQLExecutionId(sparkSession) // propagate the description UUID into the jobs, so that committers @@ -208,29 +162,26 @@ object FileFormatWriter extends Logging { // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. committer.setupJob(job) + val sortColumns = getBucketSortColumns(bucketSpec, dataColumns) try { - val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) { - (empty2NullPlan.execute(), None) + val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters + val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty + val concurrentOutputWriterSpec = if (concurrentWritersEnabled) { + val output = empty2NullPlan.output + val enableRadixSort = sparkSession.sessionState.conf.enableRadixSort + val outputSchema = empty2NullPlan.schema + Some(ConcurrentOutputWriterSpec(maxWriters, + () => SortExec.createSorter( + getSortOrder(output, partitionColumns, staticPartitionColumns.size, + bucketSpec, options), + output, + outputSchema, + enableRadixSort + ))) } else { - // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and - // the physical plan may have different attribute ids due to optimizer removing some - // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. - val orderingExpr = bindReferences( - requiredOrdering.map(SortOrder(_, Ascending)), finalOutputSpec.outputColumns) - val sortPlan = SortExec( - orderingExpr, - global = false, - child = empty2NullPlan) - - val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters - val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty - if (concurrentWritersEnabled) { - (empty2NullPlan.execute(), - Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter()))) - } else { - (sortPlan.execute(), None) - } + None } + val rdd = empty2NullPlan.execute() // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single // partition rdd to make sure we at least set up one write task to write the metadata. @@ -278,6 +229,7 @@ object FileFormatWriter extends Logging { throw QueryExecutionErrors.jobAbortedError(cause) } } + // scalastyle:on argcount /** Writes data out in a single Spark task. */ private def executeTask( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 74be483cd7c37..cc4f5f289bef9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -48,16 +48,16 @@ case class InsertIntoHadoopFsRelationCommand( outputPath: Path, staticPartitions: TablePartitionSpec, ifPartitionNotExists: Boolean, - partitionColumns: Seq[Attribute], - bucketSpec: Option[BucketSpec], + override val partitionColumns: Seq[Attribute], + override val bucketSpec: Option[BucketSpec], fileFormat: FileFormat, - options: Map[String, String], + override val options: Map[String, String], query: LogicalPlan, mode: SaveMode, catalogTable: Option[CatalogTable], fileIndex: Option[FileIndex], outputColumnNames: Seq[String]) - extends DataWritingCommand { + extends V1Write { private lazy val parameters = CaseInsensitiveMap(options) @@ -74,6 +74,8 @@ case class InsertIntoHadoopFsRelationCommand( staticPartitions.size < partitionColumns.length } + override lazy val numStaticPartitions: Int = staticPartitions.size + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that SchemaUtils.checkColumnNameDuplication( @@ -181,6 +183,7 @@ case class InsertIntoHadoopFsRelationCommand( committerOutputPath.toString, customPartitionLocations, outputColumns), hadoopConf = hadoopConf, partitionColumns = partitionColumns, + staticPartitionColumns = partitionColumns.take(staticPartitions.size), bucketSpec = bucketSpec, statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), options = options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala new file mode 100644 index 0000000000000..24265cde98513 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, BitwiseAnd, HiveHash, Literal, Pmod, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort} +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.command.DataWritingCommand +import org.apache.spark.sql.internal.SQLConf + +/** + * V1 write includes both datasoruce v1 and hive, that requires a specific ordering of data. + * It should be resolved by [[V1Writes]]. + * + * TODO(SPARK-37333): Specify the required distribution at V1Write + */ +trait V1Write extends DataWritingCommand with V1WritesHelper { + def partitionColumns: Seq[Attribute] = Seq.empty + def numStaticPartitions: Int = 0 + def bucketSpec: Option[BucketSpec] = None + def options: Map[String, String] = Map.empty + + final def requiredOrdering: Seq[SortOrder] = { + getSortOrder( + outputColumns, + partitionColumns, + numStaticPartitions, + bucketSpec, + options) + } +} + +/** + * A rule that makes sure the v1 write requirement, e.g. requiredOrdering + */ +object V1Writes extends Rule[LogicalPlan] with V1WritesHelper { + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + case write: V1Write => + val partitionSet = AttributeSet(write.partitionColumns) + val dataColumns = write.outputColumns.filterNot(partitionSet.contains) + val sortColumns = getBucketSortColumns(write.bucketSpec, dataColumns) + val newQuery = prepareQuery(write.query, write.requiredOrdering, sortColumns) + write.withNewChildren(newQuery :: Nil) + + case _ => plan + } +} + +trait V1WritesHelper { + + def getWriterBucketSpec( + bucketSpec: Option[BucketSpec], + dataColumns: Seq[Attribute], + options: Map[String, String]): Option[WriterBucketSpec] = { + bucketSpec.map { spec => + val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) + if (options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, "false") == + "true") { + // Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression. + // Without the extra bitwise-and operation, we can get wrong bucket id when hash value of + // columns is negative. See Hive implementation in + // `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`. + val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue)) + val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets)) + + // The bucket file name prefix is following Hive, Presto and Trino conversion, so this + // makes sure Hive bucketed table written by Spark, can be read by other SQL engines. + // + // Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`. + // Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`. + val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_" + WriterBucketSpec(bucketIdExpression, fileNamePrefix) + } else { + // Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id + // expression, so that we can guarantee the data distribution is same between shuffle and + // bucketed data source, which enables us to only shuffle one side when join a bucketed + // table and a normal one. + val bucketIdExpression = HashPartitioning(bucketColumns, spec.numBuckets) + .partitionIdExpression + WriterBucketSpec(bucketIdExpression, (_: Int) => "") + } + } + } + + def getBucketSortColumns( + bucketSpec: Option[BucketSpec], dataColumns: Seq[Attribute]): Seq[Attribute] = { + bucketSpec.toSeq.flatMap { + spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) + } + } + + def getSortOrder( + outputColumns: Seq[Attribute], + partitionColumns: Seq[Attribute], + numStaticPartitions: Int, + bucketSpec: Option[BucketSpec], + options: Map[String, String]): Seq[SortOrder] = { + val partitionSet = AttributeSet(partitionColumns) + val dataColumns = outputColumns.filterNot(partitionSet.contains) + val writerBucketSpec = getWriterBucketSpec(bucketSpec, dataColumns, options) + val sortColumns = getBucketSortColumns(bucketSpec, dataColumns) + + assert(partitionColumns.size >= numStaticPartitions) + // We should first sort by partition columns, then bucket id, and finally sorting columns. + (partitionColumns.takeRight(partitionColumns.size - numStaticPartitions) ++ + writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns) + .map(SortOrder(_, Ascending)) + } + + def prepareQuery( + query: LogicalPlan, + requiredOrdering: Seq[SortOrder], + sortColumns: Seq[Attribute]): LogicalPlan = { + val actualOrdering = query.outputOrdering + val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { + false + } else { + requiredOrdering.zip(actualOrdering).forall { + case (requiredOrder, childOutputOrder) => + requiredOrder.semanticEquals(childOutputOrder) + } + } + + if (orderingMatched || + (SQLConf.get.maxConcurrentOutputFileWriters > 0 && sortColumns.isEmpty)) { + query + } else { + Sort(requiredOrdering, false, query) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 5058a1dfc3baf..91dff9b4339d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -176,6 +176,7 @@ class FileStreamSink( outputSpec = FileFormatWriter.OutputSpec(path, Map.empty, qe.analyzed.output), hadoopConf = hadoopConf, partitionColumns = partitionColumns, + staticPartitionColumns = Seq.empty, bucketSpec = None, statsTrackers = Seq(basicWriteJobStatsTracker), options = options) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 96b41dd8e35fa..05e01080b23be 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -20,22 +20,39 @@ package org.apache.spark.sql.hive.execution import scala.util.control.NonFatal import org.apache.spark.sql.{Row, SaveMode, SparkSession} -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.{DataWritingCommand, DDLUtils} -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoHadoopFsRelationCommand, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoHadoopFsRelationCommand, LogicalRelation, V1Write} import org.apache.spark.sql.hive.HiveSessionCatalog import org.apache.spark.util.Utils -trait CreateHiveTableAsSelectBase extends DataWritingCommand { +trait CreateHiveTableAsSelectBase extends V1Write with V1HiveWritesHelper { val tableDesc: CatalogTable val query: LogicalPlan val outputColumnNames: Seq[String] val mode: SaveMode + override lazy val partitionColumns: Seq[Attribute] = { + // if table is not exists the schema should always be empty + val table = if (tableDesc.schema.isEmpty) { + val tableSchema = CharVarcharUtils.getRawSchema(outputColumns.toStructType, conf) + tableDesc.copy(schema = tableSchema) + } else { + tableDesc + } + // For CTAS, there is no static partition values to insert. + val partition = tableDesc.partitionColumnNames.map(_ -> None).toMap + getDynamicPartitionColumns(table, query, partition) + } + override lazy val bucketSpec: Option[BucketSpec] = tableDesc.bucketSpec + override lazy val options: Map[String, String] = + getOptionsWithHiveBucketWrite(tableDesc.bucketSpec) + protected val tableIdentifier = tableDesc.identifier override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 8fca95130dd8f..6631065250623 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -17,23 +17,20 @@ package org.apache.spark.sql.hive.execution -import java.util.Locale - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.plan.TableDesc -import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.CommandUtils +import org.apache.spark.sql.execution.datasources.V1Write import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive.client.HiveClientImpl @@ -76,7 +73,13 @@ case class InsertIntoHiveTable( query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean, - outputColumnNames: Seq[String]) extends SaveAsHiveFile { + outputColumnNames: Seq[String]) + extends SaveAsHiveFile with V1Write with V1HiveWritesHelper { + + override lazy val partitionColumns: Seq[Attribute] = + getDynamicPartitionColumns(table, query, partition) + override lazy val bucketSpec: Option[BucketSpec] = table.bucketSpec + override lazy val options: Map[String, String] = getOptionsWithHiveBucketWrite(table.bucketSpec) /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the @@ -131,55 +134,7 @@ case class InsertIntoHiveTable( tmpLocation: Path, child: SparkPlan): Unit = { val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) - - val numDynamicPartitions = partition.values.count(_.isEmpty) - val numStaticPartitions = partition.values.count(_.nonEmpty) - val partitionSpec = partition.map { - case (key, Some(null)) => key -> ExternalCatalogUtils.DEFAULT_PARTITION_NAME - case (key, Some(value)) => key -> value - case (key, None) => key -> "" - } - - // All partition column names in the format of "//..." - val partitionColumns = fileSinkConf.getTableInfo.getProperties.getProperty("partition_columns") - val partitionColumnNames = Option(partitionColumns).map(_.split("/")).getOrElse(Array.empty) - - // By this time, the partition map must match the table's partition columns - if (partitionColumnNames.toSet != partition.keySet) { - throw QueryExecutionErrors.requestedPartitionsMismatchTablePartitionsError(table, partition) - } - - // Validate partition spec if there exist any dynamic partitions - if (numDynamicPartitions > 0) { - // Report error if dynamic partitioning is not enabled - if (!hadoopConf.get("hive.exec.dynamic.partition", "true").toBoolean) { - throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_DISABLED.getMsg) - } - - // Report error if dynamic partition strict mode is on but no static partition is found - if (numStaticPartitions == 0 && - hadoopConf.get("hive.exec.dynamic.partition.mode", "strict").equalsIgnoreCase("strict")) { - throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_STRICT_MODE.getMsg) - } - - // Report error if any static partition appears after a dynamic partition - val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty) - if (isDynamic.init.zip(isDynamic.tail).contains((true, false))) { - throw new AnalysisException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) - } - } - - val partitionAttributes = partitionColumnNames.takeRight(numDynamicPartitions).map { name => - val attr = query.resolve(name :: Nil, sparkSession.sessionState.analyzer.resolver).getOrElse { - throw QueryCompilationErrors.cannotResolveAttributeError( - name, query.output.map(_.name).mkString(", ")) - }.asInstanceOf[Attribute] - // SPARK-28054: Hive metastore is not case preserving and keeps partition columns - // with lower cased names. Hive will validate the column names in the partition directories - // during `loadDynamicPartitions`. Spark needs to write partition directories with lower-cased - // column names in order to make `loadDynamicPartitions` work. - attr.withName(name.toLowerCase(Locale.ROOT)) - } + val dynamicPartitionAttributes = getDynamicPartitionColumns(table, query, partition) val writtenParts = saveAsHiveFile( sparkSession = sparkSession, @@ -187,9 +142,11 @@ case class InsertIntoHiveTable( hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpLocation.toString, - partitionAttributes = partitionAttributes, + partitionAttributes = dynamicPartitionAttributes, bucketSpec = table.bucketSpec) + val partitionSpec = getPartitionSpec(partition) + val numDynamicPartitions = partition.values.count(_.isEmpty) if (partition.nonEmpty) { if (numDynamicPartitions > 0) { if (overwrite && table.tableType == CatalogTableType.EXTERNAL) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 7f885729bd2be..66ff1c0ab799d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -37,13 +37,13 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.DataWritingCommand -import org.apache.spark.sql.execution.datasources.{BucketingUtils, FileFormatWriter} +import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive.client.HiveVersion // Base trait from which all hive insert statement physical execution extends. -private[hive] trait SaveAsHiveFile extends DataWritingCommand { +private[hive] trait SaveAsHiveFile extends DataWritingCommand with V1HiveWritesHelper { var createdTempDir: Option[Path] = None @@ -86,10 +86,6 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { jobId = java.util.UUID.randomUUID().toString, outputPath = outputLocation) - val options = bucketSpec - .map(_ => Map(BucketingUtils.optionForHiveCompatibleBucketWrite -> "true")) - .getOrElse(Map.empty) - FileFormatWriter.write( sparkSession = sparkSession, plan = plan, @@ -99,9 +95,10 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations, outputColumns), hadoopConf = hadoopConf, partitionColumns = partitionAttributes, + staticPartitionColumns = Seq.empty, bucketSpec = bucketSpec, statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), - options = options) + options = getOptionsWithHiveBucketWrite(bucketSpec)) } protected def getExternalTmpPath( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/V1HiveWritesHelper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/V1HiveWritesHelper.scala new file mode 100644 index 0000000000000..5ef5e7c04e6f3 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/V1HiveWritesHelper.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.util.Locale + +import org.apache.hadoop.hive.ql.ErrorMsg +import org.apache.hadoop.hive.ql.plan.TableDesc + +import org.apache.spark.SparkException +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, ExternalCatalogUtils} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.execution.datasources.BucketingUtils +import org.apache.spark.sql.hive.client.HiveClientImpl + +trait V1HiveWritesHelper { + def getOptionsWithHiveBucketWrite(bucketSpec: Option[BucketSpec]): Map[String, String] = { + bucketSpec + .map(_ => Map(BucketingUtils.optionForHiveCompatibleBucketWrite -> "true")) + .getOrElse(Map.empty) + } + + def getPartitionSpec(partition: Map[String, Option[String]]): Map[String, String] = { + partition.map { + case (key, Some(null)) => key -> ExternalCatalogUtils.DEFAULT_PARTITION_NAME + case (key, Some(value)) => key -> value + case (key, None) => key -> "" + } + } + + def getDynamicPartitionColumns( + table: CatalogTable, + query: LogicalPlan, + partition: Map[String, Option[String]]): Seq[Attribute] = { + val hadoopConf = SparkSession.active.sessionState.newHadoopConf() + val numStaticPartitions = partition.values.count(_.nonEmpty) + val numDynamicPartitions = partition.values.count(_.isEmpty) + + val hiveQlTable = HiveClientImpl.toHiveTable(table) + val tableDesc = new TableDesc( + hiveQlTable.getInputFormatClass, + hiveQlTable.getOutputFormatClass, + hiveQlTable.getMetadata + ) + // All partition column names in the format of "//..." + val partitionColumns = tableDesc.getProperties.getProperty("partition_columns") + val partitionColumnNames = Option(partitionColumns).map(_.split("/")).getOrElse(Array.empty) + val partitionSpec = getPartitionSpec(partition) + + // By this time, the partition map must match the table's partition columns + if (partitionColumnNames.toSet != partition.keySet) { + throw QueryExecutionErrors.requestedPartitionsMismatchTablePartitionsError(table, partition) + } + + // Validate partition spec if there exist any dynamic partitions + if (numDynamicPartitions > 0) { + // Report error if dynamic partitioning is not enabled + if (!hadoopConf.get("hive.exec.dynamic.partition", "true").toBoolean) { + throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_DISABLED.getMsg) + } + + // Report error if dynamic partition strict mode is on but no static partition is found + if (numStaticPartitions == 0 && + hadoopConf.get("hive.exec.dynamic.partition.mode", "strict").equalsIgnoreCase("strict")) { + throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_STRICT_MODE.getMsg) + } + + // Report error if any static partition appears after a dynamic partition + val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty) + if (isDynamic.init.zip(isDynamic.tail).contains((true, false))) { + throw new AnalysisException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) + } + } + + partitionColumnNames.takeRight(numDynamicPartitions).map { name => + val attr = query.resolve(name :: Nil, SparkSession.active.sessionState.analyzer.resolver) + .getOrElse { + throw QueryCompilationErrors.cannotResolveAttributeError( + name, query.output.map(_.name).mkString(", ")) + }.asInstanceOf[Attribute] + // SPARK-28054: Hive metastore is not case preserving and keeps partition columns + // with lower cased names. Hive will validate the column names in the partition directories + // during `loadDynamicPartitions`. Spark needs to write partition directories with lower-cased + // column names in order to make `loadDynamicPartitions` work. + attr.withName(name.toLowerCase(Locale.ROOT)) + } + } +}