Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,14 @@ object SQLConf {
.longConf
.createWithDefault(67108864L)

val PLANNED_WRITE_ENABLED = buildConf("spark.sql.optimizer.plannedWrite.enabled")
.internal()
.doc("When set to true, Spark optimizer will add logical sort operators to V1 write commands " +
"if needed so that `FileFormatWriter` does not need to insert physical sorts.")
.version("3.4.0")
.booleanConf
.createWithDefault(true)

val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed")
.doc("When set to true Spark SQL will automatically select a compression codec for each " +
"column based on statistics of the data.")
Expand Down Expand Up @@ -4617,6 +4625,8 @@ class SQLConf extends Serializable with Logging {

def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS)

def plannedWriteEnabled: Boolean = getConf(SQLConf.PLANNED_WRITE_ENABLED)

def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT)

def legacyInferArrayTypeFromFirstElement: Boolean = getConf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.datasources.SchemaPruning
import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes}
import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes}
import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning}
import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs}
Expand All @@ -39,6 +38,7 @@ class SparkOptimizer(
// TODO: move SchemaPruning into catalyst
Seq(SchemaPruning) :+
GroupBasedRowLevelOperationScanPlanning :+
V1Writes :+
V2ScanRelationPushDown :+
V2ScanPartitioningAndOrdering :+
V2Writes :+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
import org.apache.spark.sql.execution.datasources.V1WriteCommand
import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -46,8 +47,9 @@ case class InsertAdaptiveSparkPlan(
case _ if !conf.adaptiveExecutionEnabled => plan
case _: ExecutedCommandExec => plan
case _: CommandResultExec => plan
case c: DataWritingCommandExec => c.copy(child = apply(c.child))
case c: V2CommandExec => c.withNewChildren(c.children.map(apply))
case c: DataWritingCommandExec if !c.cmd.isInstanceOf[V1WriteCommand] =>
c.copy(child = apply(c.child))
case _ if shouldApplyAQE(plan, isSubquery) =>
if (supportAdaptive(plan)) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package org.apache.spark.sql.execution.command
import java.net.URI

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.errors.QueryCompilationErrors
Expand Down Expand Up @@ -141,7 +143,18 @@ case class CreateDataSourceTableAsSelectCommand(
mode: SaveMode,
query: LogicalPlan,
outputColumnNames: Seq[String])
extends DataWritingCommand {
extends V1WriteCommand {

override def requiredOrdering: Seq[SortOrder] = {
val unresolvedPartitionColumns = table.partitionColumnNames.map(UnresolvedAttribute.quoted)
val partitionColumns = DataSource.resolvePartitionColumns(
unresolvedPartitionColumns,
outputColumns,
query,
SparkSession.active.sessionState.conf.resolver)
val options = table.storage.properties
V1WritesUtils.getSortOrder(outputColumns, partitionColumns, table.bucketSpec, options)
}

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
assert(table.tableType != CatalogTableType.VIEW)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ import org.apache.spark.SparkException
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, TypeUtils}
import org.apache.spark.sql.connector.catalog.TableProvider
Expand Down Expand Up @@ -519,18 +520,8 @@ case class DataSource(
case format: FileFormat =>
disallowWritingIntervals(outputColumns.map(_.dataType), forbidAnsiIntervals = false)
val cmd = planForWritingFileFormat(format, mode, data)
val resolvedPartCols = cmd.partitionColumns.map { col =>
// The partition columns created in `planForWritingFileFormat` should always be
// `UnresolvedAttribute` with a single name part.
assert(col.isInstanceOf[UnresolvedAttribute])
val unresolved = col.asInstanceOf[UnresolvedAttribute]
assert(unresolved.nameParts.length == 1)
val name = unresolved.nameParts.head
outputColumns.find(a => equality(a.name, name)).getOrElse {
throw QueryCompilationErrors.cannotResolveAttributeError(
name, data.output.map(_.name).mkString(", "))
}
}
val resolvedPartCols =
DataSource.resolvePartitionColumns(cmd.partitionColumns, outputColumns, data, equality)
val resolved = cmd.copy(
partitionColumns = resolvedPartCols,
outputColumnNames = outputColumnNames)
Expand Down Expand Up @@ -836,4 +827,26 @@ object DataSource extends Logging {
throw QueryCompilationErrors.writeEmptySchemasUnsupportedByDataSourceError()
}
}

/**
* Resolve partition columns using output columns of the query plan.
*/
def resolvePartitionColumns(
partitionColumns: Seq[Attribute],
outputColumns: Seq[Attribute],
plan: LogicalPlan,
resolver: Resolver): Seq[Attribute] = {
partitionColumns.map { col =>
// The partition columns created in `planForWritingFileFormat` should always be
// `UnresolvedAttribute` with a single name part.
assert(col.isInstanceOf[UnresolvedAttribute])
val unresolved = col.asInstanceOf[UnresolvedAttribute]
assert(unresolved.nameParts.length == 1)
val name = unresolved.nameParts.head
outputColumns.find(a => resolver(a.name, name)).getOrElse {
throw QueryCompilationErrors.cannotResolveAttributeError(
name, plan.output.map(_.name).mkString(", "))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter}
Expand Down Expand Up @@ -78,6 +77,12 @@ object FileFormatWriter extends Logging {
maxWriters: Int,
createSorter: () => UnsafeExternalRowSorter)

/**
* A variable used in tests to check whether the output ordering of the query matches the
* required ordering of the write command.
*/
private[sql] var outputOrderingMatched: Boolean = false

/**
* Basic work flow of this command is:
* 1. Driver side setup, including output committer initialization and data source specific
Expand Down Expand Up @@ -126,38 +131,8 @@ object FileFormatWriter extends Logging {
}
val empty2NullPlan = if (needConvert) ProjectExec(projectList, plan) else plan

val writerBucketSpec = bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)

if (options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, "false") ==
"true") {
// Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression.
// Without the extra bitwise-and operation, we can get wrong bucket id when hash value of
// columns is negative. See Hive implementation in
// `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`.
val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue))
val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets))

// The bucket file name prefix is following Hive, Presto and Trino conversion, so this
// makes sure Hive bucketed table written by Spark, can be read by other SQL engines.
//
// Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`.
// Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`.
val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_"
WriterBucketSpec(bucketIdExpression, fileNamePrefix)
} else {
// Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id
// expression, so that we can guarantee the data distribution is same between shuffle and
// bucketed data source, which enables us to only shuffle one side when join a bucketed
// table and a normal one.
val bucketIdExpression = HashPartitioning(bucketColumns, spec.numBuckets)
.partitionIdExpression
WriterBucketSpec(bucketIdExpression, (_: Int) => "")
}
}
val sortColumns = bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
}
val writerBucketSpec = V1WritesUtils.getWriterBucketSpec(bucketSpec, dataColumns, options)
val sortColumns = V1WritesUtils.getBucketSortColumns(bucketSpec, dataColumns)

val caseInsensitiveOptions = CaseInsensitiveMap(options)

Expand Down Expand Up @@ -209,6 +184,16 @@ object FileFormatWriter extends Logging {
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
committer.setupJob(job)

// When `PLANNED_WRITE_ENABLED` is true, the optimizer rule V1Writes will add logical sort
// operator based on the required ordering of the V1 write command. So the output
// ordering of the physical plan should always match the required ordering. Here
// we set the variable to verify this behavior in tests.
// There are two cases where FileFormatWriter still needs to add physical sort:
// 1) When the planned write config is disabled.
// 2) When the concurrent writers are enabled (in this case the required ordering of a
// V1 write command will be empty).
if (Utils.isTesting) outputOrderingMatched = orderingMatched

try {
val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) {
(empty2NullPlan.execute(), None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
Expand Down Expand Up @@ -57,7 +57,7 @@ case class InsertIntoHadoopFsRelationCommand(
catalogTable: Option[CatalogTable],
fileIndex: Option[FileIndex],
outputColumnNames: Seq[String])
extends DataWritingCommand {
extends V1WriteCommand {

private lazy val parameters = CaseInsensitiveMap(options)

Expand All @@ -74,6 +74,9 @@ case class InsertIntoHadoopFsRelationCommand(
staticPartitions.size < partitionColumns.length
}

override def requiredOrdering: Seq[SortOrder] =
V1WritesUtils.getSortOrder(outputColumns, partitionColumns, bucketSpec, options)

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
// Most formats don't do well with duplicate columns, so lets not allow that
SchemaUtils.checkColumnNameDuplication(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, BitwiseAnd, HiveHash, Literal, Pmod, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.DataWritingCommand
import org.apache.spark.sql.internal.SQLConf

trait V1WriteCommand extends DataWritingCommand {
// Specify the required ordering for the V1 write command. `FileFormatWriter` will
// add SortExec if necessary when the requiredOrdering is empty.
def requiredOrdering: Seq[SortOrder]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just brainstorming here, if we plan to add a requirement for partitioning, e.g. support shuffle before writing bucket table. Do we want to add a similar RequiresDistributionAndOrdering as v2 now or not?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can add one more method: requiredPartitioning

}

/**
* A rule that adds logical sorts to V1 data writing commands.
*/
object V1Writes extends Rule[LogicalPlan] with SQLConfHelper {
override def apply(plan: LogicalPlan): LogicalPlan = {
if (conf.plannedWriteEnabled) {
plan.transformDown {
case write: V1WriteCommand =>
val newQuery = prepareQuery(write, write.query)
write.withNewChildren(newQuery :: Nil)
}
} else {
plan
}
}

private def prepareQuery(write: V1WriteCommand, query: LogicalPlan): LogicalPlan = {
val requiredOrdering = write.requiredOrdering
val outputOrdering = query.outputOrdering
// Check if the ordering is already matched. It is needed to ensure the
// idempotency of the rule.
val orderingMatched = if (requiredOrdering.length > outputOrdering.length) {
false
} else {
requiredOrdering.zip(outputOrdering).forall {
case (requiredOrder, outputOrder) => requiredOrder.semanticEquals(outputOrder)
}
}
if (orderingMatched) {
query
} else {
Sort(requiredOrdering, global = false, query)
}
}
}

object V1WritesUtils {

def getWriterBucketSpec(
bucketSpec: Option[BucketSpec],
dataColumns: Seq[Attribute],
options: Map[String, String]): Option[WriterBucketSpec] = {
bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)

if (options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, "false") ==
"true") {
// Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression.
// Without the extra bitwise-and operation, we can get wrong bucket id when hash value of
// columns is negative. See Hive implementation in
// `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`.
val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue))
val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets))

// The bucket file name prefix is following Hive, Presto and Trino conversion, so this
// makes sure Hive bucketed table written by Spark, can be read by other SQL engines.
//
// Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`.
// Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`.
val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_"
WriterBucketSpec(bucketIdExpression, fileNamePrefix)
} else {
// Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id
// expression, so that we can guarantee the data distribution is same between shuffle and
// bucketed data source, which enables us to only shuffle one side when join a bucketed
// table and a normal one.
val bucketIdExpression = HashPartitioning(bucketColumns, spec.numBuckets)
.partitionIdExpression
WriterBucketSpec(bucketIdExpression, (_: Int) => "")
}
}
}

def getBucketSortColumns(
bucketSpec: Option[BucketSpec],
dataColumns: Seq[Attribute]): Seq[Attribute] = {
bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
}
}

def getSortOrder(
outputColumns: Seq[Attribute],
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
options: Map[String, String]): Seq[SortOrder] = {
val partitionSet = AttributeSet(partitionColumns)
val dataColumns = outputColumns.filterNot(partitionSet.contains)
val writerBucketSpec = V1WritesUtils.getWriterBucketSpec(bucketSpec, dataColumns, options)
val sortColumns = V1WritesUtils.getBucketSortColumns(bucketSpec, dataColumns)

if (SQLConf.get.maxConcurrentOutputFileWriters > 0 && sortColumns.isEmpty) {
// Do not insert logical sort when concurrent output writers are enabled.
Seq.empty
} else {
// We should first sort by partition columns, then bucket id, and finally sorting columns.
// Note we do not need to convert empty string partition columns to null when sorting the
// columns since null and empty string values will be next to each other.
(partitionColumns ++writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns)
.map(SortOrder(_, Ascending))
}
}
}
Loading