Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.{IntegerType, StructType, StringType}
import org.apache.spark.sql.types.{StructField, IntegerType, StructType, StringType}
import org.apache.spark.util.SerializableConfiguration


Expand Down Expand Up @@ -349,67 +349,6 @@ private[sql] class DynamicPartitionWriterContainer(
}
}

private def sameBucket(key1: UnsafeRow, key2: UnsafeRow): Boolean = {
val bucketIdIndex = partitionColumns.length
if (key1.getInt(bucketIdIndex) != key2.getInt(bucketIdIndex)) {
false
} else {
var i = partitionColumns.length - 1
while (i >= 0) {
val dt = partitionColumns(i).dataType
if (key1.get(i, dt) != key2.get(i, dt)) return false
i -= 1
}
true
}
}

private def sortBasedWrite(
sorter: UnsafeKVExternalSorter,
iterator: Iterator[InternalRow],
getSortingKey: UnsafeProjection,
getOutputRow: UnsafeProjection,
getPartitionString: UnsafeProjection,
outputWriters: java.util.HashMap[InternalRow, OutputWriter]): Unit = {
while (iterator.hasNext) {
val currentRow = iterator.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}

logInfo(s"Sorting complete. Writing out partition files one at a time.")

val needNewWriter: (UnsafeRow, UnsafeRow) => Boolean = if (sortColumns.isEmpty) {
(key1, key2) => key1 != key2
} else {
(key1, key2) => key1 == null || !sameBucket(key1, key2)
}

val sortedIterator = sorter.sortedIterator()
var currentKey: UnsafeRow = null
var currentWriter: OutputWriter = null
try {
while (sortedIterator.next()) {
if (needNewWriter(currentKey, sortedIterator.getKey)) {
if (currentWriter != null) {
currentWriter.close()
}
currentKey = sortedIterator.getKey.copy()
logDebug(s"Writing partition: $currentKey")

// Either use an existing file from before, or open a new one.
currentWriter = outputWriters.remove(currentKey)
if (currentWriter == null) {
currentWriter = newOutputWriter(currentKey, getPartitionString)
}
}

currentWriter.writeInternal(sortedIterator.getValue)
}
} finally {
if (currentWriter != null) { currentWriter.close() }
}
}

/**
* Open and returns a new OutputWriter given a partition key and optional bucket id.
* If bucket id is specified, we will append it to the end of the file name, but before the
Expand All @@ -435,22 +374,18 @@ private[sql] class DynamicPartitionWriterContainer(
}

def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
val outputWriters = new java.util.HashMap[InternalRow, OutputWriter]
executorSideSetup(taskContext)

var outputWritersCleared = false

// We should first sort by partition columns, then bucket id, and finally sorting columns.
val getSortingKey =
UnsafeProjection.create(partitionColumns ++ bucketIdExpression ++ sortColumns, inputSchema)

val sortingKeySchema = if (bucketSpec.isEmpty) {
StructType.fromAttributes(partitionColumns)
} else { // If it's bucketed, we should also consider bucket id as part of the key.
val fields = StructType.fromAttributes(partitionColumns)
.add("bucketId", IntegerType, nullable = false) ++ StructType.fromAttributes(sortColumns)
StructType(fields)
}
val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns

val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)

val sortingKeySchema = StructType(sortingExpressions.map {
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
// The sorting expressions are all `Attribute` except bucket id.
case _ => StructField("bucketId", IntegerType, nullable = false)
})

// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
Expand All @@ -461,54 +396,49 @@ private[sql] class DynamicPartitionWriterContainer(

// If anything below fails, we should abort the task.
try {
// If there is no sorting columns, we set sorter to null and try the hash-based writing first,
// and fill the sorter if there are too many writers and we need to fall back on sorting.
// If there are sorting columns, then we have to sort the data anyway, and no need to try the
// hash-based writing first.
var sorter: UnsafeKVExternalSorter = if (sortColumns.nonEmpty) {
new UnsafeKVExternalSorter(
sortingKeySchema,
StructType.fromAttributes(dataColumns),
SparkEnv.get.blockManager,
TaskContext.get().taskMemoryManager().pageSizeBytes)
// Sorts the data before write, so that we only need one writer at the same time.
// TODO: inject a local sort operator in planning.
val sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
StructType.fromAttributes(dataColumns),
SparkEnv.get.blockManager,
TaskContext.get().taskMemoryManager().pageSizeBytes)

while (iterator.hasNext) {
val currentRow = iterator.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}

logInfo(s"Sorting complete. Writing out partition files one at a time.")

val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
identity
} else {
null
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
})
}
while (iterator.hasNext && sorter == null) {
val inputRow = iterator.next()
// When we reach here, the `sortColumns` must be empty, so the sorting key is hashing key.
val currentKey = getSortingKey(inputRow)
var currentWriter = outputWriters.get(currentKey)

if (currentWriter == null) {
if (outputWriters.size < maxOpenFiles) {

val sortedIterator = sorter.sortedIterator()
var currentKey: UnsafeRow = null
var currentWriter: OutputWriter = null
try {
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
if (currentWriter != null) {
currentWriter.close()
}
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")

currentWriter = newOutputWriter(currentKey, getPartitionString)
outputWriters.put(currentKey.copy(), currentWriter)
currentWriter.writeInternal(getOutputRow(inputRow))
} else {
logInfo(s"Maximum partitions reached, falling back on sorting.")
sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
StructType.fromAttributes(dataColumns),
SparkEnv.get.blockManager,
TaskContext.get().taskMemoryManager().pageSizeBytes)
sorter.insertKV(currentKey, getOutputRow(inputRow))
}
} else {
currentWriter.writeInternal(getOutputRow(inputRow))
}
}

// If the sorter is not null that means that we reached the maxFiles above and need to finish
// using external sort, or there are sorting columns and we need to sort the whole data set.
if (sorter != null) {
sortBasedWrite(
sorter,
iterator,
getSortingKey,
getOutputRow,
getPartitionString,
outputWriters)
currentWriter.writeInternal(sortedIterator.getValue)
}
} finally {
if (currentWriter != null) { currentWriter.close() }
}

commitTask()
Expand All @@ -518,31 +448,5 @@ private[sql] class DynamicPartitionWriterContainer(
abortTask()
throw new SparkException("Task failed while writing rows.", cause)
}

def clearOutputWriters(): Unit = {
if (!outputWritersCleared) {
outputWriters.asScala.values.foreach(_.close())
outputWriters.clear()
outputWritersCleared = true
}
}

def commitTask(): Unit = {
try {
clearOutputWriters()
super.commitTask()
} catch {
case cause: Throwable =>
throw new RuntimeException("Failed to commit task", cause)
}
}

def abortTask(): Unit = {
try {
clearOutputWriters()
} finally {
super.abortTask()
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ trait HadoopFsRelationProvider {
partitionColumns: Option[StructType],
parameters: Map[String, String]): HadoopFsRelation

// TODO: expose bucket API to users.
private[sql] def createRelation(
sqlContext: SQLContext,
paths: Array[String],
Expand Down Expand Up @@ -370,7 +369,6 @@ abstract class OutputWriterFactory extends Serializable {
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter

// TODO: expose bucket API to users.
private[sql] def newInstance(
path: String,
bucketId: Option[Int],
Expand Down Expand Up @@ -460,7 +458,6 @@ abstract class HadoopFsRelation private[sql](

private var _partitionSpec: PartitionSpec = _

// TODO: expose bucket API to users.
private[sql] def bucketSpec: Option[BucketSpec] = None

private class FileStatusCache {
Expand Down