Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter


/** A helper object for writing FileFormat data out to a location. */
Expand All @@ -64,9 +63,9 @@ object FileFormatWriter extends Logging {
val serializableHadoopConf: SerializableConfiguration,
val outputWriterFactory: OutputWriterFactory,
val allColumns: Seq[Attribute],
val partitionColumns: Seq[Attribute],
val dataColumns: Seq[Attribute],
val bucketSpec: Option[BucketSpec],
val partitionColumns: Seq[Attribute],
val bucketIdExpression: Option[Expression],
val path: String,
val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long)
Expand Down Expand Up @@ -108,9 +107,21 @@ object FileFormatWriter extends Logging {
job.setOutputValueClass(classOf[InternalRow])
FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))

val allColumns = queryExecution.logical.output
val partitionSet = AttributeSet(partitionColumns)
val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we rewrite it to val dataColumns = allColumns.filterNot(partitionColumns.contains), we do not need partitionSet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's so minor, I'll fix it in my next PR


val bucketIdExpression = bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
// guarantee the data distribution is same between shuffle and bucketed data source, which
// enables us to only shuffle one side when join a bucketed table and a normal one.
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
}
val sortColumns = bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
}

// Note: prepareWrite has side effect. It sets "job".
val outputWriterFactory =
fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType)
Expand All @@ -119,23 +130,45 @@ object FileFormatWriter extends Logging {
uuid = UUID.randomUUID().toString,
serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
outputWriterFactory = outputWriterFactory,
allColumns = queryExecution.logical.output,
partitionColumns = partitionColumns,
allColumns = allColumns,
dataColumns = dataColumns,
bucketSpec = bucketSpec,
partitionColumns = partitionColumns,
bucketIdExpression = bucketIdExpression,
path = outputSpec.outputPath,
customPartitionLocations = outputSpec.customPartitionLocations,
maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong)
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
)

// We should first sort by partition columns, then bucket id, and finally sorting columns.
val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
// the sort order doesn't matter
val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan would it be possible to use the logical plan rather than the executedPlan? If the optimizer decides the data is already sorted according according to the logical plan the executedPlan won't include the fields.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be great, but may need some refactoring.

val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
false
} else {
requiredOrdering.zip(actualOrdering).forall {
case (requiredOrder, childOutputOrder) =>
requiredOrder.semanticEquals(childOutputOrder)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because bucketIdExpression is HashPartitioning, this will never match, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's HashPartitioning(...).partitionIdExpression, which returns Pmod(new Murmur3Hash(expressions), Literal(numPartitions)), so it may match

}
}

SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
// This call shouldn't be put into the `try` block below because it only initializes and
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
committer.setupJob(job)

try {
val ret = sparkSession.sparkContext.runJob(queryExecution.toRdd,
val rdd = if (orderingMatched) {
queryExecution.toRdd
} else {
SortExec(
requiredOrdering.map(SortOrder(_, Ascending)),
global = false,
child = queryExecution.executedPlan).execute()
}

val ret = sparkSession.sparkContext.runJob(rdd,
(taskContext: TaskContext, iter: Iterator[InternalRow]) => {
executeTask(
description = description,
Expand Down Expand Up @@ -189,7 +222,7 @@ object FileFormatWriter extends Logging {
committer.setupTask(taskAttemptContext)

val writeTask =
if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) {
if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) {
new SingleDirectoryWriteTask(description, taskAttemptContext, committer)
} else {
new DynamicPartitionWriteTask(description, taskAttemptContext, committer)
Expand Down Expand Up @@ -287,31 +320,16 @@ object FileFormatWriter extends Logging {
* multiple directories (partitions) or files (bucketing).
*/
private class DynamicPartitionWriteTask(
description: WriteJobDescription,
desc: WriteJobDescription,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SingleDirectoryWriteTask is still using description. Change both or keep it unchanged?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to change both to make it consistent.

taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol) extends ExecuteWriteTask {

// currentWriter is initialized whenever we see a new key
private var currentWriter: OutputWriter = _

private val bucketColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap {
spec => spec.bucketColumnNames.map(c => description.allColumns.find(_.name == c).get)
}

private val sortColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => description.allColumns.find(_.name == c).get)
}

private def bucketIdExpression: Option[Expression] = description.bucketSpec.map { spec =>
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
// guarantee the data distribution is same between shuffle and bucketed data source, which
// enables us to only shuffle one side when join a bucketed table and a normal one.
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
}

/** Expressions that given a partition key build a string like: col1=val/col2=val/... */
private def partitionStringExpression: Seq[Expression] = {
description.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
/** Expressions that given partition columns build a path string like: col1=val/col2=val/... */
private def partitionPathExpression: Seq[Expression] = {
desc.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
// TODO: use correct timezone for partition values.
val escaped = ScalaUDF(
ExternalCatalogUtils.escapePathName _,
Expand All @@ -325,35 +343,46 @@ object FileFormatWriter extends Logging {
}

/**
* Open and returns a new OutputWriter given a partition key and optional bucket id.
* Opens a new OutputWriter given a partition key and optional bucket id.
* If bucket id is specified, we will append it to the end of the file name, but before the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit for previous line: Open and returns a ...

this method does not return anything

* file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
*
* @param key vaues for fields consisting of partition keys for the current row
* @param partString a function that projects the partition values into a string
* @param partColsAndBucketId a row consisting of partition columns and a bucket id for the
* current row.
* @param getPartitionPath a function that projects the partition values into a path string.
* @param fileCounter the number of files that have been written in the past for this specific
* partition. This is used to limit the max number of records written for a
* single file. The value should start from 0.
* @param updatedPartitions the set of updated partition paths, we should add the new partition
* path of this writer to it.
*/
private def newOutputWriter(
key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = {
val partDir =
if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0))
partColsAndBucketId: InternalRow,
getPartitionPath: UnsafeProjection,
fileCounter: Int,
updatedPartitions: mutable.Set[String]): Unit = {
val partDir = if (desc.partitionColumns.isEmpty) {
None
} else {
Option(getPartitionPath(partColsAndBucketId).getString(0))
}
partDir.foreach(updatedPartitions.add)

// If the bucket spec is defined, the bucket column is right after the partition columns
val bucketId = if (description.bucketSpec.isDefined) {
BucketingUtils.bucketIdToString(key.getInt(description.partitionColumns.length))
// If the bucketId expression is defined, the bucketId column is right after the partition
// columns.
val bucketId = if (desc.bucketIdExpression.isDefined) {
BucketingUtils.bucketIdToString(partColsAndBucketId.getInt(desc.partitionColumns.length))
} else {
""
}

// This must be in a form that matches our bucketing format. See BucketingUtils.
val ext = f"$bucketId.c$fileCounter%03d" +
description.outputWriterFactory.getFileExtension(taskAttemptContext)
desc.outputWriterFactory.getFileExtension(taskAttemptContext)

val customPath = partDir match {
case Some(dir) =>
description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
case _ =>
None
}
Expand All @@ -363,80 +392,42 @@ object FileFormatWriter extends Logging {
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
}

currentWriter = description.outputWriterFactory.newInstance(
currentWriter = desc.outputWriterFactory.newInstance(
path = path,
dataSchema = description.dataColumns.toStructType,
dataSchema = desc.dataColumns.toStructType,
context = taskAttemptContext)
}

override def execute(iter: Iterator[InternalRow]): Set[String] = {
// We should first sort by partition columns, then bucket id, and finally sorting columns.
val sortingExpressions: Seq[Expression] =
description.partitionColumns ++ bucketIdExpression ++ sortColumns
val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns)

val sortingKeySchema = StructType(sortingExpressions.map {
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
// The sorting expressions are all `Attribute` except bucket id.
case _ => StructField("bucketId", IntegerType, nullable = false)
})

// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(
description.dataColumns, description.allColumns)

// Returns the partition path given a partition key.
val getPartitionStringFunc = UnsafeProjection.create(
Seq(Concat(partitionStringExpression)), description.partitionColumns)

// Sorts the data before write, so that we only need one writer at the same time.
val sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
StructType.fromAttributes(description.dataColumns),
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
TaskContext.get().taskMemoryManager().pageSizeBytes,
SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD))

while (iter.hasNext) {
val currentRow = iter.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}
val getPartitionColsAndBucketId = UnsafeProjection.create(
desc.partitionColumns ++ desc.bucketIdExpression, desc.allColumns)

val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
identity
} else {
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
})
}
// Generates the partition path given the row generated by `getPartitionColsAndBucketId`.
val getPartPath = UnsafeProjection.create(
Seq(Concat(partitionPathExpression)), desc.partitionColumns)

val sortedIterator = sorter.sortedIterator()
// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns)

// If anything below fails, we should abort the task.
var recordsInFile: Long = 0L
var fileCounter = 0
var currentKey: UnsafeRow = null
var currentPartColsAndBucketId: UnsafeRow = null
val updatedPartitions = mutable.Set[String]()
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
// See a new key - write to a new partition (new file).
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")
for (row <- iter) {
val nextPartColsAndBucketId = getPartitionColsAndBucketId(row)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getPartitionColsAndBucketId is an unsafe projection. So nextPartColsAndBucketId is a new unsafe row. Do we still need a copy when assigning it to currentPartColsAndBucketId?

Previously we need a copy because getBucketingKey can be an identity function. So the nextKey can be the same unsafe row.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you take a look at the GenerateUnsafeProject, actually it will reuse the same row instance, so we need to copy.

if (currentPartColsAndBucketId != nextPartColsAndBucketId) {
// See a new partition or bucket - write to a new partition dir (or a new bucket file).
currentPartColsAndBucketId = nextPartColsAndBucketId.copy()
logDebug(s"Writing partition: $currentPartColsAndBucketId")

recordsInFile = 0
fileCounter = 0

releaseResources()
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
val partitionPath = getPartitionStringFunc(currentKey).getString(0)
if (partitionPath.nonEmpty) {
updatedPartitions.add(partitionPath)
}
} else if (description.maxRecordsPerFile > 0 &&
recordsInFile >= description.maxRecordsPerFile) {
newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions)
} else if (desc.maxRecordsPerFile > 0 &&
recordsInFile >= desc.maxRecordsPerFile) {
// Exceeded the threshold in terms of the number of records per file.
// Create a new file by increasing the file counter.
recordsInFile = 0
Expand All @@ -445,10 +436,10 @@ object FileFormatWriter extends Logging {
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")

releaseResources()
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions)
}

currentWriter.write(sortedIterator.getValue)
currentWriter.write(getOutputRow(row))
recordsInFile += 1
}
releaseResources()
Expand Down