-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-19563][SQL] avoid unnecessary sort in FileFormatWriter #16898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,10 +38,9 @@ import org.apache.spark.sql.catalyst.expressions._ | |
| import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.util.DateTimeUtils | ||
| import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter} | ||
| import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} | ||
| import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} | ||
| import org.apache.spark.sql.types.{StringType, StructType} | ||
| import org.apache.spark.util.{SerializableConfiguration, Utils} | ||
| import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter | ||
|
|
||
|
|
||
| /** A helper object for writing FileFormat data out to a location. */ | ||
|
|
@@ -64,9 +63,9 @@ object FileFormatWriter extends Logging { | |
| val serializableHadoopConf: SerializableConfiguration, | ||
| val outputWriterFactory: OutputWriterFactory, | ||
| val allColumns: Seq[Attribute], | ||
| val partitionColumns: Seq[Attribute], | ||
| val dataColumns: Seq[Attribute], | ||
| val bucketSpec: Option[BucketSpec], | ||
| val partitionColumns: Seq[Attribute], | ||
| val bucketIdExpression: Option[Expression], | ||
| val path: String, | ||
| val customPartitionLocations: Map[TablePartitionSpec, String], | ||
| val maxRecordsPerFile: Long) | ||
|
|
@@ -108,9 +107,21 @@ object FileFormatWriter extends Logging { | |
| job.setOutputValueClass(classOf[InternalRow]) | ||
| FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) | ||
|
|
||
| val allColumns = queryExecution.logical.output | ||
| val partitionSet = AttributeSet(partitionColumns) | ||
| val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains) | ||
|
|
||
| val bucketIdExpression = bucketSpec.map { spec => | ||
| val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) | ||
| // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can | ||
| // guarantee the data distribution is same between shuffle and bucketed data source, which | ||
| // enables us to only shuffle one side when join a bucketed table and a normal one. | ||
| HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression | ||
| } | ||
| val sortColumns = bucketSpec.toSeq.flatMap { | ||
| spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) | ||
| } | ||
|
|
||
| // Note: prepareWrite has side effect. It sets "job". | ||
| val outputWriterFactory = | ||
| fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType) | ||
|
|
@@ -119,23 +130,45 @@ object FileFormatWriter extends Logging { | |
| uuid = UUID.randomUUID().toString, | ||
| serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), | ||
| outputWriterFactory = outputWriterFactory, | ||
| allColumns = queryExecution.logical.output, | ||
| partitionColumns = partitionColumns, | ||
| allColumns = allColumns, | ||
| dataColumns = dataColumns, | ||
| bucketSpec = bucketSpec, | ||
| partitionColumns = partitionColumns, | ||
| bucketIdExpression = bucketIdExpression, | ||
| path = outputSpec.outputPath, | ||
| customPartitionLocations = outputSpec.customPartitionLocations, | ||
| maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) | ||
| .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile) | ||
| ) | ||
|
|
||
| // We should first sort by partition columns, then bucket id, and finally sorting columns. | ||
| val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns | ||
| // the sort order doesn't matter | ||
| val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan would it be possible to use the logical plan rather than the executedPlan? If the optimizer decides the data is already sorted according according to the logical plan the executedPlan won't include the fields.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would be great, but may need some refactoring. |
||
| val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { | ||
| false | ||
| } else { | ||
| requiredOrdering.zip(actualOrdering).forall { | ||
| case (requiredOrder, childOutputOrder) => | ||
| requiredOrder.semanticEquals(childOutputOrder) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's |
||
| } | ||
| } | ||
|
|
||
| SQLExecution.withNewExecutionId(sparkSession, queryExecution) { | ||
| // This call shouldn't be put into the `try` block below because it only initializes and | ||
| // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. | ||
| committer.setupJob(job) | ||
|
|
||
| try { | ||
| val ret = sparkSession.sparkContext.runJob(queryExecution.toRdd, | ||
| val rdd = if (orderingMatched) { | ||
| queryExecution.toRdd | ||
| } else { | ||
| SortExec( | ||
| requiredOrdering.map(SortOrder(_, Ascending)), | ||
| global = false, | ||
| child = queryExecution.executedPlan).execute() | ||
| } | ||
|
|
||
| val ret = sparkSession.sparkContext.runJob(rdd, | ||
| (taskContext: TaskContext, iter: Iterator[InternalRow]) => { | ||
| executeTask( | ||
| description = description, | ||
|
|
@@ -189,7 +222,7 @@ object FileFormatWriter extends Logging { | |
| committer.setupTask(taskAttemptContext) | ||
|
|
||
| val writeTask = | ||
| if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) { | ||
| if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { | ||
| new SingleDirectoryWriteTask(description, taskAttemptContext, committer) | ||
| } else { | ||
| new DynamicPartitionWriteTask(description, taskAttemptContext, committer) | ||
|
|
@@ -287,31 +320,16 @@ object FileFormatWriter extends Logging { | |
| * multiple directories (partitions) or files (bucketing). | ||
| */ | ||
| private class DynamicPartitionWriteTask( | ||
| description: WriteJobDescription, | ||
| desc: WriteJobDescription, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to change both to make it consistent. |
||
| taskAttemptContext: TaskAttemptContext, | ||
| committer: FileCommitProtocol) extends ExecuteWriteTask { | ||
|
|
||
| // currentWriter is initialized whenever we see a new key | ||
| private var currentWriter: OutputWriter = _ | ||
|
|
||
| private val bucketColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap { | ||
| spec => spec.bucketColumnNames.map(c => description.allColumns.find(_.name == c).get) | ||
| } | ||
|
|
||
| private val sortColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap { | ||
| spec => spec.sortColumnNames.map(c => description.allColumns.find(_.name == c).get) | ||
| } | ||
|
|
||
| private def bucketIdExpression: Option[Expression] = description.bucketSpec.map { spec => | ||
| // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can | ||
| // guarantee the data distribution is same between shuffle and bucketed data source, which | ||
| // enables us to only shuffle one side when join a bucketed table and a normal one. | ||
| HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression | ||
| } | ||
|
|
||
| /** Expressions that given a partition key build a string like: col1=val/col2=val/... */ | ||
| private def partitionStringExpression: Seq[Expression] = { | ||
| description.partitionColumns.zipWithIndex.flatMap { case (c, i) => | ||
| /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */ | ||
| private def partitionPathExpression: Seq[Expression] = { | ||
| desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => | ||
| // TODO: use correct timezone for partition values. | ||
| val escaped = ScalaUDF( | ||
| ExternalCatalogUtils.escapePathName _, | ||
|
|
@@ -325,35 +343,46 @@ object FileFormatWriter extends Logging { | |
| } | ||
|
|
||
| /** | ||
| * Open and returns a new OutputWriter given a partition key and optional bucket id. | ||
| * Opens a new OutputWriter given a partition key and optional bucket id. | ||
| * If bucket id is specified, we will append it to the end of the file name, but before the | ||
|
||
| * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet | ||
| * | ||
| * @param key vaues for fields consisting of partition keys for the current row | ||
| * @param partString a function that projects the partition values into a string | ||
| * @param partColsAndBucketId a row consisting of partition columns and a bucket id for the | ||
| * current row. | ||
| * @param getPartitionPath a function that projects the partition values into a path string. | ||
| * @param fileCounter the number of files that have been written in the past for this specific | ||
| * partition. This is used to limit the max number of records written for a | ||
| * single file. The value should start from 0. | ||
| * @param updatedPartitions the set of updated partition paths, we should add the new partition | ||
| * path of this writer to it. | ||
| */ | ||
| private def newOutputWriter( | ||
| key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = { | ||
| val partDir = | ||
| if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0)) | ||
| partColsAndBucketId: InternalRow, | ||
| getPartitionPath: UnsafeProjection, | ||
| fileCounter: Int, | ||
| updatedPartitions: mutable.Set[String]): Unit = { | ||
| val partDir = if (desc.partitionColumns.isEmpty) { | ||
| None | ||
| } else { | ||
| Option(getPartitionPath(partColsAndBucketId).getString(0)) | ||
| } | ||
| partDir.foreach(updatedPartitions.add) | ||
|
|
||
| // If the bucket spec is defined, the bucket column is right after the partition columns | ||
| val bucketId = if (description.bucketSpec.isDefined) { | ||
| BucketingUtils.bucketIdToString(key.getInt(description.partitionColumns.length)) | ||
| // If the bucketId expression is defined, the bucketId column is right after the partition | ||
| // columns. | ||
| val bucketId = if (desc.bucketIdExpression.isDefined) { | ||
| BucketingUtils.bucketIdToString(partColsAndBucketId.getInt(desc.partitionColumns.length)) | ||
| } else { | ||
| "" | ||
| } | ||
|
|
||
| // This must be in a form that matches our bucketing format. See BucketingUtils. | ||
| val ext = f"$bucketId.c$fileCounter%03d" + | ||
| description.outputWriterFactory.getFileExtension(taskAttemptContext) | ||
| desc.outputWriterFactory.getFileExtension(taskAttemptContext) | ||
|
|
||
| val customPath = partDir match { | ||
| case Some(dir) => | ||
| description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) | ||
| desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) | ||
| case _ => | ||
| None | ||
| } | ||
|
|
@@ -363,80 +392,42 @@ object FileFormatWriter extends Logging { | |
| committer.newTaskTempFile(taskAttemptContext, partDir, ext) | ||
| } | ||
|
|
||
| currentWriter = description.outputWriterFactory.newInstance( | ||
| currentWriter = desc.outputWriterFactory.newInstance( | ||
| path = path, | ||
| dataSchema = description.dataColumns.toStructType, | ||
| dataSchema = desc.dataColumns.toStructType, | ||
| context = taskAttemptContext) | ||
| } | ||
|
|
||
| override def execute(iter: Iterator[InternalRow]): Set[String] = { | ||
| // We should first sort by partition columns, then bucket id, and finally sorting columns. | ||
| val sortingExpressions: Seq[Expression] = | ||
| description.partitionColumns ++ bucketIdExpression ++ sortColumns | ||
| val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns) | ||
|
|
||
| val sortingKeySchema = StructType(sortingExpressions.map { | ||
| case a: Attribute => StructField(a.name, a.dataType, a.nullable) | ||
| // The sorting expressions are all `Attribute` except bucket id. | ||
| case _ => StructField("bucketId", IntegerType, nullable = false) | ||
| }) | ||
|
|
||
| // Returns the data columns to be written given an input row | ||
| val getOutputRow = UnsafeProjection.create( | ||
| description.dataColumns, description.allColumns) | ||
|
|
||
| // Returns the partition path given a partition key. | ||
| val getPartitionStringFunc = UnsafeProjection.create( | ||
| Seq(Concat(partitionStringExpression)), description.partitionColumns) | ||
|
|
||
| // Sorts the data before write, so that we only need one writer at the same time. | ||
| val sorter = new UnsafeKVExternalSorter( | ||
| sortingKeySchema, | ||
| StructType.fromAttributes(description.dataColumns), | ||
| SparkEnv.get.blockManager, | ||
| SparkEnv.get.serializerManager, | ||
| TaskContext.get().taskMemoryManager().pageSizeBytes, | ||
| SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", | ||
| UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)) | ||
|
|
||
| while (iter.hasNext) { | ||
| val currentRow = iter.next() | ||
| sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) | ||
| } | ||
| val getPartitionColsAndBucketId = UnsafeProjection.create( | ||
| desc.partitionColumns ++ desc.bucketIdExpression, desc.allColumns) | ||
|
|
||
| val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { | ||
| identity | ||
| } else { | ||
| UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { | ||
| case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) | ||
| }) | ||
| } | ||
| // Generates the partition path given the row generated by `getPartitionColsAndBucketId`. | ||
| val getPartPath = UnsafeProjection.create( | ||
| Seq(Concat(partitionPathExpression)), desc.partitionColumns) | ||
|
|
||
| val sortedIterator = sorter.sortedIterator() | ||
| // Returns the data columns to be written given an input row | ||
| val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns) | ||
|
|
||
| // If anything below fails, we should abort the task. | ||
| var recordsInFile: Long = 0L | ||
| var fileCounter = 0 | ||
| var currentKey: UnsafeRow = null | ||
| var currentPartColsAndBucketId: UnsafeRow = null | ||
| val updatedPartitions = mutable.Set[String]() | ||
| while (sortedIterator.next()) { | ||
| val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] | ||
| if (currentKey != nextKey) { | ||
| // See a new key - write to a new partition (new file). | ||
| currentKey = nextKey.copy() | ||
| logDebug(s"Writing partition: $currentKey") | ||
| for (row <- iter) { | ||
| val nextPartColsAndBucketId = getPartitionColsAndBucketId(row) | ||
|
||
| if (currentPartColsAndBucketId != nextPartColsAndBucketId) { | ||
| // See a new partition or bucket - write to a new partition dir (or a new bucket file). | ||
| currentPartColsAndBucketId = nextPartColsAndBucketId.copy() | ||
| logDebug(s"Writing partition: $currentPartColsAndBucketId") | ||
|
|
||
| recordsInFile = 0 | ||
| fileCounter = 0 | ||
|
|
||
| releaseResources() | ||
| newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) | ||
| val partitionPath = getPartitionStringFunc(currentKey).getString(0) | ||
| if (partitionPath.nonEmpty) { | ||
| updatedPartitions.add(partitionPath) | ||
| } | ||
| } else if (description.maxRecordsPerFile > 0 && | ||
| recordsInFile >= description.maxRecordsPerFile) { | ||
| newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions) | ||
| } else if (desc.maxRecordsPerFile > 0 && | ||
| recordsInFile >= desc.maxRecordsPerFile) { | ||
| // Exceeded the threshold in terms of the number of records per file. | ||
| // Create a new file by increasing the file counter. | ||
| recordsInFile = 0 | ||
|
|
@@ -445,10 +436,10 @@ object FileFormatWriter extends Logging { | |
| s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") | ||
|
|
||
| releaseResources() | ||
| newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) | ||
| newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions) | ||
| } | ||
|
|
||
| currentWriter.write(sortedIterator.getValue) | ||
| currentWriter.write(getOutputRow(row)) | ||
| recordsInFile += 1 | ||
| } | ||
| releaseResources() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we rewrite it to
val dataColumns = allColumns.filterNot(partitionColumns.contains), we do not needpartitionSetThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's so minor, I'll fix it in my next PR