Skip to content

Commit 247ddad

Browse files
committed
address comments
1 parent e523245 commit 247ddad

File tree

1 file changed

+8
-20
lines changed

1 file changed

+8
-20
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -349,21 +349,6 @@ private[sql] class DynamicPartitionWriterContainer(
349349
}
350350
}
351351

352-
private def sameBucket(key1: UnsafeRow, key2: UnsafeRow): Boolean = {
353-
val bucketIdIndex = partitionColumns.length
354-
if (key1.getInt(bucketIdIndex) != key2.getInt(bucketIdIndex)) {
355-
false
356-
} else {
357-
var i = partitionColumns.length - 1
358-
while (i >= 0) {
359-
val dt = partitionColumns(i).dataType
360-
if (key1.get(i, dt) != key2.get(i, dt)) return false
361-
i -= 1
362-
}
363-
true
364-
}
365-
}
366-
367352
/**
368353
* Open and returns a new OutputWriter given a partition key and optional bucket id.
369354
* If bucket id is specified, we will append it to the end of the file name, but before the
@@ -426,22 +411,25 @@ private[sql] class DynamicPartitionWriterContainer(
426411

427412
logInfo(s"Sorting complete. Writing out partition files one at a time.")
428413

429-
val needNewWriter: (UnsafeRow, UnsafeRow) => Boolean = if (sortColumns.isEmpty) {
430-
(key1, key2) => key1 != key2
414+
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
415+
identity
431416
} else {
432-
(key1, key2) => key1 == null || !sameBucket(key1, key2)
417+
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
418+
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
419+
})
433420
}
434421

435422
val sortedIterator = sorter.sortedIterator()
436423
var currentKey: UnsafeRow = null
437424
var currentWriter: OutputWriter = null
438425
try {
439426
while (sortedIterator.next()) {
440-
if (needNewWriter(currentKey, sortedIterator.getKey)) {
427+
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
428+
if (currentKey != nextKey) {
441429
if (currentWriter != null) {
442430
currentWriter.close()
443431
}
444-
currentKey = sortedIterator.getKey.copy()
432+
currentKey = nextKey.copy()
445433
logDebug(s"Writing partition: $currentKey")
446434

447435
currentWriter = newOutputWriter(currentKey, getPartitionString)

0 commit comments

Comments
 (0)