Skip to content

Commit af37bdd

Browse files
HyukjinKwonmarmbrus
authored andcommitted
[SPARK-10216][SQL] Avoid creating empty files during overwriting with group by query
## What changes were proposed in this pull request? Currently, `INSERT INTO` with `GROUP BY` query tries to make at least 200 files (default value of `spark.sql.shuffle.partition`), which results in lots of empty files. This PR makes it avoid creating empty files during overwriting into Hive table and in internal data sources with group by query. This checks whether the given partition has data in it or not and creates/writes file only when it actually has data. ## How was this patch tested? Unittests in `InsertIntoHiveTableSuite` and `HadoopFsRelationTest`. Closes #8411 Author: hyukjinkwon <[email protected]> Author: Keuntae Park <[email protected]> Closes #12855 from HyukjinKwon/pr/8411. (cherry picked from commit 8d05a7a) Signed-off-by: Michael Armbrust <[email protected]>
1 parent adc1c26 commit af37bdd

File tree

4 files changed

+182
-126
lines changed

4 files changed

+182
-126
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala

Lines changed: 113 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -239,48 +239,50 @@ private[sql] class DefaultWriterContainer(
239239
extends BaseWriterContainer(relation, job, isAppend) {
240240

241241
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
242-
executorSideSetup(taskContext)
243-
val configuration = taskAttemptContext.getConfiguration
244-
configuration.set("spark.sql.sources.output.path", outputPath)
245-
var writer = newOutputWriter(getWorkPath)
246-
writer.initConverter(dataSchema)
247-
248-
// If anything below fails, we should abort the task.
249-
try {
250-
Utils.tryWithSafeFinallyAndFailureCallbacks {
251-
while (iterator.hasNext) {
252-
val internalRow = iterator.next()
253-
writer.writeInternal(internalRow)
254-
}
255-
commitTask()
256-
}(catchBlock = abortTask())
257-
} catch {
258-
case t: Throwable =>
259-
throw new SparkException("Task failed while writing rows", t)
260-
}
242+
if (iterator.hasNext) {
243+
executorSideSetup(taskContext)
244+
val configuration = taskAttemptContext.getConfiguration
245+
configuration.set("spark.sql.sources.output.path", outputPath)
246+
var writer = newOutputWriter(getWorkPath)
247+
writer.initConverter(dataSchema)
261248

262-
def commitTask(): Unit = {
249+
// If anything below fails, we should abort the task.
263250
try {
264-
if (writer != null) {
265-
writer.close()
266-
writer = null
267-
}
268-
super.commitTask()
251+
Utils.tryWithSafeFinallyAndFailureCallbacks {
252+
while (iterator.hasNext) {
253+
val internalRow = iterator.next()
254+
writer.writeInternal(internalRow)
255+
}
256+
commitTask()
257+
}(catchBlock = abortTask())
269258
} catch {
270-
case cause: Throwable =>
271-
// This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
272-
// will cause `abortTask()` to be invoked.
273-
throw new RuntimeException("Failed to commit task", cause)
259+
case t: Throwable =>
260+
throw new SparkException("Task failed while writing rows", t)
274261
}
275-
}
276262

277-
def abortTask(): Unit = {
278-
try {
279-
if (writer != null) {
280-
writer.close()
263+
def commitTask(): Unit = {
264+
try {
265+
if (writer != null) {
266+
writer.close()
267+
writer = null
268+
}
269+
super.commitTask()
270+
} catch {
271+
case cause: Throwable =>
272+
// This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
273+
// will cause `abortTask()` to be invoked.
274+
throw new RuntimeException("Failed to commit task", cause)
275+
}
276+
}
277+
278+
def abortTask(): Unit = {
279+
try {
280+
if (writer != null) {
281+
writer.close()
282+
}
283+
} finally {
284+
super.abortTask()
281285
}
282-
} finally {
283-
super.abortTask()
284286
}
285287
}
286288
}
@@ -363,84 +365,87 @@ private[sql] class DynamicPartitionWriterContainer(
363365
}
364366

365367
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
366-
executorSideSetup(taskContext)
367-
368-
// We should first sort by partition columns, then bucket id, and finally sorting columns.
369-
val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns
370-
val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)
371-
372-
val sortingKeySchema = StructType(sortingExpressions.map {
373-
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
374-
// The sorting expressions are all `Attribute` except bucket id.
375-
case _ => StructField("bucketId", IntegerType, nullable = false)
376-
})
377-
378-
// Returns the data columns to be written given an input row
379-
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
380-
381-
// Returns the partition path given a partition key.
382-
val getPartitionString =
383-
UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
384-
385-
// Sorts the data before write, so that we only need one writer at the same time.
386-
// TODO: inject a local sort operator in planning.
387-
val sorter = new UnsafeKVExternalSorter(
388-
sortingKeySchema,
389-
StructType.fromAttributes(dataColumns),
390-
SparkEnv.get.blockManager,
391-
SparkEnv.get.serializerManager,
392-
TaskContext.get().taskMemoryManager().pageSizeBytes)
393-
394-
while (iterator.hasNext) {
395-
val currentRow = iterator.next()
396-
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
397-
}
398-
logInfo(s"Sorting complete. Writing out partition files one at a time.")
399-
400-
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
401-
identity
402-
} else {
403-
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
404-
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
368+
if (iterator.hasNext) {
369+
executorSideSetup(taskContext)
370+
371+
// We should first sort by partition columns, then bucket id, and finally sorting columns.
372+
val sortingExpressions: Seq[Expression] =
373+
partitionColumns ++ bucketIdExpression ++ sortColumns
374+
val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)
375+
376+
val sortingKeySchema = StructType(sortingExpressions.map {
377+
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
378+
// The sorting expressions are all `Attribute` except bucket id.
379+
case _ => StructField("bucketId", IntegerType, nullable = false)
405380
})
406-
}
407381

408-
val sortedIterator = sorter.sortedIterator()
382+
// Returns the data columns to be written given an input row
383+
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
384+
385+
// Returns the partition path given a partition key.
386+
val getPartitionString =
387+
UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
388+
389+
// Sorts the data before write, so that we only need one writer at the same time.
390+
// TODO: inject a local sort operator in planning.
391+
val sorter = new UnsafeKVExternalSorter(
392+
sortingKeySchema,
393+
StructType.fromAttributes(dataColumns),
394+
SparkEnv.get.blockManager,
395+
SparkEnv.get.serializerManager,
396+
TaskContext.get().taskMemoryManager().pageSizeBytes)
397+
398+
while (iterator.hasNext) {
399+
val currentRow = iterator.next()
400+
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
401+
}
402+
logInfo(s"Sorting complete. Writing out partition files one at a time.")
403+
404+
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
405+
identity
406+
} else {
407+
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
408+
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
409+
})
410+
}
409411

410-
// If anything below fails, we should abort the task.
411-
var currentWriter: OutputWriter = null
412-
try {
413-
Utils.tryWithSafeFinallyAndFailureCallbacks {
414-
var currentKey: UnsafeRow = null
415-
while (sortedIterator.next()) {
416-
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
417-
if (currentKey != nextKey) {
418-
if (currentWriter != null) {
419-
currentWriter.close()
420-
currentWriter = null
421-
}
422-
currentKey = nextKey.copy()
423-
logDebug(s"Writing partition: $currentKey")
412+
val sortedIterator = sorter.sortedIterator()
424413

425-
currentWriter = newOutputWriter(currentKey, getPartitionString)
414+
// If anything below fails, we should abort the task.
415+
var currentWriter: OutputWriter = null
416+
try {
417+
Utils.tryWithSafeFinallyAndFailureCallbacks {
418+
var currentKey: UnsafeRow = null
419+
while (sortedIterator.next()) {
420+
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
421+
if (currentKey != nextKey) {
422+
if (currentWriter != null) {
423+
currentWriter.close()
424+
currentWriter = null
425+
}
426+
currentKey = nextKey.copy()
427+
logDebug(s"Writing partition: $currentKey")
428+
429+
currentWriter = newOutputWriter(currentKey, getPartitionString)
430+
}
431+
currentWriter.writeInternal(sortedIterator.getValue)
432+
}
433+
if (currentWriter != null) {
434+
currentWriter.close()
435+
currentWriter = null
426436
}
427-
currentWriter.writeInternal(sortedIterator.getValue)
428-
}
429-
if (currentWriter != null) {
430-
currentWriter.close()
431-
currentWriter = null
432-
}
433437

434-
commitTask()
435-
}(catchBlock = {
436-
if (currentWriter != null) {
437-
currentWriter.close()
438-
}
439-
abortTask()
440-
})
441-
} catch {
442-
case t: Throwable =>
443-
throw new SparkException("Task failed while writing rows", t)
438+
commitTask()
439+
}(catchBlock = {
440+
if (currentWriter != null) {
441+
currentWriter.close()
442+
}
443+
abortTask()
444+
})
445+
} catch {
446+
case t: Throwable =>
447+
throw new SparkException("Task failed while writing rows", t)
448+
}
444449
}
445450
}
446451
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -178,19 +178,21 @@ private[hive] class SparkHiveWriterContainer(
178178

179179
// this function is executed on executor side
180180
def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = {
181-
val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite()
182-
executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)
183-
184-
iterator.foreach { row =>
185-
var i = 0
186-
while (i < fieldOIs.length) {
187-
outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
188-
i += 1
181+
if (iterator.hasNext) {
182+
val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite()
183+
executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)
184+
185+
iterator.foreach { row =>
186+
var i = 0
187+
while (i < fieldOIs.length) {
188+
outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
189+
i += 1
190+
}
191+
writer.write(serializer.serialize(outputData, standardOI))
189192
}
190-
writer.write(serializer.serialize(outputData, standardOI))
191-
}
192193

193-
close()
194+
close()
195+
}
194196
}
195197
}
196198

sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ package org.apache.spark.sql.hive
1919

2020
import java.io.File
2121

22-
import org.apache.hadoop.hive.conf.HiveConf
2322
import org.scalatest.BeforeAndAfter
2423

2524
import org.apache.spark.SparkException
26-
import org.apache.spark.sql.{QueryTest, _}
25+
import org.apache.spark.sql._
2726
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
2827
import org.apache.spark.sql.hive.test.TestHiveSingleton
28+
import org.apache.spark.sql.internal.SQLConf
2929
import org.apache.spark.sql.test.SQLTestUtils
3030
import org.apache.spark.sql.types._
3131
import org.apache.spark.util.Utils
@@ -118,10 +118,10 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
118118

119119
sql(
120120
s"""
121-
|CREATE TABLE table_with_partition(c1 string)
122-
|PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string)
123-
|location '${tmpDir.toURI.toString}'
124-
""".stripMargin)
121+
|CREATE TABLE table_with_partition(c1 string)
122+
|PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string)
123+
|location '${tmpDir.toURI.toString}'
124+
""".stripMargin)
125125
sql(
126126
"""
127127
|INSERT OVERWRITE TABLE table_with_partition
@@ -216,6 +216,35 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
216216
sql("DROP TABLE hiveTableWithStructValue")
217217
}
218218

219+
test("SPARK-10216: Avoid empty files during overwrite into Hive table with group by query") {
220+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
221+
val testDataset = hiveContext.sparkContext.parallelize(
222+
(1 to 2).map(i => TestData(i, i.toString))).toDF()
223+
testDataset.createOrReplaceTempView("testDataset")
224+
225+
val tmpDir = Utils.createTempDir()
226+
sql(
227+
s"""
228+
|CREATE TABLE table1(key int,value string)
229+
|location '${tmpDir.toURI.toString}'
230+
""".stripMargin)
231+
sql(
232+
"""
233+
|INSERT OVERWRITE TABLE table1
234+
|SELECT count(key), value FROM testDataset GROUP BY value
235+
""".stripMargin)
236+
237+
val overwrittenFiles = tmpDir.listFiles()
238+
.filter(f => f.isFile && !f.getName.endsWith(".crc"))
239+
.sortBy(_.getName)
240+
val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0)
241+
242+
assert(overwrittenFiles === overwrittenFilesWithoutEmpty)
243+
244+
sql("DROP TABLE table1")
245+
}
246+
}
247+
219248
test("Reject partitioning that does not match table") {
220249
withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
221250
sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")

sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter
2929
import org.apache.spark.deploy.SparkHadoopUtil
3030
import org.apache.spark.sql._
3131
import org.apache.spark.sql.execution.DataSourceScanExec
32-
import org.apache.spark.sql.execution.datasources.{FileScanRDD, HadoopFsRelation, LocalityTestFileSystem, LogicalRelation}
32+
import org.apache.spark.sql.execution.datasources.{FileScanRDD, LocalityTestFileSystem}
3333
import org.apache.spark.sql.hive.test.TestHiveSingleton
3434
import org.apache.spark.sql.internal.SQLConf
3535
import org.apache.spark.sql.test.SQLTestUtils
@@ -879,6 +879,26 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
879879
}
880880
}
881881
}
882+
883+
test("SPARK-10216: Avoid empty files during overwriting with group by query") {
884+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
885+
withTempPath { path =>
886+
val df = spark.range(0, 5)
887+
val groupedDF = df.groupBy("id").count()
888+
groupedDF.write
889+
.format(dataSourceName)
890+
.mode(SaveMode.Overwrite)
891+
.save(path.getCanonicalPath)
892+
893+
val overwrittenFiles = path.listFiles()
894+
.filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_"))
895+
.sortBy(_.getName)
896+
val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0)
897+
898+
assert(overwrittenFiles === overwrittenFilesWithoutEmpty)
899+
}
900+
}
901+
}
882902
}
883903

884904
// This class is used to test SPARK-8578. We should not use any custom output committer when

0 commit comments

Comments
 (0)