Skip to content

Commit 43f809f

Browse files
committed
[SPARK-23390][SQL] Register task completion listerners first in ParquetFileFormat
1 parent f217d7d commit 43f809f

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -395,16 +395,19 @@ class ParquetFileFormat
395395
ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get)
396396
}
397397
val taskContext = Option(TaskContext.get())
398-
val parquetReader = if (enableVectorizedReader) {
398+
val iter = if (enableVectorizedReader) {
399399
val vectorizedReader = new VectorizedParquetRecordReader(
400400
convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity)
401+
val recordReaderIterator = new RecordReaderIterator(vectorizedReader)
402+
// Register a task completion lister before `initalization`.
403+
taskContext.foreach(_.addTaskCompletionListener(_ => recordReaderIterator.close()))
401404
vectorizedReader.initialize(split, hadoopAttemptContext)
402405
logDebug(s"Appending $partitionSchema ${file.partitionValues}")
403406
vectorizedReader.initBatch(partitionSchema, file.partitionValues)
404407
if (returningBatch) {
405408
vectorizedReader.enableReturningBatches()
406409
}
407-
vectorizedReader
410+
recordReaderIterator
408411
} else {
409412
logDebug(s"Falling back to parquet-mr")
410413
// ParquetRecordReader returns UnsafeRow
@@ -414,16 +417,16 @@ class ParquetFileFormat
414417
} else {
415418
new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz))
416419
}
420+
val recordReaderIterator = new RecordReaderIterator(reader)
421+
// Register a task completion lister before `initalization`.
422+
taskContext.foreach(_.addTaskCompletionListener(_ => recordReaderIterator.close()))
417423
reader.initialize(split, hadoopAttemptContext)
418-
reader
424+
recordReaderIterator
419425
}
420426

421-
val iter = new RecordReaderIterator(parquetReader)
422-
taskContext.foreach(_.addTaskCompletionListener(_ => iter.close()))
423427

424428
// UnsafeRowParquetRecordReader appends the columns internally to avoid another copy.
425-
if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] &&
426-
enableVectorizedReader) {
429+
if (enableVectorizedReader) {
427430
iter.asInstanceOf[Iterator[InternalRow]]
428431
} else {
429432
val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes

0 commit comments

Comments
 (0)