diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ParallelUnionRDD.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ParallelUnionRDD.scala new file mode 100644 index 0000000000000..dacc6fa546e17 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ParallelUnionRDD.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.util.concurrent.Callable + +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, SparkContext} +import org.apache.spark.rdd.{RDD, UnionPartition, UnionRDD} +import org.apache.spark.util.ThreadUtils + +private[hive] object ParallelUnionRDD { + lazy val executorService = ThreadUtils.newDaemonFixedThreadPool(16, "ParallelUnionRDD") +} + +private[hive] class ParallelUnionRDD[T: ClassTag]( + sc: SparkContext, + rdds: Seq[RDD[T]]) extends UnionRDD[T](sc, rdds){ + + override def getPartitions: Array[Partition] = { + // Calc partitions field for each RDD in parallel. + val rddPartitions = rdds.map {rdd => + (rdd, ParallelUnionRDD.executorService.submit(new Callable[Array[Partition]] { + override def call(): Array[Partition] = rdd.partitions + })) + }.map { case(r, f) => (r, f.get()) } + + val array = new Array[Partition](rddPartitions.map(_._2.length).sum) + var pos = 0 + for (((rdd, partitions), rddIndex) <- rddPartitions.zipWithIndex; split <- partitions) { + array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) + pos += 1 + } + array + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index fd465e80a87e5..8cb5dd4d9d09c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -246,7 +246,7 @@ class HadoopTableReader( if (hivePartitionRDDs.size == 0) { new EmptyRDD[InternalRow](sc.sparkContext) } else { - new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) + new ParallelUnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index b0c0dcbe5c25c..a67c22747efd0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -89,4 +89,24 @@ class HiveTableScanSuite extends HiveComparisonTest { assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi")) assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) } + + test("Spark-11517: calc partitions in parallel") { + val partitionNum = 500 + val partitionTable = "combine" + sql("set hive.exec.dynamic.partition.mode=nonstrict") + val df = (1 to 500).map { i => (i, i)}.toDF("a", "b").coalesce(500) + df.registerTempTable("temp") + sql(s"""create table $partitionTable (a int, b string) + |partitioned by (c int) + |stored as orc""".stripMargin) + sql( + s"""insert into table $partitionTable partition(c) + |select a, b, (b % $partitionNum) as c from temp""".stripMargin) + + // Ensure that the result is the same as the original + assert( + sql( s"""select * from $partitionTable order by a""").collect().map(_.toString()).deep + == (1 to 500).map{i => s"[$i,$i,${i % partitionNum}]"}.toArray.deep + ) + } }