Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1639,6 +1639,12 @@ object SQLConf {
"java.time.* packages are used for the same purpose.")
.booleanConf
.createWithDefault(false)

val MAX_HASH_BASED_OUTPUT_WRITERS = buildConf("spark.sql.maxHashBasedOutputWriters")
.doc("Maximum number of output writers when doing hash-based write. " +
"If writers exceeding this limit, executor will fall back to sort-based write.")
.intConf
.createWithDefault(200)
}

/**
Expand Down Expand Up @@ -2066,6 +2072,8 @@ class SQLConf extends Serializable with Logging {

def legacyTimeParserEnabled: Boolean = getConf(SQLConf.LEGACY_TIME_PARSER_ENABLED)

def maxHashBasedOutputWriters: Int = getConf(SQLConf.MAX_HASH_BASED_OUTPUT_WRITERS)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources

import java.io.FileNotFoundException

import scala.collection.mutable

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

Expand Down Expand Up @@ -54,7 +56,7 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration)
private[this] var numBytes: Long = 0L
private[this] var numRows: Long = 0L

private[this] var curFile: Option[String] = None
private[this] val allFiles = mutable.HashSet[String]()

/**
* Get the size of the file expected to have been written by a worker.
Expand Down Expand Up @@ -84,27 +86,26 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration)
}

override def newFile(filePath: String): Unit = {
statCurrentFile()
curFile = Some(filePath)
allFiles.add(filePath)
submittedFiles += 1
}

private def statCurrentFile(): Unit = {
curFile.foreach { path =>
private def statAllFiles(): Unit = {
allFiles.foreach { path =>
getFileSize(path).foreach { len =>
numBytes += len
numFiles += 1
}
curFile = None
}
allFiles.clear()
}

override def newRow(row: InternalRow): Unit = {
numRows += 1
}

override def getFinalStats(): WriteTaskStats = {
statCurrentFile()
statAllFiles()

// Reports bytesWritten and recordsWritten to the Spark output metrics.
Option(TaskContext.get()).map(_.taskMetrics().outputMetrics).foreach { outputMetrics =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ import org.apache.spark.util.SerializableConfiguration
/**
* Abstract class for writing out data in a single Spark task.
* Exceptions thrown by the implementation of this trait will automatically trigger task aborts.
*
* It first uses hash-based write to process input rows, i.e., keeping mapping between partition/
* bucket, to the opened output writer. When the number of output writers exceeds threshold,
* it will fallback to sort-based write, i.e., sorting rest of input rows. Then output writers
* can be closed on the fly, after all rows being processed for that partition/bucket.
*/
abstract class FileFormatDataWriter(
description: WriteJobDescription,
Expand All @@ -51,6 +56,12 @@ abstract class FileFormatDataWriter(
protected val statsTrackers: Seq[WriteTaskStatsTracker] =
description.statsTrackers.map(_.newTaskInstance())

/**
* Indicates if we are using sort-based write.
* Because we first try to use hash-based write, its initial value is false.
*/
protected var sortBased: Boolean = false

protected def releaseResources(): Unit = {
if (currentWriter != null) {
try {
Expand All @@ -64,6 +75,14 @@ abstract class FileFormatDataWriter(
/** Writes a record */
def write(record: InternalRow): Unit

/** Get number of currently opened output writers. */
def getNumOfOutputWriters: Int

/** Switch to sort-based write when the hash-based approach is opening too many writers. */
def switchToSortBasedWrite(): Unit = {
sortBased = true
}

/**
* Returns the summary of relative information which
* includes the list of partition strings written out. The list of partitions is sent back
Expand Down Expand Up @@ -94,6 +113,8 @@ class EmptyDirectoryDataWriter(
committer: FileCommitProtocol
) extends FileFormatDataWriter(description, taskAttemptContext, committer) {
override def write(record: InternalRow): Unit = {}

override def getNumOfOutputWriters: Int = 0
}

/** Writes data to a single directory (used for non-dynamic-partition writes). */
Expand Down Expand Up @@ -138,6 +159,8 @@ class SingleDirectoryDataWriter(
statsTrackers.foreach(_.newRow(record))
recordsInFile += 1
}

override def getNumOfOutputWriters: Int = if (currentWriter != null) 1 else 0
}

/**
Expand All @@ -156,16 +179,17 @@ class DynamicPartitionDataWriter(
/** Flag saying whether or not the data to be written out is bucketed. */
private val isBucketed = description.bucketIdExpression.isDefined

/** Mapping between partition/bucket and its output writer. */
private val writerMap = mutable.HashMap[WriterIndex, WriterInfo]()

assert(isPartitioned || isBucketed,
s"""DynamicPartitionWriteTask should be used for writing out data that's either
|partitioned or bucketed. In this case neither is true.
|WriteJobDescription: $description
""".stripMargin)

private var fileCounter: Int = _
private var recordsInFile: Long = _
private var currentPartionValues: Option[UnsafeRow] = None
private var currentBucketId: Option[Int] = None
private var currentWriterIndex: WriterIndex = WriterIndex(None, None)
private var currentWriterInfo: WriterInfo = WriterInfo(null, 0, 0)

/** Extracts the partition values out of an input row. */
private lazy val getPartitionValues: InternalRow => UnsafeRow = {
Expand Down Expand Up @@ -211,10 +235,10 @@ class DynamicPartitionDataWriter(
* belong to
* @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to
*/
private def newOutputWriter(partitionValues: Option[InternalRow], bucketId: Option[Int]): Unit = {
recordsInFile = 0
releaseResources()

private def newOutputWriter(
partitionValues: Option[UnsafeRow],
bucketId: Option[Int],
fileCounter: Int): OutputWriter = {
val partDir = partitionValues.map(getPartitionPath(_))
partDir.foreach(updatedPartitions.add)

Expand All @@ -233,46 +257,98 @@ class DynamicPartitionDataWriter(
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
}

currentWriter = description.outputWriterFactory.newInstance(
statsTrackers.foreach(_.newFile(currentPath))

description.outputWriterFactory.newInstance(
path = currentPath,
dataSchema = description.dataColumns.toStructType,
context = taskAttemptContext)
}

statsTrackers.foreach(_.newFile(currentPath))
private def releaseCurrentWriterResource(): Unit = {
if (currentWriterInfo.writer != null) {
try {
currentWriterInfo.writer.close()
} finally {
currentWriterInfo.writer = null
}
writerMap -= currentWriterIndex
}
}

override def releaseResources(): Unit = {
writerMap.values.foreach(writerInfo => {
if (writerInfo.writer != null) {
try {
writerInfo.writer.close()
} finally {
writerInfo.writer = null
}
}
})
}

override def write(record: InternalRow): Unit = {
val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None
val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None

if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) {
// See a new partition or bucket - write to a new partition dir (or a new bucket file).
if (isPartitioned && currentPartionValues != nextPartitionValues) {
currentPartionValues = Some(nextPartitionValues.get.copy())
statsTrackers.foreach(_.newPartition(currentPartionValues.get))
}
if (isBucketed) {
currentBucketId = nextBucketId
statsTrackers.foreach(_.newBucket(currentBucketId.get))
if (currentWriterIndex.partitionValues != nextPartitionValues ||
currentWriterIndex.bucketId != nextBucketId) {
if (sortBased) {
// The output writer can be closed now in case of sort-based write,
// because no more rows will be written with this writer.
releaseCurrentWriterResource()
}

fileCounter = 0
newOutputWriter(currentPartionValues, currentBucketId)
val nextWriterIndex = WriterIndex(nextPartitionValues.map(_.copy()), nextBucketId)
if (writerMap.contains(nextWriterIndex)) {
// Re-use the existing output writer.
currentWriterInfo = writerMap(nextWriterIndex)
} else {
// See a new partition or bucket - write to a new partition dir (or a new bucket file),
// create a new output writer, and add mapping between partition/bucket to writer.
if (isPartitioned &&
currentWriterIndex.partitionValues != nextWriterIndex.partitionValues) {
statsTrackers.foreach(_.newPartition(nextWriterIndex.partitionValues.get))
}
if (isBucketed && currentWriterIndex.bucketId != nextWriterIndex.bucketId) {
statsTrackers.foreach(_.newBucket(nextWriterIndex.bucketId.get))
}
val newWriter = newOutputWriter(
nextWriterIndex.partitionValues,
nextWriterIndex.bucketId,
0)
currentWriterInfo = WriterInfo(newWriter, 0, 0)
writerMap(nextWriterIndex) = currentWriterInfo
}
currentWriterIndex = nextWriterIndex
} else if (description.maxRecordsPerFile > 0 &&
recordsInFile >= description.maxRecordsPerFile) {
currentWriterInfo.recordsInFile >= description.maxRecordsPerFile) {
// Exceeded the threshold in terms of the number of records per file.
// Create a new file by increasing the file counter.
fileCounter += 1
assert(fileCounter < MAX_FILE_COUNTER,
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")

newOutputWriter(currentPartionValues, currentBucketId)
// Create a new file and output writer by increasing the file counter.
currentWriterInfo.fileCounter += 1
assert(currentWriterInfo.fileCounter < MAX_FILE_COUNTER,
s"File counter $currentWriterInfo.fileCounter is beyond max value $MAX_FILE_COUNTER")

currentWriterInfo.writer = newOutputWriter(
currentWriterIndex.partitionValues,
currentWriterIndex.bucketId,
currentWriterInfo.fileCounter)
}
val outputRow = getOutputRow(record)
currentWriter.write(outputRow)
currentWriterInfo.writer.write(outputRow)
statsTrackers.foreach(_.newRow(outputRow))
recordsInFile += 1
currentWriterInfo.recordsInFile += 1
}

override def getNumOfOutputWriters: Int = writerMap.size

/** Wrapper class for partition value and bucket id to index output writer. */
private case class WriterIndex(partitionValues: Option[UnsafeRow], bucketId: Option[Int])

/** Wrapper class for output writer bookkeeping information. */
private case class WriterInfo(
var writer: OutputWriter, var recordsInFile: Long, var fileCounter: Int)
}

/** A shared job description for all the write tasks. */
Expand All @@ -288,7 +364,10 @@ class WriteJobDescription(
val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long,
val timeZoneId: String,
val statsTrackers: Seq[WriteJobStatsTracker])
val statsTrackers: Seq[WriteJobStatsTracker],
val maxHashBasedOutputWriters: Int,
val enableRadixSort: Boolean,
val sortOrderWithPartitionsAndBuckets: Seq[SortOrder])
extends Serializable {

assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
Expand Down
Loading