@@ -38,10 +38,9 @@ import org.apache.spark.sql.catalyst.expressions._
3838import org .apache .spark .sql .catalyst .plans .physical .HashPartitioning
3939import org .apache .spark .sql .catalyst .InternalRow
4040import org .apache .spark .sql .catalyst .util .DateTimeUtils
41- import org .apache .spark .sql .execution .{QueryExecution , SQLExecution , UnsafeKVExternalSorter }
42- import org .apache .spark .sql .types .{IntegerType , StringType , StructField , StructType }
41+ import org .apache .spark .sql .execution .{QueryExecution , SortExec , SQLExecution }
42+ import org .apache .spark .sql .types .{StringType , StructType }
4343import org .apache .spark .util .{SerializableConfiguration , Utils }
44- import org .apache .spark .util .collection .unsafe .sort .UnsafeExternalSorter
4544
4645
4746/** A helper object for writing FileFormat data out to a location. */
@@ -64,9 +63,9 @@ object FileFormatWriter extends Logging {
6463 val serializableHadoopConf : SerializableConfiguration ,
6564 val outputWriterFactory : OutputWriterFactory ,
6665 val allColumns : Seq [Attribute ],
67- val partitionColumns : Seq [Attribute ],
6866 val dataColumns : Seq [Attribute ],
69- val bucketSpec : Option [BucketSpec ],
67+ val partitionColumns : Seq [Attribute ],
68+ val bucketIdExpression : Option [Expression ],
7069 val path : String ,
7170 val customPartitionLocations : Map [TablePartitionSpec , String ],
7271 val maxRecordsPerFile : Long )
@@ -108,9 +107,21 @@ object FileFormatWriter extends Logging {
108107 job.setOutputValueClass(classOf [InternalRow ])
109108 FileOutputFormat .setOutputPath(job, new Path (outputSpec.outputPath))
110109
110+ val allColumns = queryExecution.logical.output
111111 val partitionSet = AttributeSet (partitionColumns)
112112 val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains)
113113
114+ val bucketIdExpression = bucketSpec.map { spec =>
115+ val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
116+ // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
117+ // guarantee the data distribution is same between shuffle and bucketed data source, which
118+ // enables us to only shuffle one side when join a bucketed table and a normal one.
119+ HashPartitioning (bucketColumns, spec.numBuckets).partitionIdExpression
120+ }
121+ val sortColumns = bucketSpec.toSeq.flatMap {
122+ spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
123+ }
124+
114125 // Note: prepareWrite has side effect. It sets "job".
115126 val outputWriterFactory =
116127 fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType)
@@ -119,23 +130,45 @@ object FileFormatWriter extends Logging {
119130 uuid = UUID .randomUUID().toString,
120131 serializableHadoopConf = new SerializableConfiguration (job.getConfiguration),
121132 outputWriterFactory = outputWriterFactory,
122- allColumns = queryExecution.logical.output,
123- partitionColumns = partitionColumns,
133+ allColumns = allColumns,
124134 dataColumns = dataColumns,
125- bucketSpec = bucketSpec,
135+ partitionColumns = partitionColumns,
136+ bucketIdExpression = bucketIdExpression,
126137 path = outputSpec.outputPath,
127138 customPartitionLocations = outputSpec.customPartitionLocations,
128139 maxRecordsPerFile = options.get(" maxRecordsPerFile" ).map(_.toLong)
129140 .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
130141 )
131142
143+ // We should first sort by partition columns, then bucket id, and finally sorting columns.
144+ val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
145+ // the sort order doesn't matter
146+ val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child)
147+ val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
148+ false
149+ } else {
150+ requiredOrdering.zip(actualOrdering).forall {
151+ case (requiredOrder, childOutputOrder) =>
152+ requiredOrder.semanticEquals(childOutputOrder)
153+ }
154+ }
155+
132156 SQLExecution .withNewExecutionId(sparkSession, queryExecution) {
133157 // This call shouldn't be put into the `try` block below because it only initializes and
134158 // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
135159 committer.setupJob(job)
136160
137161 try {
138- val ret = sparkSession.sparkContext.runJob(queryExecution.toRdd,
162+ val rdd = if (orderingMatched) {
163+ queryExecution.toRdd
164+ } else {
165+ SortExec (
166+ requiredOrdering.map(SortOrder (_, Ascending )),
167+ global = false ,
168+ child = queryExecution.executedPlan).execute()
169+ }
170+
171+ val ret = sparkSession.sparkContext.runJob(rdd,
139172 (taskContext : TaskContext , iter : Iterator [InternalRow ]) => {
140173 executeTask(
141174 description = description,
@@ -189,7 +222,7 @@ object FileFormatWriter extends Logging {
189222 committer.setupTask(taskAttemptContext)
190223
191224 val writeTask =
192- if (description.partitionColumns.isEmpty && description.bucketSpec .isEmpty) {
225+ if (description.partitionColumns.isEmpty && description.bucketIdExpression .isEmpty) {
193226 new SingleDirectoryWriteTask (description, taskAttemptContext, committer)
194227 } else {
195228 new DynamicPartitionWriteTask (description, taskAttemptContext, committer)
@@ -287,31 +320,16 @@ object FileFormatWriter extends Logging {
287320 * multiple directories (partitions) or files (bucketing).
288321 */
289322 private class DynamicPartitionWriteTask (
290- description : WriteJobDescription ,
323+ desc : WriteJobDescription ,
291324 taskAttemptContext : TaskAttemptContext ,
292325 committer : FileCommitProtocol ) extends ExecuteWriteTask {
293326
294327 // currentWriter is initialized whenever we see a new key
295328 private var currentWriter : OutputWriter = _
296329
297- private val bucketColumns : Seq [Attribute ] = description.bucketSpec.toSeq.flatMap {
298- spec => spec.bucketColumnNames.map(c => description.allColumns.find(_.name == c).get)
299- }
300-
301- private val sortColumns : Seq [Attribute ] = description.bucketSpec.toSeq.flatMap {
302- spec => spec.sortColumnNames.map(c => description.allColumns.find(_.name == c).get)
303- }
304-
305- private def bucketIdExpression : Option [Expression ] = description.bucketSpec.map { spec =>
306- // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
307- // guarantee the data distribution is same between shuffle and bucketed data source, which
308- // enables us to only shuffle one side when join a bucketed table and a normal one.
309- HashPartitioning (bucketColumns, spec.numBuckets).partitionIdExpression
310- }
311-
312- /** Expressions that given a partition key build a string like: col1=val/col2=val/... */
313- private def partitionStringExpression : Seq [Expression ] = {
314- description.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
330+ /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */
331+ private def partitionPathExpression : Seq [Expression ] = {
332+ desc.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
315333 // TODO: use correct timezone for partition values.
316334 val escaped = ScalaUDF (
317335 ExternalCatalogUtils .escapePathName _,
@@ -325,35 +343,46 @@ object FileFormatWriter extends Logging {
325343 }
326344
327345 /**
328- * Open and returns a new OutputWriter given a partition key and optional bucket id.
346+ * Opens a new OutputWriter given a partition key and optional bucket id.
329347 * If bucket id is specified, we will append it to the end of the file name, but before the
330348 * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
331349 *
332- * @param key vaues for fields consisting of partition keys for the current row
333- * @param partString a function that projects the partition values into a string
350+ * @param partColsAndBucketId a row consisting of partition columns and a bucket id for the
351+ * current row.
352+ * @param getPartitionPath a function that projects the partition values into a path string.
334353 * @param fileCounter the number of files that have been written in the past for this specific
335354 * partition. This is used to limit the max number of records written for a
336355 * single file. The value should start from 0.
356+ * @param updatedPartitions the set of updated partition paths, we should add the new partition
357+ * path of this writer to it.
337358 */
338359 private def newOutputWriter (
339- key : InternalRow , partString : UnsafeProjection , fileCounter : Int ): Unit = {
340- val partDir =
341- if (description.partitionColumns.isEmpty) None else Option (partString(key).getString(0 ))
360+ partColsAndBucketId : InternalRow ,
361+ getPartitionPath : UnsafeProjection ,
362+ fileCounter : Int ,
363+ updatedPartitions : mutable.Set [String ]): Unit = {
364+ val partDir = if (desc.partitionColumns.isEmpty) {
365+ None
366+ } else {
367+ Option (getPartitionPath(partColsAndBucketId).getString(0 ))
368+ }
369+ partDir.foreach(updatedPartitions.add)
342370
343- // If the bucket spec is defined, the bucket column is right after the partition columns
344- val bucketId = if (description.bucketSpec.isDefined) {
345- BucketingUtils .bucketIdToString(key.getInt(description.partitionColumns.length))
371+ // If the bucketId expression is defined, the bucketId column is right after the partition
372+ // columns.
373+ val bucketId = if (desc.bucketIdExpression.isDefined) {
374+ BucketingUtils .bucketIdToString(partColsAndBucketId.getInt(desc.partitionColumns.length))
346375 } else {
347376 " "
348377 }
349378
350379 // This must be in a form that matches our bucketing format. See BucketingUtils.
351380 val ext = f " $bucketId.c $fileCounter%03d " +
352- description .outputWriterFactory.getFileExtension(taskAttemptContext)
381+ desc .outputWriterFactory.getFileExtension(taskAttemptContext)
353382
354383 val customPath = partDir match {
355384 case Some (dir) =>
356- description .customPartitionLocations.get(PartitioningUtils .parsePathFragment(dir))
385+ desc .customPartitionLocations.get(PartitioningUtils .parsePathFragment(dir))
357386 case _ =>
358387 None
359388 }
@@ -363,80 +392,42 @@ object FileFormatWriter extends Logging {
363392 committer.newTaskTempFile(taskAttemptContext, partDir, ext)
364393 }
365394
366- currentWriter = description .outputWriterFactory.newInstance(
395+ currentWriter = desc .outputWriterFactory.newInstance(
367396 path = path,
368- dataSchema = description .dataColumns.toStructType,
397+ dataSchema = desc .dataColumns.toStructType,
369398 context = taskAttemptContext)
370399 }
371400
372401 override def execute (iter : Iterator [InternalRow ]): Set [String ] = {
373- // We should first sort by partition columns, then bucket id, and finally sorting columns.
374- val sortingExpressions : Seq [Expression ] =
375- description.partitionColumns ++ bucketIdExpression ++ sortColumns
376- val getSortingKey = UnsafeProjection .create(sortingExpressions, description.allColumns)
377-
378- val sortingKeySchema = StructType (sortingExpressions.map {
379- case a : Attribute => StructField (a.name, a.dataType, a.nullable)
380- // The sorting expressions are all `Attribute` except bucket id.
381- case _ => StructField (" bucketId" , IntegerType , nullable = false )
382- })
383-
384- // Returns the data columns to be written given an input row
385- val getOutputRow = UnsafeProjection .create(
386- description.dataColumns, description.allColumns)
387-
388- // Returns the partition path given a partition key.
389- val getPartitionStringFunc = UnsafeProjection .create(
390- Seq (Concat (partitionStringExpression)), description.partitionColumns)
391-
392- // Sorts the data before write, so that we only need one writer at the same time.
393- val sorter = new UnsafeKVExternalSorter (
394- sortingKeySchema,
395- StructType .fromAttributes(description.dataColumns),
396- SparkEnv .get.blockManager,
397- SparkEnv .get.serializerManager,
398- TaskContext .get().taskMemoryManager().pageSizeBytes,
399- SparkEnv .get.conf.getLong(" spark.shuffle.spill.numElementsForceSpillThreshold" ,
400- UnsafeExternalSorter .DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD ))
401-
402- while (iter.hasNext) {
403- val currentRow = iter.next()
404- sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
405- }
402+ val getPartitionColsAndBucketId = UnsafeProjection .create(
403+ desc.partitionColumns ++ desc.bucketIdExpression, desc.allColumns)
406404
407- val getBucketingKey : InternalRow => InternalRow = if (sortColumns.isEmpty) {
408- identity
409- } else {
410- UnsafeProjection .create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
411- case (expr, ordinal) => BoundReference (ordinal, expr.dataType, expr.nullable)
412- })
413- }
405+ // Generates the partition path given the row generated by `getPartitionColsAndBucketId`.
406+ val getPartPath = UnsafeProjection .create(
407+ Seq (Concat (partitionPathExpression)), desc.partitionColumns)
414408
415- val sortedIterator = sorter.sortedIterator()
409+ // Returns the data columns to be written given an input row
410+ val getOutputRow = UnsafeProjection .create(desc.dataColumns, desc.allColumns)
416411
417412 // If anything below fails, we should abort the task.
418413 var recordsInFile : Long = 0L
419414 var fileCounter = 0
420- var currentKey : UnsafeRow = null
415+ var currentPartColsAndBucketId : UnsafeRow = null
421416 val updatedPartitions = mutable.Set [String ]()
422- while (sortedIterator.next() ) {
423- val nextKey = getBucketingKey(sortedIterator.getKey). asInstanceOf [ UnsafeRow ]
424- if (currentKey != nextKey ) {
425- // See a new key - write to a new partition ( new file).
426- currentKey = nextKey .copy()
427- logDebug(s " Writing partition: $currentKey " )
417+ for (row <- iter ) {
418+ val nextPartColsAndBucketId = getPartitionColsAndBucketId(row)
419+ if (currentPartColsAndBucketId != nextPartColsAndBucketId ) {
420+ // See a new partition or bucket - write to a new partition dir (or a new bucket file).
421+ currentPartColsAndBucketId = nextPartColsAndBucketId .copy()
422+ logDebug(s " Writing partition: $currentPartColsAndBucketId " )
428423
429424 recordsInFile = 0
430425 fileCounter = 0
431426
432427 releaseResources()
433- newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
434- val partitionPath = getPartitionStringFunc(currentKey).getString(0 )
435- if (partitionPath.nonEmpty) {
436- updatedPartitions.add(partitionPath)
437- }
438- } else if (description.maxRecordsPerFile > 0 &&
439- recordsInFile >= description.maxRecordsPerFile) {
428+ newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions)
429+ } else if (desc.maxRecordsPerFile > 0 &&
430+ recordsInFile >= desc.maxRecordsPerFile) {
440431 // Exceeded the threshold in terms of the number of records per file.
441432 // Create a new file by increasing the file counter.
442433 recordsInFile = 0
@@ -445,10 +436,10 @@ object FileFormatWriter extends Logging {
445436 s " File counter $fileCounter is beyond max value $MAX_FILE_COUNTER" )
446437
447438 releaseResources()
448- newOutputWriter(currentKey, getPartitionStringFunc , fileCounter)
439+ newOutputWriter(currentPartColsAndBucketId, getPartPath , fileCounter, updatedPartitions )
449440 }
450441
451- currentWriter.write(sortedIterator.getValue )
442+ currentWriter.write(getOutputRow(row) )
452443 recordsInFile += 1
453444 }
454445 releaseResources()
0 commit comments