Skip to content

Commit 2ba3ff0

Browse files
committed
[SPARK-10216][SQL] Revert "[] Avoid creating empty files during overwrit…
This reverts commit 8d05a7a from #12855, which seems to have caused regressions when working with empty DataFrames. Author: Michael Armbrust <[email protected]> Closes #13181 from marmbrus/revert12855.
1 parent dfa61f7 commit 2ba3ff0

File tree

4 files changed

+126
-182
lines changed

4 files changed

+126
-182
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala

Lines changed: 108 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -239,50 +239,48 @@ private[sql] class DefaultWriterContainer(
239239
extends BaseWriterContainer(relation, job, isAppend) {
240240

241241
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
242-
if (iterator.hasNext) {
243-
executorSideSetup(taskContext)
244-
val configuration = taskAttemptContext.getConfiguration
245-
configuration.set("spark.sql.sources.output.path", outputPath)
246-
var writer = newOutputWriter(getWorkPath)
247-
writer.initConverter(dataSchema)
242+
executorSideSetup(taskContext)
243+
val configuration = taskAttemptContext.getConfiguration
244+
configuration.set("spark.sql.sources.output.path", outputPath)
245+
var writer = newOutputWriter(getWorkPath)
246+
writer.initConverter(dataSchema)
248247

249-
// If anything below fails, we should abort the task.
250-
try {
251-
Utils.tryWithSafeFinallyAndFailureCallbacks {
252-
while (iterator.hasNext) {
253-
val internalRow = iterator.next()
254-
writer.writeInternal(internalRow)
255-
}
256-
commitTask()
257-
}(catchBlock = abortTask())
258-
} catch {
259-
case t: Throwable =>
260-
throw new SparkException("Task failed while writing rows", t)
261-
}
248+
// If anything below fails, we should abort the task.
249+
try {
250+
Utils.tryWithSafeFinallyAndFailureCallbacks {
251+
while (iterator.hasNext) {
252+
val internalRow = iterator.next()
253+
writer.writeInternal(internalRow)
254+
}
255+
commitTask()
256+
}(catchBlock = abortTask())
257+
} catch {
258+
case t: Throwable =>
259+
throw new SparkException("Task failed while writing rows", t)
260+
}
262261

263-
def commitTask(): Unit = {
264-
try {
265-
if (writer != null) {
266-
writer.close()
267-
writer = null
268-
}
269-
super.commitTask()
270-
} catch {
271-
case cause: Throwable =>
272-
// This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
273-
// will cause `abortTask()` to be invoked.
274-
throw new RuntimeException("Failed to commit task", cause)
262+
def commitTask(): Unit = {
263+
try {
264+
if (writer != null) {
265+
writer.close()
266+
writer = null
275267
}
268+
super.commitTask()
269+
} catch {
270+
case cause: Throwable =>
271+
// This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
272+
// will cause `abortTask()` to be invoked.
273+
throw new RuntimeException("Failed to commit task", cause)
276274
}
275+
}
277276

278-
def abortTask(): Unit = {
279-
try {
280-
if (writer != null) {
281-
writer.close()
282-
}
283-
} finally {
284-
super.abortTask()
277+
def abortTask(): Unit = {
278+
try {
279+
if (writer != null) {
280+
writer.close()
285281
}
282+
} finally {
283+
super.abortTask()
286284
}
287285
}
288286
}
@@ -365,87 +363,84 @@ private[sql] class DynamicPartitionWriterContainer(
365363
}
366364

367365
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
368-
if (iterator.hasNext) {
369-
executorSideSetup(taskContext)
370-
371-
// We should first sort by partition columns, then bucket id, and finally sorting columns.
372-
val sortingExpressions: Seq[Expression] =
373-
partitionColumns ++ bucketIdExpression ++ sortColumns
374-
val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)
375-
376-
val sortingKeySchema = StructType(sortingExpressions.map {
377-
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
378-
// The sorting expressions are all `Attribute` except bucket id.
379-
case _ => StructField("bucketId", IntegerType, nullable = false)
380-
})
366+
executorSideSetup(taskContext)
367+
368+
// We should first sort by partition columns, then bucket id, and finally sorting columns.
369+
val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns
370+
val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)
371+
372+
val sortingKeySchema = StructType(sortingExpressions.map {
373+
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
374+
// The sorting expressions are all `Attribute` except bucket id.
375+
case _ => StructField("bucketId", IntegerType, nullable = false)
376+
})
377+
378+
// Returns the data columns to be written given an input row
379+
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
380+
381+
// Returns the partition path given a partition key.
382+
val getPartitionString =
383+
UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
384+
385+
// Sorts the data before write, so that we only need one writer at the same time.
386+
// TODO: inject a local sort operator in planning.
387+
val sorter = new UnsafeKVExternalSorter(
388+
sortingKeySchema,
389+
StructType.fromAttributes(dataColumns),
390+
SparkEnv.get.blockManager,
391+
SparkEnv.get.serializerManager,
392+
TaskContext.get().taskMemoryManager().pageSizeBytes)
393+
394+
while (iterator.hasNext) {
395+
val currentRow = iterator.next()
396+
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
397+
}
398+
logInfo(s"Sorting complete. Writing out partition files one at a time.")
381399

382-
// Returns the data columns to be written given an input row
383-
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
384-
385-
// Returns the partition path given a partition key.
386-
val getPartitionString =
387-
UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
388-
389-
// Sorts the data before write, so that we only need one writer at the same time.
390-
// TODO: inject a local sort operator in planning.
391-
val sorter = new UnsafeKVExternalSorter(
392-
sortingKeySchema,
393-
StructType.fromAttributes(dataColumns),
394-
SparkEnv.get.blockManager,
395-
SparkEnv.get.serializerManager,
396-
TaskContext.get().taskMemoryManager().pageSizeBytes)
397-
398-
while (iterator.hasNext) {
399-
val currentRow = iterator.next()
400-
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
401-
}
402-
logInfo(s"Sorting complete. Writing out partition files one at a time.")
403-
404-
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
405-
identity
406-
} else {
407-
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
408-
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
409-
})
410-
}
400+
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
401+
identity
402+
} else {
403+
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
404+
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
405+
})
406+
}
411407

412-
val sortedIterator = sorter.sortedIterator()
408+
val sortedIterator = sorter.sortedIterator()
413409

414-
// If anything below fails, we should abort the task.
415-
var currentWriter: OutputWriter = null
416-
try {
417-
Utils.tryWithSafeFinallyAndFailureCallbacks {
418-
var currentKey: UnsafeRow = null
419-
while (sortedIterator.next()) {
420-
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
421-
if (currentKey != nextKey) {
422-
if (currentWriter != null) {
423-
currentWriter.close()
424-
currentWriter = null
425-
}
426-
currentKey = nextKey.copy()
427-
logDebug(s"Writing partition: $currentKey")
428-
429-
currentWriter = newOutputWriter(currentKey, getPartitionString)
410+
// If anything below fails, we should abort the task.
411+
var currentWriter: OutputWriter = null
412+
try {
413+
Utils.tryWithSafeFinallyAndFailureCallbacks {
414+
var currentKey: UnsafeRow = null
415+
while (sortedIterator.next()) {
416+
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
417+
if (currentKey != nextKey) {
418+
if (currentWriter != null) {
419+
currentWriter.close()
420+
currentWriter = null
430421
}
431-
currentWriter.writeInternal(sortedIterator.getValue)
432-
}
433-
if (currentWriter != null) {
434-
currentWriter.close()
435-
currentWriter = null
436-
}
422+
currentKey = nextKey.copy()
423+
logDebug(s"Writing partition: $currentKey")
437424

438-
commitTask()
439-
}(catchBlock = {
440-
if (currentWriter != null) {
441-
currentWriter.close()
425+
currentWriter = newOutputWriter(currentKey, getPartitionString)
442426
}
443-
abortTask()
444-
})
445-
} catch {
446-
case t: Throwable =>
447-
throw new SparkException("Task failed while writing rows", t)
448-
}
427+
currentWriter.writeInternal(sortedIterator.getValue)
428+
}
429+
if (currentWriter != null) {
430+
currentWriter.close()
431+
currentWriter = null
432+
}
433+
434+
commitTask()
435+
}(catchBlock = {
436+
if (currentWriter != null) {
437+
currentWriter.close()
438+
}
439+
abortTask()
440+
})
441+
} catch {
442+
case t: Throwable =>
443+
throw new SparkException("Task failed while writing rows", t)
449444
}
450445
}
451446
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -178,21 +178,19 @@ private[hive] class SparkHiveWriterContainer(
178178

179179
// this function is executed on executor side
180180
def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = {
181-
if (iterator.hasNext) {
182-
val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite()
183-
executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)
184-
185-
iterator.foreach { row =>
186-
var i = 0
187-
while (i < fieldOIs.length) {
188-
outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
189-
i += 1
190-
}
191-
writer.write(serializer.serialize(outputData, standardOI))
192-
}
181+
val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite()
182+
executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)
193183

194-
close()
184+
iterator.foreach { row =>
185+
var i = 0
186+
while (i < fieldOIs.length) {
187+
outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
188+
i += 1
189+
}
190+
writer.write(serializer.serialize(outputData, standardOI))
195191
}
192+
193+
close()
196194
}
197195
}
198196

sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ package org.apache.spark.sql.hive
1919

2020
import java.io.File
2121

22+
import org.apache.hadoop.hive.conf.HiveConf
2223
import org.scalatest.BeforeAndAfter
2324

2425
import org.apache.spark.SparkException
25-
import org.apache.spark.sql._
26+
import org.apache.spark.sql.{QueryTest, _}
2627
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
2728
import org.apache.spark.sql.hive.test.TestHiveSingleton
28-
import org.apache.spark.sql.internal.SQLConf
2929
import org.apache.spark.sql.test.SQLTestUtils
3030
import org.apache.spark.sql.types._
3131
import org.apache.spark.util.Utils
@@ -118,10 +118,10 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
118118

119119
sql(
120120
s"""
121-
|CREATE TABLE table_with_partition(c1 string)
122-
|PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string)
123-
|location '${tmpDir.toURI.toString}'
124-
""".stripMargin)
121+
|CREATE TABLE table_with_partition(c1 string)
122+
|PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string)
123+
|location '${tmpDir.toURI.toString}'
124+
""".stripMargin)
125125
sql(
126126
"""
127127
|INSERT OVERWRITE TABLE table_with_partition
@@ -216,35 +216,6 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
216216
sql("DROP TABLE hiveTableWithStructValue")
217217
}
218218

219-
test("SPARK-10216: Avoid empty files during overwrite into Hive table with group by query") {
220-
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
221-
val testDataset = hiveContext.sparkContext.parallelize(
222-
(1 to 2).map(i => TestData(i, i.toString))).toDF()
223-
testDataset.createOrReplaceTempView("testDataset")
224-
225-
val tmpDir = Utils.createTempDir()
226-
sql(
227-
s"""
228-
|CREATE TABLE table1(key int,value string)
229-
|location '${tmpDir.toURI.toString}'
230-
""".stripMargin)
231-
sql(
232-
"""
233-
|INSERT OVERWRITE TABLE table1
234-
|SELECT count(key), value FROM testDataset GROUP BY value
235-
""".stripMargin)
236-
237-
val overwrittenFiles = tmpDir.listFiles()
238-
.filter(f => f.isFile && !f.getName.endsWith(".crc"))
239-
.sortBy(_.getName)
240-
val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0)
241-
242-
assert(overwrittenFiles === overwrittenFilesWithoutEmpty)
243-
244-
sql("DROP TABLE table1")
245-
}
246-
}
247-
248219
test("Reject partitioning that does not match table") {
249220
withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
250221
sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")

sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter
2929
import org.apache.spark.deploy.SparkHadoopUtil
3030
import org.apache.spark.sql._
3131
import org.apache.spark.sql.execution.DataSourceScanExec
32-
import org.apache.spark.sql.execution.datasources.{FileScanRDD, LocalityTestFileSystem}
32+
import org.apache.spark.sql.execution.datasources.{FileScanRDD, HadoopFsRelation, LocalityTestFileSystem, LogicalRelation}
3333
import org.apache.spark.sql.hive.test.TestHiveSingleton
3434
import org.apache.spark.sql.internal.SQLConf
3535
import org.apache.spark.sql.test.SQLTestUtils
@@ -879,26 +879,6 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
879879
}
880880
}
881881
}
882-
883-
test("SPARK-10216: Avoid empty files during overwriting with group by query") {
884-
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
885-
withTempPath { path =>
886-
val df = spark.range(0, 5)
887-
val groupedDF = df.groupBy("id").count()
888-
groupedDF.write
889-
.format(dataSourceName)
890-
.mode(SaveMode.Overwrite)
891-
.save(path.getCanonicalPath)
892-
893-
val overwrittenFiles = path.listFiles()
894-
.filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_"))
895-
.sortBy(_.getName)
896-
val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0)
897-
898-
assert(overwrittenFiles === overwrittenFilesWithoutEmpty)
899-
}
900-
}
901-
}
902882
}
903883

904884
// This class is used to test SPARK-8578. We should not use any custom output committer when

0 commit comments

Comments
 (0)