Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -239,48 +239,50 @@ private[sql] class DefaultWriterContainer(
extends BaseWriterContainer(relation, job, isAppend) {

def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
executorSideSetup(taskContext)
val configuration = taskAttemptContext.getConfiguration
configuration.set("spark.sql.sources.output.path", outputPath)
var writer = newOutputWriter(getWorkPath)
writer.initConverter(dataSchema)

// If anything below fails, we should abort the task.
try {
Utils.tryWithSafeFinallyAndFailureCallbacks {
while (iterator.hasNext) {
val internalRow = iterator.next()
writer.writeInternal(internalRow)
}
commitTask()
}(catchBlock = abortTask())
} catch {
case t: Throwable =>
throw new SparkException("Task failed while writing rows", t)
}
if (iterator.hasNext) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simply added iterator.hasNext check.

executorSideSetup(taskContext)
val configuration = taskAttemptContext.getConfiguration
configuration.set("spark.sql.sources.output.path", outputPath)
var writer = newOutputWriter(getWorkPath)
writer.initConverter(dataSchema)

def commitTask(): Unit = {
// If anything below fails, we should abort the task.
try {
if (writer != null) {
writer.close()
writer = null
}
super.commitTask()
Utils.tryWithSafeFinallyAndFailureCallbacks {
while (iterator.hasNext) {
val internalRow = iterator.next()
writer.writeInternal(internalRow)
}
commitTask()
}(catchBlock = abortTask())
} catch {
case cause: Throwable =>
// This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
// will cause `abortTask()` to be invoked.
throw new RuntimeException("Failed to commit task", cause)
case t: Throwable =>
throw new SparkException("Task failed while writing rows", t)
}
}

def abortTask(): Unit = {
try {
if (writer != null) {
writer.close()
def commitTask(): Unit = {
try {
if (writer != null) {
writer.close()
writer = null
}
super.commitTask()
} catch {
case cause: Throwable =>
// This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
// will cause `abortTask()` to be invoked.
throw new RuntimeException("Failed to commit task", cause)
}
}

def abortTask(): Unit = {
try {
if (writer != null) {
writer.close()
}
} finally {
super.abortTask()
}
} finally {
super.abortTask()
}
}
}
Expand Down Expand Up @@ -363,84 +365,87 @@ private[sql] class DynamicPartitionWriterContainer(
}

def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
executorSideSetup(taskContext)

// We should first sort by partition columns, then bucket id, and finally sorting columns.
val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns
val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)

val sortingKeySchema = StructType(sortingExpressions.map {
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
// The sorting expressions are all `Attribute` except bucket id.
case _ => StructField("bucketId", IntegerType, nullable = false)
})

// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)

// Returns the partition path given a partition key.
val getPartitionString =
UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)

// Sorts the data before write, so that we only need one writer at the same time.
// TODO: inject a local sort operator in planning.
val sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
StructType.fromAttributes(dataColumns),
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
TaskContext.get().taskMemoryManager().pageSizeBytes)

while (iterator.hasNext) {
val currentRow = iterator.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}
logInfo(s"Sorting complete. Writing out partition files one at a time.")

val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
identity
} else {
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
if (iterator.hasNext) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here as well. Simply added iterator.hasNext check.

executorSideSetup(taskContext)

// We should first sort by partition columns, then bucket id, and finally sorting columns.
val sortingExpressions: Seq[Expression] =
partitionColumns ++ bucketIdExpression ++ sortColumns
val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)

val sortingKeySchema = StructType(sortingExpressions.map {
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
// The sorting expressions are all `Attribute` except bucket id.
case _ => StructField("bucketId", IntegerType, nullable = false)
})
}

val sortedIterator = sorter.sortedIterator()
// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)

// Returns the partition path given a partition key.
val getPartitionString =
UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)

// Sorts the data before write, so that we only need one writer at the same time.
// TODO: inject a local sort operator in planning.
val sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
StructType.fromAttributes(dataColumns),
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
TaskContext.get().taskMemoryManager().pageSizeBytes)

while (iterator.hasNext) {
val currentRow = iterator.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}
logInfo(s"Sorting complete. Writing out partition files one at a time.")

val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
identity
} else {
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
})
}

// If anything below fails, we should abort the task.
var currentWriter: OutputWriter = null
try {
Utils.tryWithSafeFinallyAndFailureCallbacks {
var currentKey: UnsafeRow = null
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")
val sortedIterator = sorter.sortedIterator()

currentWriter = newOutputWriter(currentKey, getPartitionString)
// If anything below fails, we should abort the task.
var currentWriter: OutputWriter = null
try {
Utils.tryWithSafeFinallyAndFailureCallbacks {
var currentKey: UnsafeRow = null
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")

currentWriter = newOutputWriter(currentKey, getPartitionString)
}
currentWriter.writeInternal(sortedIterator.getValue)
}
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
currentWriter.writeInternal(sortedIterator.getValue)
}
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}

commitTask()
}(catchBlock = {
if (currentWriter != null) {
currentWriter.close()
}
abortTask()
})
} catch {
case t: Throwable =>
throw new SparkException("Task failed while writing rows", t)
commitTask()
}(catchBlock = {
if (currentWriter != null) {
currentWriter.close()
}
abortTask()
})
} catch {
case t: Throwable =>
throw new SparkException("Task failed while writing rows", t)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,21 @@ private[hive] class SparkHiveWriterContainer(

// this function is executed on executor side
def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = {
val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite()
executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)

iterator.foreach { row =>
var i = 0
while (i < fieldOIs.length) {
outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
i += 1
if (iterator.hasNext) {
val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite()
executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)

iterator.foreach { row =>
var i = 0
while (i < fieldOIs.length) {
outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
i += 1
}
writer.write(serializer.serialize(outputData, standardOI))
}
writer.write(serializer.serialize(outputData, standardOI))
}

close()
close()
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ package org.apache.spark.sql.hive

import java.io.File

import org.apache.hadoop.hive.conf.HiveConf
import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkException
import org.apache.spark.sql.{QueryTest, _}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -118,10 +118,10 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef

sql(
s"""
|CREATE TABLE table_with_partition(c1 string)
|PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string)
|location '${tmpDir.toURI.toString}'
""".stripMargin)
|CREATE TABLE table_with_partition(c1 string)
|PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string)
|location '${tmpDir.toURI.toString}'
""".stripMargin)
sql(
"""
|INSERT OVERWRITE TABLE table_with_partition
Expand Down Expand Up @@ -216,6 +216,35 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
sql("DROP TABLE hiveTableWithStructValue")
}

test("SPARK-10216: Avoid empty files during overwrite into Hive table with group by query") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
val testDataset = hiveContext.sparkContext.parallelize(
(1 to 2).map(i => TestData(i, i.toString))).toDF()
testDataset.createOrReplaceTempView("testDataset")

val tmpDir = Utils.createTempDir()
sql(
s"""
|CREATE TABLE table1(key int,value string)
|location '${tmpDir.toURI.toString}'
""".stripMargin)
sql(
"""
|INSERT OVERWRITE TABLE table1
|SELECT count(key), value FROM testDataset GROUP BY value
""".stripMargin)

val overwrittenFiles = tmpDir.listFiles()
.filter(f => f.isFile && !f.getName.endsWith(".crc"))
.sortBy(_.getName)
val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0)

assert(overwrittenFiles === overwrittenFilesWithoutEmpty)

sql("DROP TABLE table1")
}
}

test("Reject partitioning that does not match table") {
withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql._
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.datasources.{FileScanRDD, HadoopFsRelation, LocalityTestFileSystem, LogicalRelation}
import org.apache.spark.sql.execution.datasources.{FileScanRDD, LocalityTestFileSystem}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
Expand Down Expand Up @@ -879,6 +879,26 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
}
}
}

test("SPARK-10216: Avoid empty files during overwriting with group by query") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
withTempPath { path =>
val df = spark.range(0, 5)
val groupedDF = df.groupBy("id").count()
groupedDF.write
.format(dataSourceName)
.mode(SaveMode.Overwrite)
.save(path.getCanonicalPath)

val overwrittenFiles = path.listFiles()
.filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_"))
.sortBy(_.getName)
val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0)

assert(overwrittenFiles === overwrittenFilesWithoutEmpty)
}
}
}
}

// This class is used to test SPARK-8578. We should not use any custom output committer when
Expand Down