From d249e1bb9a2f687ce856ba0e427f2c7188fd6c0b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 12 Feb 2017 01:01:53 -0800 Subject: [PATCH 1/5] advoid unnecessary sort in FileFormatWriter --- .../datasources/FileFormatWriter.scala | 171 ++++++++---------- 1 file changed, 80 insertions(+), 91 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index be13cbc51a9d..048305438b14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -38,10 +38,9 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter} -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** A helper object for writing FileFormat data out to a location. */ @@ -64,9 +63,10 @@ object FileFormatWriter extends Logging { val serializableHadoopConf: SerializableConfiguration, val outputWriterFactory: OutputWriterFactory, val allColumns: Seq[Attribute], - val partitionColumns: Seq[Attribute], val dataColumns: Seq[Attribute], - val bucketSpec: Option[BucketSpec], + val partitionColumns: Seq[Attribute], + val bucketColumns: Seq[Attribute], + val numBuckets: Int, val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], val maxRecordsPerFile: Long) @@ -108,8 +108,15 @@ object FileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) + val allColumns = queryExecution.logical.output val partitionSet = AttributeSet(partitionColumns) val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains) + val bucketColumns = bucketSpec.toSeq.flatMap { + spec => spec.bucketColumnNames.map(c => allColumns.find(_.name == c).get) + } + val sortColumns = bucketSpec.toSeq.flatMap { + spec => spec.sortColumnNames.map(c => allColumns.find(_.name == c).get) + } // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = @@ -120,9 +127,10 @@ object FileFormatWriter extends Logging { serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), outputWriterFactory = outputWriterFactory, allColumns = queryExecution.logical.output, - partitionColumns = partitionColumns, dataColumns = dataColumns, - bucketSpec = bucketSpec, + partitionColumns = partitionColumns, + bucketColumns = bucketColumns, + numBuckets = bucketSpec.map(_.numBuckets).getOrElse(0), path = outputSpec.outputPath, customPartitionLocations = outputSpec.customPartitionLocations, maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) @@ -134,8 +142,23 @@ object FileFormatWriter extends Logging { // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. committer.setupJob(job) + val bucketIdExpression = bucketSpec.map { spec => + // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + // guarantee the data distribution is same between shuffle and bucketed data source, which + // enables us to only shuffle one side when join a bucketed table and a normal one. + HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression + } + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val requiredOrdering = (partitionColumns ++ bucketIdExpression ++ sortColumns) + .map(SortOrder(_, Ascending)) + val rdd = if (requiredOrdering == queryExecution.executedPlan.outputOrdering) { + queryExecution.toRdd + } else { + SortExec(requiredOrdering, global = false, queryExecution.executedPlan).execute() + } + try { - val ret = sparkSession.sparkContext.runJob(queryExecution.toRdd, + val ret = sparkSession.sparkContext.runJob(rdd, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -189,7 +212,7 @@ object FileFormatWriter extends Logging { committer.setupTask(taskAttemptContext) val writeTask = - if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) { + if (description.partitionColumns.isEmpty && description.numBuckets == 0) { new SingleDirectoryWriteTask(description, taskAttemptContext, committer) } else { new DynamicPartitionWriteTask(description, taskAttemptContext, committer) @@ -287,31 +310,25 @@ object FileFormatWriter extends Logging { * multiple directories (partitions) or files (bucketing). */ private class DynamicPartitionWriteTask( - description: WriteJobDescription, + desc: WriteJobDescription, taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol) extends ExecuteWriteTask { // currentWriter is initialized whenever we see a new key private var currentWriter: OutputWriter = _ - private val bucketColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap { - spec => spec.bucketColumnNames.map(c => description.allColumns.find(_.name == c).get) - } - - private val sortColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap { - spec => spec.sortColumnNames.map(c => description.allColumns.find(_.name == c).get) - } - - private def bucketIdExpression: Option[Expression] = description.bucketSpec.map { spec => + private def bucketIdExpression: Option[Expression] = if (desc.numBuckets > 0) { // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can // guarantee the data distribution is same between shuffle and bucketed data source, which // enables us to only shuffle one side when join a bucketed table and a normal one. - HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression + Some(HashPartitioning(desc.bucketColumns, desc.numBuckets).partitionIdExpression) + } else { + None } - /** Expressions that given a partition key build a string like: col1=val/col2=val/... */ - private def partitionStringExpression: Seq[Expression] = { - description.partitionColumns.zipWithIndex.flatMap { case (c, i) => + /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */ + private def partitionPathExpression: Seq[Expression] = { + desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => // TODO: use correct timezone for partition values. val escaped = ScalaUDF( ExternalCatalogUtils.escapePathName _, @@ -329,31 +346,41 @@ object FileFormatWriter extends Logging { * If bucket id is specified, we will append it to the end of the file name, but before the * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet * - * @param key vaues for fields consisting of partition keys for the current row - * @param partString a function that projects the partition values into a string + * @param partColsAndBucketId a row consisting of partition columns and a bucket id for the + * current row. + * @param getPartitionPath a function that projects the partition values into a path string. * @param fileCounter the number of files that have been written in the past for this specific * partition. This is used to limit the max number of records written for a * single file. The value should start from 0. + * @param updatedPartitions the set of updated partition paths, we should add the new partition + * path of this writer to it. */ private def newOutputWriter( - key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = { - val partDir = - if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0)) + partColsAndBucketId: InternalRow, + getPartitionPath: UnsafeProjection, + fileCounter: Int, + updatedPartitions: mutable.Set[String]): Unit = { + val partDir = if (desc.partitionColumns.isEmpty) { + None + } else { + Option(getPartitionPath(partColsAndBucketId).getString(0)) + } + partDir.foreach(updatedPartitions.add) // If the bucket spec is defined, the bucket column is right after the partition columns - val bucketId = if (description.bucketSpec.isDefined) { - BucketingUtils.bucketIdToString(key.getInt(description.partitionColumns.length)) + val bucketId = if (desc.numBuckets > 0) { + BucketingUtils.bucketIdToString(partColsAndBucketId.getInt(desc.partitionColumns.length)) } else { "" } // This must be in a form that matches our bucketing format. See BucketingUtils. val ext = f"$bucketId.c$fileCounter%03d" + - description.outputWriterFactory.getFileExtension(taskAttemptContext) + desc.outputWriterFactory.getFileExtension(taskAttemptContext) val customPath = partDir match { case Some(dir) => - description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) case _ => None } @@ -363,80 +390,42 @@ object FileFormatWriter extends Logging { committer.newTaskTempFile(taskAttemptContext, partDir, ext) } - currentWriter = description.outputWriterFactory.newInstance( + currentWriter = desc.outputWriterFactory.newInstance( path = path, - dataSchema = description.dataColumns.toStructType, + dataSchema = desc.dataColumns.toStructType, context = taskAttemptContext) } override def execute(iter: Iterator[InternalRow]): Set[String] = { - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val sortingExpressions: Seq[Expression] = - description.partitionColumns ++ bucketIdExpression ++ sortColumns - val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns) - - val sortingKeySchema = StructType(sortingExpressions.map { - case a: Attribute => StructField(a.name, a.dataType, a.nullable) - // The sorting expressions are all `Attribute` except bucket id. - case _ => StructField("bucketId", IntegerType, nullable = false) - }) + val getPartitionColsAndBucketId = UnsafeProjection.create( + desc.partitionColumns ++ bucketIdExpression, desc.allColumns) - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create( - description.dataColumns, description.allColumns) - - // Returns the partition path given a partition key. - val getPartitionStringFunc = UnsafeProjection.create( - Seq(Concat(partitionStringExpression)), description.partitionColumns) - - // Sorts the data before write, so that we only need one writer at the same time. - val sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(description.dataColumns), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get().taskMemoryManager().pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)) + // Generates the partition path given the row generated by `getPartitionColsAndBucketId`. + val getPartPath = UnsafeProjection.create( + Seq(Concat(partitionPathExpression)), desc.partitionColumns) - while (iter.hasNext) { - val currentRow = iter.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } - - val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { - identity - } else { - UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { - case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) - }) - } - - val sortedIterator = sorter.sortedIterator() + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns) // If anything below fails, we should abort the task. var recordsInFile: Long = 0L var fileCounter = 0 - var currentKey: UnsafeRow = null + var currentPartColsAndBucketId: UnsafeRow = null val updatedPartitions = mutable.Set[String]() - while (sortedIterator.next()) { - val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] - if (currentKey != nextKey) { - // See a new key - write to a new partition (new file). - currentKey = nextKey.copy() - logDebug(s"Writing partition: $currentKey") + for (row <- iter) { + val nextPartColsAndBucketId = getPartitionColsAndBucketId(row) + if (currentPartColsAndBucketId != nextPartColsAndBucketId) { + // See a new partition or bucket - write to a new partition dir (or a new bucket file). + currentPartColsAndBucketId = nextPartColsAndBucketId.copy() + logDebug(s"Writing partition: $currentPartColsAndBucketId") recordsInFile = 0 fileCounter = 0 releaseResources() - newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) - val partitionPath = getPartitionStringFunc(currentKey).getString(0) - if (partitionPath.nonEmpty) { - updatedPartitions.add(partitionPath) - } - } else if (description.maxRecordsPerFile > 0 && - recordsInFile >= description.maxRecordsPerFile) { + newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions) + } else if (desc.maxRecordsPerFile > 0 && + recordsInFile >= desc.maxRecordsPerFile) { // Exceeded the threshold in terms of the number of records per file. // Create a new file by increasing the file counter. recordsInFile = 0 @@ -445,10 +434,10 @@ object FileFormatWriter extends Logging { s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") releaseResources() - newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) + newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions) } - currentWriter.write(sortedIterator.getValue) + currentWriter.write(getOutputRow(row)) recordsInFile += 1 } releaseResources() From 66dbef00321e5358cc754e444fc4310689e1f876 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 12 Feb 2017 01:35:08 -0800 Subject: [PATCH 2/5] address comments --- .../spark/sql/execution/datasources/FileFormatWriter.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 048305438b14..6bd4f6283456 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -151,7 +151,10 @@ object FileFormatWriter extends Logging { // We should first sort by partition columns, then bucket id, and finally sorting columns. val requiredOrdering = (partitionColumns ++ bucketIdExpression ++ sortColumns) .map(SortOrder(_, Ascending)) - val rdd = if (requiredOrdering == queryExecution.executedPlan.outputOrdering) { + val actualOrdering = queryExecution.executedPlan.outputOrdering + // We can still avoid the sort if the required ordering is [partCol] and the actual ordering + // is [partCol, anotherCol]. + val rdd = if (requiredOrdering == actualOrdering.take(requiredOrdering.length)) { queryExecution.toRdd } else { SortExec(requiredOrdering, global = false, queryExecution.executedPlan).execute() From fc591d18601e9c2c86d3df80b7ebbf43bc403d27 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 15 Feb 2017 08:34:17 -0800 Subject: [PATCH 3/5] address comments --- .../datasources/FileFormatWriter.scala | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 6bd4f6283456..b94bd901dd13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -112,10 +112,10 @@ object FileFormatWriter extends Logging { val partitionSet = AttributeSet(partitionColumns) val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains) val bucketColumns = bucketSpec.toSeq.flatMap { - spec => spec.bucketColumnNames.map(c => allColumns.find(_.name == c).get) + spec => spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) } val sortColumns = bucketSpec.toSeq.flatMap { - spec => spec.sortColumnNames.map(c => allColumns.find(_.name == c).get) + spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) } // Note: prepareWrite has side effect. It sets "job". @@ -126,7 +126,7 @@ object FileFormatWriter extends Logging { uuid = UUID.randomUUID().toString, serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), outputWriterFactory = outputWriterFactory, - allColumns = queryExecution.logical.output, + allColumns = allColumns, dataColumns = dataColumns, partitionColumns = partitionColumns, bucketColumns = bucketColumns, @@ -149,15 +149,25 @@ object FileFormatWriter extends Logging { HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression } // We should first sort by partition columns, then bucket id, and finally sorting columns. - val requiredOrdering = (partitionColumns ++ bucketIdExpression ++ sortColumns) - .map(SortOrder(_, Ascending)) - val actualOrdering = queryExecution.executedPlan.outputOrdering - // We can still avoid the sort if the required ordering is [partCol] and the actual ordering - // is [partCol, anotherCol]. - val rdd = if (requiredOrdering == actualOrdering.take(requiredOrdering.length)) { + val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns + // the sort order doesn't matter + val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child) + val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { + false + } else { + requiredOrdering.zip(actualOrdering).forall { + case (requiredOrder, childOutputOrder) => + requiredOrder.semanticEquals(childOutputOrder) + } + } + + val rdd = if (orderingMatched) { queryExecution.toRdd } else { - SortExec(requiredOrdering, global = false, queryExecution.executedPlan).execute() + SortExec( + requiredOrdering.map(SortOrder(_, Ascending)), + global = false, + child = queryExecution.executedPlan).execute() } try { @@ -345,7 +355,7 @@ object FileFormatWriter extends Logging { } /** - * Open and returns a new OutputWriter given a partition key and optional bucket id. + * Opens a new OutputWriter given a partition key and optional bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet * From 728e1c8eef5d081b499f3e105a0dc13302d8e799 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 15 Feb 2017 10:24:25 -0800 Subject: [PATCH 4/5] fix a bug --- .../datasources/FileFormatWriter.scala | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index b94bd901dd13..dd5da2a22857 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -137,40 +137,40 @@ object FileFormatWriter extends Logging { .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile) ) + val bucketIdExpression = bucketSpec.map { spec => + // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + // guarantee the data distribution is same between shuffle and bucketed data source, which + // enables us to only shuffle one side when join a bucketed table and a normal one. + HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression + } + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns + // the sort order doesn't matter + val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child) + val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { + false + } else { + requiredOrdering.zip(actualOrdering).forall { + case (requiredOrder, childOutputOrder) => + requiredOrder.semanticEquals(childOutputOrder) + } + } + SQLExecution.withNewExecutionId(sparkSession, queryExecution) { // This call shouldn't be put into the `try` block below because it only initializes and // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. committer.setupJob(job) - val bucketIdExpression = bucketSpec.map { spec => - // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can - // guarantee the data distribution is same between shuffle and bucketed data source, which - // enables us to only shuffle one side when join a bucketed table and a normal one. - HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression - } - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns - // the sort order doesn't matter - val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child) - val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { - false - } else { - requiredOrdering.zip(actualOrdering).forall { - case (requiredOrder, childOutputOrder) => - requiredOrder.semanticEquals(childOutputOrder) + try { + val rdd = if (orderingMatched) { + queryExecution.toRdd + } else { + SortExec( + requiredOrdering.map(SortOrder(_, Ascending)), + global = false, + child = queryExecution.executedPlan).execute() } - } - val rdd = if (orderingMatched) { - queryExecution.toRdd - } else { - SortExec( - requiredOrdering.map(SortOrder(_, Ascending)), - global = false, - child = queryExecution.executedPlan).execute() - } - - try { val ret = sparkSession.sparkContext.runJob(rdd, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( From 83053ef59d97abe1b9d05e0cf6184a00d261da25 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 17 Feb 2017 10:30:03 -0800 Subject: [PATCH 5/5] simplify --- .../datasources/FileFormatWriter.scala | 39 +++++++------------ 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index dd5da2a22857..644358493e2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -65,8 +65,7 @@ object FileFormatWriter extends Logging { val allColumns: Seq[Attribute], val dataColumns: Seq[Attribute], val partitionColumns: Seq[Attribute], - val bucketColumns: Seq[Attribute], - val numBuckets: Int, + val bucketIdExpression: Option[Expression], val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], val maxRecordsPerFile: Long) @@ -111,8 +110,13 @@ object FileFormatWriter extends Logging { val allColumns = queryExecution.logical.output val partitionSet = AttributeSet(partitionColumns) val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains) - val bucketColumns = bucketSpec.toSeq.flatMap { - spec => spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) + + val bucketIdExpression = bucketSpec.map { spec => + val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) + // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + // guarantee the data distribution is same between shuffle and bucketed data source, which + // enables us to only shuffle one side when join a bucketed table and a normal one. + HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression } val sortColumns = bucketSpec.toSeq.flatMap { spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) @@ -129,20 +133,13 @@ object FileFormatWriter extends Logging { allColumns = allColumns, dataColumns = dataColumns, partitionColumns = partitionColumns, - bucketColumns = bucketColumns, - numBuckets = bucketSpec.map(_.numBuckets).getOrElse(0), + bucketIdExpression = bucketIdExpression, path = outputSpec.outputPath, customPartitionLocations = outputSpec.customPartitionLocations, maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile) ) - val bucketIdExpression = bucketSpec.map { spec => - // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can - // guarantee the data distribution is same between shuffle and bucketed data source, which - // enables us to only shuffle one side when join a bucketed table and a normal one. - HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression - } // We should first sort by partition columns, then bucket id, and finally sorting columns. val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns // the sort order doesn't matter @@ -225,7 +222,7 @@ object FileFormatWriter extends Logging { committer.setupTask(taskAttemptContext) val writeTask = - if (description.partitionColumns.isEmpty && description.numBuckets == 0) { + if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { new SingleDirectoryWriteTask(description, taskAttemptContext, committer) } else { new DynamicPartitionWriteTask(description, taskAttemptContext, committer) @@ -330,15 +327,6 @@ object FileFormatWriter extends Logging { // currentWriter is initialized whenever we see a new key private var currentWriter: OutputWriter = _ - private def bucketIdExpression: Option[Expression] = if (desc.numBuckets > 0) { - // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can - // guarantee the data distribution is same between shuffle and bucketed data source, which - // enables us to only shuffle one side when join a bucketed table and a normal one. - Some(HashPartitioning(desc.bucketColumns, desc.numBuckets).partitionIdExpression) - } else { - None - } - /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */ private def partitionPathExpression: Seq[Expression] = { desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => @@ -380,8 +368,9 @@ object FileFormatWriter extends Logging { } partDir.foreach(updatedPartitions.add) - // If the bucket spec is defined, the bucket column is right after the partition columns - val bucketId = if (desc.numBuckets > 0) { + // If the bucketId expression is defined, the bucketId column is right after the partition + // columns. + val bucketId = if (desc.bucketIdExpression.isDefined) { BucketingUtils.bucketIdToString(partColsAndBucketId.getInt(desc.partitionColumns.length)) } else { "" @@ -411,7 +400,7 @@ object FileFormatWriter extends Logging { override def execute(iter: Iterator[InternalRow]): Set[String] = { val getPartitionColsAndBucketId = UnsafeProjection.create( - desc.partitionColumns ++ bucketIdExpression, desc.allColumns) + desc.partitionColumns ++ desc.bucketIdExpression, desc.allColumns) // Generates the partition path given the row generated by `getPartitionColsAndBucketId`. val getPartPath = UnsafeProjection.create(