@@ -239,50 +239,48 @@ private[sql] class DefaultWriterContainer(
239239 extends BaseWriterContainer (relation, job, isAppend) {
240240
241241 def writeRows (taskContext : TaskContext , iterator : Iterator [InternalRow ]): Unit = {
242- if (iterator.hasNext) {
243- executorSideSetup(taskContext)
244- val configuration = taskAttemptContext.getConfiguration
245- configuration.set(" spark.sql.sources.output.path" , outputPath)
246- var writer = newOutputWriter(getWorkPath)
247- writer.initConverter(dataSchema)
242+ executorSideSetup(taskContext)
243+ val configuration = taskAttemptContext.getConfiguration
244+ configuration.set(" spark.sql.sources.output.path" , outputPath)
245+ var writer = newOutputWriter(getWorkPath)
246+ writer.initConverter(dataSchema)
248247
249- // If anything below fails, we should abort the task.
250- try {
251- Utils .tryWithSafeFinallyAndFailureCallbacks {
252- while (iterator.hasNext) {
253- val internalRow = iterator.next()
254- writer.writeInternal(internalRow)
255- }
256- commitTask()
257- }(catchBlock = abortTask())
258- } catch {
259- case t : Throwable =>
260- throw new SparkException (" Task failed while writing rows" , t)
261- }
248+ // If anything below fails, we should abort the task.
249+ try {
250+ Utils .tryWithSafeFinallyAndFailureCallbacks {
251+ while (iterator.hasNext) {
252+ val internalRow = iterator.next()
253+ writer.writeInternal(internalRow)
254+ }
255+ commitTask()
256+ }(catchBlock = abortTask())
257+ } catch {
258+ case t : Throwable =>
259+ throw new SparkException (" Task failed while writing rows" , t)
260+ }
262261
263- def commitTask (): Unit = {
264- try {
265- if (writer != null ) {
266- writer.close()
267- writer = null
268- }
269- super .commitTask()
270- } catch {
271- case cause : Throwable =>
272- // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
273- // will cause `abortTask()` to be invoked.
274- throw new RuntimeException (" Failed to commit task" , cause)
262+ def commitTask (): Unit = {
263+ try {
264+ if (writer != null ) {
265+ writer.close()
266+ writer = null
275267 }
268+ super .commitTask()
269+ } catch {
270+ case cause : Throwable =>
271+ // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
272+ // will cause `abortTask()` to be invoked.
273+ throw new RuntimeException (" Failed to commit task" , cause)
276274 }
275+ }
277276
278- def abortTask (): Unit = {
279- try {
280- if (writer != null ) {
281- writer.close()
282- }
283- } finally {
284- super .abortTask()
277+ def abortTask (): Unit = {
278+ try {
279+ if (writer != null ) {
280+ writer.close()
285281 }
282+ } finally {
283+ super .abortTask()
286284 }
287285 }
288286 }
@@ -365,87 +363,84 @@ private[sql] class DynamicPartitionWriterContainer(
365363 }
366364
367365 def writeRows (taskContext : TaskContext , iterator : Iterator [InternalRow ]): Unit = {
368- if (iterator.hasNext) {
369- executorSideSetup(taskContext)
370-
371- // We should first sort by partition columns, then bucket id, and finally sorting columns.
372- val sortingExpressions : Seq [Expression ] =
373- partitionColumns ++ bucketIdExpression ++ sortColumns
374- val getSortingKey = UnsafeProjection .create(sortingExpressions, inputSchema)
375-
376- val sortingKeySchema = StructType (sortingExpressions.map {
377- case a : Attribute => StructField (a.name, a.dataType, a.nullable)
378- // The sorting expressions are all `Attribute` except bucket id.
379- case _ => StructField (" bucketId" , IntegerType , nullable = false )
380- })
366+ executorSideSetup(taskContext)
367+
368+ // We should first sort by partition columns, then bucket id, and finally sorting columns.
369+ val sortingExpressions : Seq [Expression ] = partitionColumns ++ bucketIdExpression ++ sortColumns
370+ val getSortingKey = UnsafeProjection .create(sortingExpressions, inputSchema)
371+
372+ val sortingKeySchema = StructType (sortingExpressions.map {
373+ case a : Attribute => StructField (a.name, a.dataType, a.nullable)
374+ // The sorting expressions are all `Attribute` except bucket id.
375+ case _ => StructField (" bucketId" , IntegerType , nullable = false )
376+ })
377+
378+ // Returns the data columns to be written given an input row
379+ val getOutputRow = UnsafeProjection .create(dataColumns, inputSchema)
380+
381+ // Returns the partition path given a partition key.
382+ val getPartitionString =
383+ UnsafeProjection .create(Concat (partitionStringExpression) :: Nil , partitionColumns)
384+
385+ // Sorts the data before write, so that we only need one writer at the same time.
386+ // TODO: inject a local sort operator in planning.
387+ val sorter = new UnsafeKVExternalSorter (
388+ sortingKeySchema,
389+ StructType .fromAttributes(dataColumns),
390+ SparkEnv .get.blockManager,
391+ SparkEnv .get.serializerManager,
392+ TaskContext .get().taskMemoryManager().pageSizeBytes)
393+
394+ while (iterator.hasNext) {
395+ val currentRow = iterator.next()
396+ sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
397+ }
398+ logInfo(s " Sorting complete. Writing out partition files one at a time. " )
381399
382- // Returns the data columns to be written given an input row
383- val getOutputRow = UnsafeProjection .create(dataColumns, inputSchema)
384-
385- // Returns the partition path given a partition key.
386- val getPartitionString =
387- UnsafeProjection .create(Concat (partitionStringExpression) :: Nil , partitionColumns)
388-
389- // Sorts the data before write, so that we only need one writer at the same time.
390- // TODO: inject a local sort operator in planning.
391- val sorter = new UnsafeKVExternalSorter (
392- sortingKeySchema,
393- StructType .fromAttributes(dataColumns),
394- SparkEnv .get.blockManager,
395- SparkEnv .get.serializerManager,
396- TaskContext .get().taskMemoryManager().pageSizeBytes)
397-
398- while (iterator.hasNext) {
399- val currentRow = iterator.next()
400- sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
401- }
402- logInfo(s " Sorting complete. Writing out partition files one at a time. " )
403-
404- val getBucketingKey : InternalRow => InternalRow = if (sortColumns.isEmpty) {
405- identity
406- } else {
407- UnsafeProjection .create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
408- case (expr, ordinal) => BoundReference (ordinal, expr.dataType, expr.nullable)
409- })
410- }
400+ val getBucketingKey : InternalRow => InternalRow = if (sortColumns.isEmpty) {
401+ identity
402+ } else {
403+ UnsafeProjection .create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
404+ case (expr, ordinal) => BoundReference (ordinal, expr.dataType, expr.nullable)
405+ })
406+ }
411407
412- val sortedIterator = sorter.sortedIterator()
408+ val sortedIterator = sorter.sortedIterator()
413409
414- // If anything below fails, we should abort the task.
415- var currentWriter : OutputWriter = null
416- try {
417- Utils .tryWithSafeFinallyAndFailureCallbacks {
418- var currentKey : UnsafeRow = null
419- while (sortedIterator.next()) {
420- val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf [UnsafeRow ]
421- if (currentKey != nextKey) {
422- if (currentWriter != null ) {
423- currentWriter.close()
424- currentWriter = null
425- }
426- currentKey = nextKey.copy()
427- logDebug(s " Writing partition: $currentKey" )
428-
429- currentWriter = newOutputWriter(currentKey, getPartitionString)
410+ // If anything below fails, we should abort the task.
411+ var currentWriter : OutputWriter = null
412+ try {
413+ Utils .tryWithSafeFinallyAndFailureCallbacks {
414+ var currentKey : UnsafeRow = null
415+ while (sortedIterator.next()) {
416+ val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf [UnsafeRow ]
417+ if (currentKey != nextKey) {
418+ if (currentWriter != null ) {
419+ currentWriter.close()
420+ currentWriter = null
430421 }
431- currentWriter.writeInternal(sortedIterator.getValue)
432- }
433- if (currentWriter != null ) {
434- currentWriter.close()
435- currentWriter = null
436- }
422+ currentKey = nextKey.copy()
423+ logDebug(s " Writing partition: $currentKey" )
437424
438- commitTask()
439- }(catchBlock = {
440- if (currentWriter != null ) {
441- currentWriter.close()
425+ currentWriter = newOutputWriter(currentKey, getPartitionString)
442426 }
443- abortTask()
444- })
445- } catch {
446- case t : Throwable =>
447- throw new SparkException (" Task failed while writing rows" , t)
448- }
427+ currentWriter.writeInternal(sortedIterator.getValue)
428+ }
429+ if (currentWriter != null ) {
430+ currentWriter.close()
431+ currentWriter = null
432+ }
433+
434+ commitTask()
435+ }(catchBlock = {
436+ if (currentWriter != null ) {
437+ currentWriter.close()
438+ }
439+ abortTask()
440+ })
441+ } catch {
442+ case t : Throwable =>
443+ throw new SparkException (" Task failed while writing rows" , t)
449444 }
450445 }
451446}
0 commit comments