Skip to content

Commit 776b8f1

Browse files
committed
[SPARK-19563][SQL] avoid unnecessary sort in FileFormatWriter
## What changes were proposed in this pull request? In `FileFormatWriter`, we will sort the input rows by partition columns and bucket id and sort columns, if we want to write data out partitioned or bucketed. However, if the data is already sorted, we will sort it again, which is unnecssary. This PR removes the sorting logic in `FileFormatWriter` and use `SortExec` instead. We will not add `SortExec` if the data is already sorted. ## How was this patch tested? I did a micro benchmark manually ``` val df = spark.range(10000000).select($"id", $"id" % 10 as "part").sort("part") spark.time(df.write.partitionBy("part").parquet("/tmp/test")) ``` The result was about 6.4 seconds before this PR, and is 5.7 seconds afterwards. close apache#16724 Author: Wenchen Fan <[email protected]> Closes apache#16898 from cloud-fan/writer.
1 parent 65fe902 commit 776b8f1

File tree

1 file changed

+90
-99
lines changed

1 file changed

+90
-99
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala

Lines changed: 90 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@ import org.apache.spark.sql.catalyst.expressions._
3838
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
3939
import org.apache.spark.sql.catalyst.InternalRow
4040
import org.apache.spark.sql.catalyst.util.DateTimeUtils
41-
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter}
42-
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
41+
import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution}
42+
import org.apache.spark.sql.types.{StringType, StructType}
4343
import org.apache.spark.util.{SerializableConfiguration, Utils}
44-
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
4544

4645

4746
/** A helper object for writing FileFormat data out to a location. */
@@ -64,9 +63,9 @@ object FileFormatWriter extends Logging {
6463
val serializableHadoopConf: SerializableConfiguration,
6564
val outputWriterFactory: OutputWriterFactory,
6665
val allColumns: Seq[Attribute],
67-
val partitionColumns: Seq[Attribute],
6866
val dataColumns: Seq[Attribute],
69-
val bucketSpec: Option[BucketSpec],
67+
val partitionColumns: Seq[Attribute],
68+
val bucketIdExpression: Option[Expression],
7069
val path: String,
7170
val customPartitionLocations: Map[TablePartitionSpec, String],
7271
val maxRecordsPerFile: Long)
@@ -108,9 +107,21 @@ object FileFormatWriter extends Logging {
108107
job.setOutputValueClass(classOf[InternalRow])
109108
FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))
110109

110+
val allColumns = queryExecution.logical.output
111111
val partitionSet = AttributeSet(partitionColumns)
112112
val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains)
113113

114+
val bucketIdExpression = bucketSpec.map { spec =>
115+
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
116+
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
117+
// guarantee the data distribution is same between shuffle and bucketed data source, which
118+
// enables us to only shuffle one side when join a bucketed table and a normal one.
119+
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
120+
}
121+
val sortColumns = bucketSpec.toSeq.flatMap {
122+
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
123+
}
124+
114125
// Note: prepareWrite has side effect. It sets "job".
115126
val outputWriterFactory =
116127
fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType)
@@ -119,23 +130,45 @@ object FileFormatWriter extends Logging {
119130
uuid = UUID.randomUUID().toString,
120131
serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
121132
outputWriterFactory = outputWriterFactory,
122-
allColumns = queryExecution.logical.output,
123-
partitionColumns = partitionColumns,
133+
allColumns = allColumns,
124134
dataColumns = dataColumns,
125-
bucketSpec = bucketSpec,
135+
partitionColumns = partitionColumns,
136+
bucketIdExpression = bucketIdExpression,
126137
path = outputSpec.outputPath,
127138
customPartitionLocations = outputSpec.customPartitionLocations,
128139
maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong)
129140
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
130141
)
131142

143+
// We should first sort by partition columns, then bucket id, and finally sorting columns.
144+
val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
145+
// the sort order doesn't matter
146+
val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child)
147+
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
148+
false
149+
} else {
150+
requiredOrdering.zip(actualOrdering).forall {
151+
case (requiredOrder, childOutputOrder) =>
152+
requiredOrder.semanticEquals(childOutputOrder)
153+
}
154+
}
155+
132156
SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
133157
// This call shouldn't be put into the `try` block below because it only initializes and
134158
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
135159
committer.setupJob(job)
136160

137161
try {
138-
val ret = sparkSession.sparkContext.runJob(queryExecution.toRdd,
162+
val rdd = if (orderingMatched) {
163+
queryExecution.toRdd
164+
} else {
165+
SortExec(
166+
requiredOrdering.map(SortOrder(_, Ascending)),
167+
global = false,
168+
child = queryExecution.executedPlan).execute()
169+
}
170+
171+
val ret = sparkSession.sparkContext.runJob(rdd,
139172
(taskContext: TaskContext, iter: Iterator[InternalRow]) => {
140173
executeTask(
141174
description = description,
@@ -189,7 +222,7 @@ object FileFormatWriter extends Logging {
189222
committer.setupTask(taskAttemptContext)
190223

191224
val writeTask =
192-
if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) {
225+
if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) {
193226
new SingleDirectoryWriteTask(description, taskAttemptContext, committer)
194227
} else {
195228
new DynamicPartitionWriteTask(description, taskAttemptContext, committer)
@@ -287,31 +320,16 @@ object FileFormatWriter extends Logging {
287320
* multiple directories (partitions) or files (bucketing).
288321
*/
289322
private class DynamicPartitionWriteTask(
290-
description: WriteJobDescription,
323+
desc: WriteJobDescription,
291324
taskAttemptContext: TaskAttemptContext,
292325
committer: FileCommitProtocol) extends ExecuteWriteTask {
293326

294327
// currentWriter is initialized whenever we see a new key
295328
private var currentWriter: OutputWriter = _
296329

297-
private val bucketColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap {
298-
spec => spec.bucketColumnNames.map(c => description.allColumns.find(_.name == c).get)
299-
}
300-
301-
private val sortColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap {
302-
spec => spec.sortColumnNames.map(c => description.allColumns.find(_.name == c).get)
303-
}
304-
305-
private def bucketIdExpression: Option[Expression] = description.bucketSpec.map { spec =>
306-
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
307-
// guarantee the data distribution is same between shuffle and bucketed data source, which
308-
// enables us to only shuffle one side when join a bucketed table and a normal one.
309-
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
310-
}
311-
312-
/** Expressions that given a partition key build a string like: col1=val/col2=val/... */
313-
private def partitionStringExpression: Seq[Expression] = {
314-
description.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
330+
/** Expressions that given partition columns build a path string like: col1=val/col2=val/... */
331+
private def partitionPathExpression: Seq[Expression] = {
332+
desc.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
315333
// TODO: use correct timezone for partition values.
316334
val escaped = ScalaUDF(
317335
ExternalCatalogUtils.escapePathName _,
@@ -325,35 +343,46 @@ object FileFormatWriter extends Logging {
325343
}
326344

327345
/**
328-
* Open and returns a new OutputWriter given a partition key and optional bucket id.
346+
* Opens a new OutputWriter given a partition key and optional bucket id.
329347
* If bucket id is specified, we will append it to the end of the file name, but before the
330348
* file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
331349
*
332-
* @param key vaues for fields consisting of partition keys for the current row
333-
* @param partString a function that projects the partition values into a string
350+
* @param partColsAndBucketId a row consisting of partition columns and a bucket id for the
351+
* current row.
352+
* @param getPartitionPath a function that projects the partition values into a path string.
334353
* @param fileCounter the number of files that have been written in the past for this specific
335354
* partition. This is used to limit the max number of records written for a
336355
* single file. The value should start from 0.
356+
* @param updatedPartitions the set of updated partition paths, we should add the new partition
357+
* path of this writer to it.
337358
*/
338359
private def newOutputWriter(
339-
key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = {
340-
val partDir =
341-
if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0))
360+
partColsAndBucketId: InternalRow,
361+
getPartitionPath: UnsafeProjection,
362+
fileCounter: Int,
363+
updatedPartitions: mutable.Set[String]): Unit = {
364+
val partDir = if (desc.partitionColumns.isEmpty) {
365+
None
366+
} else {
367+
Option(getPartitionPath(partColsAndBucketId).getString(0))
368+
}
369+
partDir.foreach(updatedPartitions.add)
342370

343-
// If the bucket spec is defined, the bucket column is right after the partition columns
344-
val bucketId = if (description.bucketSpec.isDefined) {
345-
BucketingUtils.bucketIdToString(key.getInt(description.partitionColumns.length))
371+
// If the bucketId expression is defined, the bucketId column is right after the partition
372+
// columns.
373+
val bucketId = if (desc.bucketIdExpression.isDefined) {
374+
BucketingUtils.bucketIdToString(partColsAndBucketId.getInt(desc.partitionColumns.length))
346375
} else {
347376
""
348377
}
349378

350379
// This must be in a form that matches our bucketing format. See BucketingUtils.
351380
val ext = f"$bucketId.c$fileCounter%03d" +
352-
description.outputWriterFactory.getFileExtension(taskAttemptContext)
381+
desc.outputWriterFactory.getFileExtension(taskAttemptContext)
353382

354383
val customPath = partDir match {
355384
case Some(dir) =>
356-
description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
385+
desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
357386
case _ =>
358387
None
359388
}
@@ -363,80 +392,42 @@ object FileFormatWriter extends Logging {
363392
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
364393
}
365394

366-
currentWriter = description.outputWriterFactory.newInstance(
395+
currentWriter = desc.outputWriterFactory.newInstance(
367396
path = path,
368-
dataSchema = description.dataColumns.toStructType,
397+
dataSchema = desc.dataColumns.toStructType,
369398
context = taskAttemptContext)
370399
}
371400

372401
override def execute(iter: Iterator[InternalRow]): Set[String] = {
373-
// We should first sort by partition columns, then bucket id, and finally sorting columns.
374-
val sortingExpressions: Seq[Expression] =
375-
description.partitionColumns ++ bucketIdExpression ++ sortColumns
376-
val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns)
377-
378-
val sortingKeySchema = StructType(sortingExpressions.map {
379-
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
380-
// The sorting expressions are all `Attribute` except bucket id.
381-
case _ => StructField("bucketId", IntegerType, nullable = false)
382-
})
383-
384-
// Returns the data columns to be written given an input row
385-
val getOutputRow = UnsafeProjection.create(
386-
description.dataColumns, description.allColumns)
387-
388-
// Returns the partition path given a partition key.
389-
val getPartitionStringFunc = UnsafeProjection.create(
390-
Seq(Concat(partitionStringExpression)), description.partitionColumns)
391-
392-
// Sorts the data before write, so that we only need one writer at the same time.
393-
val sorter = new UnsafeKVExternalSorter(
394-
sortingKeySchema,
395-
StructType.fromAttributes(description.dataColumns),
396-
SparkEnv.get.blockManager,
397-
SparkEnv.get.serializerManager,
398-
TaskContext.get().taskMemoryManager().pageSizeBytes,
399-
SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
400-
UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD))
401-
402-
while (iter.hasNext) {
403-
val currentRow = iter.next()
404-
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
405-
}
402+
val getPartitionColsAndBucketId = UnsafeProjection.create(
403+
desc.partitionColumns ++ desc.bucketIdExpression, desc.allColumns)
406404

407-
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
408-
identity
409-
} else {
410-
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
411-
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
412-
})
413-
}
405+
// Generates the partition path given the row generated by `getPartitionColsAndBucketId`.
406+
val getPartPath = UnsafeProjection.create(
407+
Seq(Concat(partitionPathExpression)), desc.partitionColumns)
414408

415-
val sortedIterator = sorter.sortedIterator()
409+
// Returns the data columns to be written given an input row
410+
val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns)
416411

417412
// If anything below fails, we should abort the task.
418413
var recordsInFile: Long = 0L
419414
var fileCounter = 0
420-
var currentKey: UnsafeRow = null
415+
var currentPartColsAndBucketId: UnsafeRow = null
421416
val updatedPartitions = mutable.Set[String]()
422-
while (sortedIterator.next()) {
423-
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
424-
if (currentKey != nextKey) {
425-
// See a new key - write to a new partition (new file).
426-
currentKey = nextKey.copy()
427-
logDebug(s"Writing partition: $currentKey")
417+
for (row <- iter) {
418+
val nextPartColsAndBucketId = getPartitionColsAndBucketId(row)
419+
if (currentPartColsAndBucketId != nextPartColsAndBucketId) {
420+
// See a new partition or bucket - write to a new partition dir (or a new bucket file).
421+
currentPartColsAndBucketId = nextPartColsAndBucketId.copy()
422+
logDebug(s"Writing partition: $currentPartColsAndBucketId")
428423

429424
recordsInFile = 0
430425
fileCounter = 0
431426

432427
releaseResources()
433-
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
434-
val partitionPath = getPartitionStringFunc(currentKey).getString(0)
435-
if (partitionPath.nonEmpty) {
436-
updatedPartitions.add(partitionPath)
437-
}
438-
} else if (description.maxRecordsPerFile > 0 &&
439-
recordsInFile >= description.maxRecordsPerFile) {
428+
newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions)
429+
} else if (desc.maxRecordsPerFile > 0 &&
430+
recordsInFile >= desc.maxRecordsPerFile) {
440431
// Exceeded the threshold in terms of the number of records per file.
441432
// Create a new file by increasing the file counter.
442433
recordsInFile = 0
@@ -445,10 +436,10 @@ object FileFormatWriter extends Logging {
445436
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
446437

447438
releaseResources()
448-
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
439+
newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions)
449440
}
450441

451-
currentWriter.write(sortedIterator.getValue)
442+
currentWriter.write(getOutputRow(row))
452443
recordsInFile += 1
453444
}
454445
releaseResources()

0 commit comments

Comments
 (0)