Skip to content

Commit f253fef

Browse files
cloud-fanrxin
authored andcommitted
[SPARK-12539][FOLLOW-UP] always sort in partitioning writer
address comments in #10498 , especially #10498 (comment) Author: Wenchen Fan <[email protected]> This patch had conflicts when merged, resolved by Committer: Reynold Xin <[email protected]> Closes #10638 from cloud-fan/bucket-write.
1 parent f13c7f8 commit f253fef

File tree

2 files changed

+48
-147
lines changed

2 files changed

+48
-147
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala

Lines changed: 48 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
3333
import org.apache.spark.sql.catalyst.InternalRow
3434
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
3535
import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory}
36-
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
36+
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
3737
import org.apache.spark.util.SerializableConfiguration
3838

3939

@@ -349,67 +349,6 @@ private[sql] class DynamicPartitionWriterContainer(
349349
}
350350
}
351351

352-
private def sameBucket(key1: UnsafeRow, key2: UnsafeRow): Boolean = {
353-
val bucketIdIndex = partitionColumns.length
354-
if (key1.getInt(bucketIdIndex) != key2.getInt(bucketIdIndex)) {
355-
false
356-
} else {
357-
var i = partitionColumns.length - 1
358-
while (i >= 0) {
359-
val dt = partitionColumns(i).dataType
360-
if (key1.get(i, dt) != key2.get(i, dt)) return false
361-
i -= 1
362-
}
363-
true
364-
}
365-
}
366-
367-
private def sortBasedWrite(
368-
sorter: UnsafeKVExternalSorter,
369-
iterator: Iterator[InternalRow],
370-
getSortingKey: UnsafeProjection,
371-
getOutputRow: UnsafeProjection,
372-
getPartitionString: UnsafeProjection,
373-
outputWriters: java.util.HashMap[InternalRow, OutputWriter]): Unit = {
374-
while (iterator.hasNext) {
375-
val currentRow = iterator.next()
376-
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
377-
}
378-
379-
logInfo(s"Sorting complete. Writing out partition files one at a time.")
380-
381-
val needNewWriter: (UnsafeRow, UnsafeRow) => Boolean = if (sortColumns.isEmpty) {
382-
(key1, key2) => key1 != key2
383-
} else {
384-
(key1, key2) => key1 == null || !sameBucket(key1, key2)
385-
}
386-
387-
val sortedIterator = sorter.sortedIterator()
388-
var currentKey: UnsafeRow = null
389-
var currentWriter: OutputWriter = null
390-
try {
391-
while (sortedIterator.next()) {
392-
if (needNewWriter(currentKey, sortedIterator.getKey)) {
393-
if (currentWriter != null) {
394-
currentWriter.close()
395-
}
396-
currentKey = sortedIterator.getKey.copy()
397-
logDebug(s"Writing partition: $currentKey")
398-
399-
// Either use an existing file from before, or open a new one.
400-
currentWriter = outputWriters.remove(currentKey)
401-
if (currentWriter == null) {
402-
currentWriter = newOutputWriter(currentKey, getPartitionString)
403-
}
404-
}
405-
406-
currentWriter.writeInternal(sortedIterator.getValue)
407-
}
408-
} finally {
409-
if (currentWriter != null) { currentWriter.close() }
410-
}
411-
}
412-
413352
/**
414353
* Open and returns a new OutputWriter given a partition key and optional bucket id.
415354
* If bucket id is specified, we will append it to the end of the file name, but before the
@@ -435,22 +374,18 @@ private[sql] class DynamicPartitionWriterContainer(
435374
}
436375

437376
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
438-
val outputWriters = new java.util.HashMap[InternalRow, OutputWriter]
439377
executorSideSetup(taskContext)
440378

441-
var outputWritersCleared = false
442-
443379
// We should first sort by partition columns, then bucket id, and finally sorting columns.
444-
val getSortingKey =
445-
UnsafeProjection.create(partitionColumns ++ bucketIdExpression ++ sortColumns, inputSchema)
446-
447-
val sortingKeySchema = if (bucketSpec.isEmpty) {
448-
StructType.fromAttributes(partitionColumns)
449-
} else { // If it's bucketed, we should also consider bucket id as part of the key.
450-
val fields = StructType.fromAttributes(partitionColumns)
451-
.add("bucketId", IntegerType, nullable = false) ++ StructType.fromAttributes(sortColumns)
452-
StructType(fields)
453-
}
380+
val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns
381+
382+
val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)
383+
384+
val sortingKeySchema = StructType(sortingExpressions.map {
385+
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
386+
// The sorting expressions are all `Attribute` except bucket id.
387+
case _ => StructField("bucketId", IntegerType, nullable = false)
388+
})
454389

455390
// Returns the data columns to be written given an input row
456391
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
@@ -461,54 +396,49 @@ private[sql] class DynamicPartitionWriterContainer(
461396

462397
// If anything below fails, we should abort the task.
463398
try {
464-
// If there is no sorting columns, we set sorter to null and try the hash-based writing first,
465-
// and fill the sorter if there are too many writers and we need to fall back on sorting.
466-
// If there are sorting columns, then we have to sort the data anyway, and no need to try the
467-
// hash-based writing first.
468-
var sorter: UnsafeKVExternalSorter = if (sortColumns.nonEmpty) {
469-
new UnsafeKVExternalSorter(
470-
sortingKeySchema,
471-
StructType.fromAttributes(dataColumns),
472-
SparkEnv.get.blockManager,
473-
TaskContext.get().taskMemoryManager().pageSizeBytes)
399+
// Sorts the data before write, so that we only need one writer at the same time.
400+
// TODO: inject a local sort operator in planning.
401+
val sorter = new UnsafeKVExternalSorter(
402+
sortingKeySchema,
403+
StructType.fromAttributes(dataColumns),
404+
SparkEnv.get.blockManager,
405+
TaskContext.get().taskMemoryManager().pageSizeBytes)
406+
407+
while (iterator.hasNext) {
408+
val currentRow = iterator.next()
409+
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
410+
}
411+
412+
logInfo(s"Sorting complete. Writing out partition files one at a time.")
413+
414+
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
415+
identity
474416
} else {
475-
null
417+
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
418+
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
419+
})
476420
}
477-
while (iterator.hasNext && sorter == null) {
478-
val inputRow = iterator.next()
479-
// When we reach here, the `sortColumns` must be empty, so the sorting key is hashing key.
480-
val currentKey = getSortingKey(inputRow)
481-
var currentWriter = outputWriters.get(currentKey)
482-
483-
if (currentWriter == null) {
484-
if (outputWriters.size < maxOpenFiles) {
421+
422+
val sortedIterator = sorter.sortedIterator()
423+
var currentKey: UnsafeRow = null
424+
var currentWriter: OutputWriter = null
425+
try {
426+
while (sortedIterator.next()) {
427+
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
428+
if (currentKey != nextKey) {
429+
if (currentWriter != null) {
430+
currentWriter.close()
431+
}
432+
currentKey = nextKey.copy()
433+
logDebug(s"Writing partition: $currentKey")
434+
485435
currentWriter = newOutputWriter(currentKey, getPartitionString)
486-
outputWriters.put(currentKey.copy(), currentWriter)
487-
currentWriter.writeInternal(getOutputRow(inputRow))
488-
} else {
489-
logInfo(s"Maximum partitions reached, falling back on sorting.")
490-
sorter = new UnsafeKVExternalSorter(
491-
sortingKeySchema,
492-
StructType.fromAttributes(dataColumns),
493-
SparkEnv.get.blockManager,
494-
TaskContext.get().taskMemoryManager().pageSizeBytes)
495-
sorter.insertKV(currentKey, getOutputRow(inputRow))
496436
}
497-
} else {
498-
currentWriter.writeInternal(getOutputRow(inputRow))
499-
}
500-
}
501437

502-
// If the sorter is not null that means that we reached the maxFiles above and need to finish
503-
// using external sort, or there are sorting columns and we need to sort the whole data set.
504-
if (sorter != null) {
505-
sortBasedWrite(
506-
sorter,
507-
iterator,
508-
getSortingKey,
509-
getOutputRow,
510-
getPartitionString,
511-
outputWriters)
438+
currentWriter.writeInternal(sortedIterator.getValue)
439+
}
440+
} finally {
441+
if (currentWriter != null) { currentWriter.close() }
512442
}
513443

514444
commitTask()
@@ -518,31 +448,5 @@ private[sql] class DynamicPartitionWriterContainer(
518448
abortTask()
519449
throw new SparkException("Task failed while writing rows.", cause)
520450
}
521-
522-
def clearOutputWriters(): Unit = {
523-
if (!outputWritersCleared) {
524-
outputWriters.asScala.values.foreach(_.close())
525-
outputWriters.clear()
526-
outputWritersCleared = true
527-
}
528-
}
529-
530-
def commitTask(): Unit = {
531-
try {
532-
clearOutputWriters()
533-
super.commitTask()
534-
} catch {
535-
case cause: Throwable =>
536-
throw new RuntimeException("Failed to commit task", cause)
537-
}
538-
}
539-
540-
def abortTask(): Unit = {
541-
try {
542-
clearOutputWriters()
543-
} finally {
544-
super.abortTask()
545-
}
546-
}
547451
}
548452
}

sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ trait HadoopFsRelationProvider {
162162
partitionColumns: Option[StructType],
163163
parameters: Map[String, String]): HadoopFsRelation
164164

165-
// TODO: expose bucket API to users.
166165
private[sql] def createRelation(
167166
sqlContext: SQLContext,
168167
paths: Array[String],
@@ -370,7 +369,6 @@ abstract class OutputWriterFactory extends Serializable {
370369
dataSchema: StructType,
371370
context: TaskAttemptContext): OutputWriter
372371

373-
// TODO: expose bucket API to users.
374372
private[sql] def newInstance(
375373
path: String,
376374
bucketId: Option[Int],
@@ -460,7 +458,6 @@ abstract class HadoopFsRelation private[sql](
460458

461459
private var _partitionSpec: PartitionSpec = _
462460

463-
// TODO: expose bucket API to users.
464461
private[sql] def bucketSpec: Option[BucketSpec] = None
465462

466463
private class FileStatusCache {

0 commit comments

Comments
 (0)