@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
3333import org .apache .spark .sql .catalyst .InternalRow
3434import org .apache .spark .sql .execution .UnsafeKVExternalSorter
3535import org .apache .spark .sql .sources .{HadoopFsRelation , OutputWriter , OutputWriterFactory }
36- import org .apache .spark .sql .types .{IntegerType , StringType , StructType }
36+ import org .apache .spark .sql .types .{IntegerType , StringType , StructField , StructType }
3737import org .apache .spark .util .SerializableConfiguration
3838
3939
@@ -349,67 +349,6 @@ private[sql] class DynamicPartitionWriterContainer(
349349 }
350350 }
351351
352- private def sameBucket (key1 : UnsafeRow , key2 : UnsafeRow ): Boolean = {
353- val bucketIdIndex = partitionColumns.length
354- if (key1.getInt(bucketIdIndex) != key2.getInt(bucketIdIndex)) {
355- false
356- } else {
357- var i = partitionColumns.length - 1
358- while (i >= 0 ) {
359- val dt = partitionColumns(i).dataType
360- if (key1.get(i, dt) != key2.get(i, dt)) return false
361- i -= 1
362- }
363- true
364- }
365- }
366-
367- private def sortBasedWrite (
368- sorter : UnsafeKVExternalSorter ,
369- iterator : Iterator [InternalRow ],
370- getSortingKey : UnsafeProjection ,
371- getOutputRow : UnsafeProjection ,
372- getPartitionString : UnsafeProjection ,
373- outputWriters : java.util.HashMap [InternalRow , OutputWriter ]): Unit = {
374- while (iterator.hasNext) {
375- val currentRow = iterator.next()
376- sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
377- }
378-
379- logInfo(s " Sorting complete. Writing out partition files one at a time. " )
380-
381- val needNewWriter : (UnsafeRow , UnsafeRow ) => Boolean = if (sortColumns.isEmpty) {
382- (key1, key2) => key1 != key2
383- } else {
384- (key1, key2) => key1 == null || ! sameBucket(key1, key2)
385- }
386-
387- val sortedIterator = sorter.sortedIterator()
388- var currentKey : UnsafeRow = null
389- var currentWriter : OutputWriter = null
390- try {
391- while (sortedIterator.next()) {
392- if (needNewWriter(currentKey, sortedIterator.getKey)) {
393- if (currentWriter != null ) {
394- currentWriter.close()
395- }
396- currentKey = sortedIterator.getKey.copy()
397- logDebug(s " Writing partition: $currentKey" )
398-
399- // Either use an existing file from before, or open a new one.
400- currentWriter = outputWriters.remove(currentKey)
401- if (currentWriter == null ) {
402- currentWriter = newOutputWriter(currentKey, getPartitionString)
403- }
404- }
405-
406- currentWriter.writeInternal(sortedIterator.getValue)
407- }
408- } finally {
409- if (currentWriter != null ) { currentWriter.close() }
410- }
411- }
412-
413352 /**
414353 * Open and returns a new OutputWriter given a partition key and optional bucket id.
415354 * If bucket id is specified, we will append it to the end of the file name, but before the
@@ -435,22 +374,18 @@ private[sql] class DynamicPartitionWriterContainer(
435374 }
436375
437376 def writeRows (taskContext : TaskContext , iterator : Iterator [InternalRow ]): Unit = {
438- val outputWriters = new java.util.HashMap [InternalRow , OutputWriter ]
439377 executorSideSetup(taskContext)
440378
441- var outputWritersCleared = false
442-
443379 // We should first sort by partition columns, then bucket id, and finally sorting columns.
444- val getSortingKey =
445- UnsafeProjection .create(partitionColumns ++ bucketIdExpression ++ sortColumns, inputSchema)
446-
447- val sortingKeySchema = if (bucketSpec.isEmpty) {
448- StructType .fromAttributes(partitionColumns)
449- } else { // If it's bucketed, we should also consider bucket id as part of the key.
450- val fields = StructType .fromAttributes(partitionColumns)
451- .add(" bucketId" , IntegerType , nullable = false ) ++ StructType .fromAttributes(sortColumns)
452- StructType (fields)
453- }
380+ val sortingExpressions : Seq [Expression ] = partitionColumns ++ bucketIdExpression ++ sortColumns
381+
382+ val getSortingKey = UnsafeProjection .create(sortingExpressions, inputSchema)
383+
384+ val sortingKeySchema = StructType (sortingExpressions.map {
385+ case a : Attribute => StructField (a.name, a.dataType, a.nullable)
386+ // The sorting expressions are all `Attribute` except bucket id.
387+ case _ => StructField (" bucketId" , IntegerType , nullable = false )
388+ })
454389
455390 // Returns the data columns to be written given an input row
456391 val getOutputRow = UnsafeProjection .create(dataColumns, inputSchema)
@@ -461,54 +396,49 @@ private[sql] class DynamicPartitionWriterContainer(
461396
462397 // If anything below fails, we should abort the task.
463398 try {
464- // If there is no sorting columns, we set sorter to null and try the hash-based writing first,
465- // and fill the sorter if there are too many writers and we need to fall back on sorting.
466- // If there are sorting columns, then we have to sort the data anyway, and no need to try the
467- // hash-based writing first.
468- var sorter : UnsafeKVExternalSorter = if (sortColumns.nonEmpty) {
469- new UnsafeKVExternalSorter (
470- sortingKeySchema,
471- StructType .fromAttributes(dataColumns),
472- SparkEnv .get.blockManager,
473- TaskContext .get().taskMemoryManager().pageSizeBytes)
399+ // Sorts the data before write, so that we only need one writer at the same time.
400+ // TODO: inject a local sort operator in planning.
401+ val sorter = new UnsafeKVExternalSorter (
402+ sortingKeySchema,
403+ StructType .fromAttributes(dataColumns),
404+ SparkEnv .get.blockManager,
405+ TaskContext .get().taskMemoryManager().pageSizeBytes)
406+
407+ while (iterator.hasNext) {
408+ val currentRow = iterator.next()
409+ sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
410+ }
411+
412+ logInfo(s " Sorting complete. Writing out partition files one at a time. " )
413+
414+ val getBucketingKey : InternalRow => InternalRow = if (sortColumns.isEmpty) {
415+ identity
474416 } else {
475- null
417+ UnsafeProjection .create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
418+ case (expr, ordinal) => BoundReference (ordinal, expr.dataType, expr.nullable)
419+ })
476420 }
477- while (iterator.hasNext && sorter == null ) {
478- val inputRow = iterator.next()
479- // When we reach here, the `sortColumns` must be empty, so the sorting key is hashing key.
480- val currentKey = getSortingKey(inputRow)
481- var currentWriter = outputWriters.get(currentKey)
482-
483- if (currentWriter == null ) {
484- if (outputWriters.size < maxOpenFiles) {
421+
422+ val sortedIterator = sorter.sortedIterator()
423+ var currentKey : UnsafeRow = null
424+ var currentWriter : OutputWriter = null
425+ try {
426+ while (sortedIterator.next()) {
427+ val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf [UnsafeRow ]
428+ if (currentKey != nextKey) {
429+ if (currentWriter != null ) {
430+ currentWriter.close()
431+ }
432+ currentKey = nextKey.copy()
433+ logDebug(s " Writing partition: $currentKey" )
434+
485435 currentWriter = newOutputWriter(currentKey, getPartitionString)
486- outputWriters.put(currentKey.copy(), currentWriter)
487- currentWriter.writeInternal(getOutputRow(inputRow))
488- } else {
489- logInfo(s " Maximum partitions reached, falling back on sorting. " )
490- sorter = new UnsafeKVExternalSorter (
491- sortingKeySchema,
492- StructType .fromAttributes(dataColumns),
493- SparkEnv .get.blockManager,
494- TaskContext .get().taskMemoryManager().pageSizeBytes)
495- sorter.insertKV(currentKey, getOutputRow(inputRow))
496436 }
497- } else {
498- currentWriter.writeInternal(getOutputRow(inputRow))
499- }
500- }
501437
502- // If the sorter is not null that means that we reached the maxFiles above and need to finish
503- // using external sort, or there are sorting columns and we need to sort the whole data set.
504- if (sorter != null ) {
505- sortBasedWrite(
506- sorter,
507- iterator,
508- getSortingKey,
509- getOutputRow,
510- getPartitionString,
511- outputWriters)
438+ currentWriter.writeInternal(sortedIterator.getValue)
439+ }
440+ } finally {
441+ if (currentWriter != null ) { currentWriter.close() }
512442 }
513443
514444 commitTask()
@@ -518,31 +448,5 @@ private[sql] class DynamicPartitionWriterContainer(
518448 abortTask()
519449 throw new SparkException (" Task failed while writing rows." , cause)
520450 }
521-
522- def clearOutputWriters (): Unit = {
523- if (! outputWritersCleared) {
524- outputWriters.asScala.values.foreach(_.close())
525- outputWriters.clear()
526- outputWritersCleared = true
527- }
528- }
529-
530- def commitTask (): Unit = {
531- try {
532- clearOutputWriters()
533- super .commitTask()
534- } catch {
535- case cause : Throwable =>
536- throw new RuntimeException (" Failed to commit task" , cause)
537- }
538- }
539-
540- def abortTask (): Unit = {
541- try {
542- clearOutputWriters()
543- } finally {
544- super .abortTask()
545- }
546- }
547451 }
548452}
0 commit comments