Skip to content

Commit 4588f82

Browse files
author
Ilya Ganelin
committed
Added logic to close existing outputWriters if data being processed has been sorted
1 parent 7e355c4 commit 4588f82

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,10 @@ private[sql] case class InsertIntoHadoopFsRelation(
245245
// we don't end up outputting the same data twice
246246
val writtenRows: mutable.HashSet[InternalRow] = new HashSet[InternalRow]
247247

248+
// Flag to track whether data has been sorted in which case it's safe to close previously
249+
// used outputWriters
250+
var sorted: Boolean = false
251+
248252
// If anything below fails, we should abort the task.
249253
try {
250254
writerContainer.executorSideSetup(taskContext)
@@ -263,15 +267,14 @@ private[sql] case class InsertIntoHadoopFsRelation(
263267
// Sort the data by partition so that it's possible to use a single outputWriter at a
264268
// time to process the incoming data
265269
def sortRows(iterator: Iterator[InternalRow]): Iterator[InternalRow] = {
270+
// TODO Sort the data by partition key
266271
throw new NotImplementedException()
267272
}
268273

269274
// When outputting rows, we may need to interrupt the file write to sort the underlying data
270275
// (SPARK-8890) to avoid running out of memory due to creating too many outputWriters. Thus,
271276
// we extract this functionality into its own function that can be called with updated
272277
// underlying data.
273-
// TODO Add tracking of whether data has been sorted in outputWriterForRow, so that
274-
// previously used outputWriters may be closed when complete.
275278
def writeRowsSafe(iterator: Iterator[InternalRow]): Unit ={
276279
while (iterator.hasNext) {
277280
val internalRow = iterator.next()
@@ -281,11 +284,11 @@ private[sql] case class InsertIntoHadoopFsRelation(
281284
val partitionPart = partitionProj(internalRow)
282285
val dataPart = dataConverter(dataProj(internalRow))
283286

284-
writerContainer.outputWriterForRow(partitionPart).write(dataPart)
287+
writerContainer.outputWriterForRow(partitionPart, sorted).write(dataPart)
285288
writtenRows += internalRow
286289
} else {
287-
// TODO Sort the data by partition key
288290
val sortedRows: Iterator[InternalRow] = sortRows(iterator)
291+
sorted = true
289292
writeRowsSafe(sortedRows)
290293
}
291294
}
@@ -547,7 +550,7 @@ private[sql] class DynamicPartitionWriterContainer(
547550
// All output writers are created on executor side.
548551
@transient protected var outputWriters: java.util.HashMap[String, OutputWriter] = _
549552

550-
protected var maxOutputWriters = 50;
553+
protected var maxOutputWriters = 50
551554

552555
override protected def initWriters(): Unit = {
553556
outputWriters = new java.util.HashMap[String, OutputWriter]
@@ -602,8 +605,12 @@ private[sql] class DynamicPartitionWriterContainer(
602605
}
603606
}
604607

605-
// The `row` argument is supposed to only contain partition column values which have been casted
606-
// to strings.
608+
/**
609+
* Create the outputWriter to output a given row to disk.
610+
*
611+
* @param row The `row` argument is supposed to only contain partition column values
612+
* which have been casted to strings.
613+
*/
607614
override def outputWriterForRow(row: InternalRow): OutputWriter = {
608615
val partitionPath: String = computePartitionPath(row)
609616

@@ -620,6 +627,22 @@ private[sql] class DynamicPartitionWriterContainer(
620627
}
621628
}
622629

630+
/**
631+
* Create the outputWriter to output a given row to disk. If dealing with sorted data, we
632+
* can close previously used writers since they will no longer be necessary.
633+
*
634+
* @param row The `row` argument is supposed to only contain partition column values
635+
* which have been casted to strings.
636+
* @param shouldCloseWriters If true, close all existing writers before creating new writers
637+
*/
638+
def outputWriterForRow(row: InternalRow, shouldCloseWriters: Boolean): OutputWriter = {
639+
if (shouldCloseWriters) {
640+
clearOutputWriters()
641+
}
642+
643+
outputWriterForRow(row)
644+
}
645+
623646
private def clearOutputWriters(): Unit = {
624647
if (!outputWriters.isEmpty) {
625648
asScalaIterator(outputWriters.values().iterator()).foreach(_.close())

0 commit comments

Comments
 (0)