diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 805cd9fe1f63..d9f0f184f490 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -205,13 +205,22 @@ class HadoopRDD[K, V]( array } + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes + protected def getBytesReadCallback(split: HadoopPartition): Option[() => Long] = { + split.inputSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + } + override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { val iter = new NextIterator[(K, V)] { val split = theSplit.asInstanceOf[HadoopPartition] logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() - // TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) @@ -223,24 +232,17 @@ class HadoopRDD[K, V]( case _ => SqlNewHadoopRDDState.unsetInputFileName() } - // Find a function that will return the FileSystem bytes read by this thread. Do this before - // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None - } + val bytesReadCallback = getBytesReadCallback(split) // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. // If we do a coalesce, however, we are likely to compute multiple partitions in the same // task and in the same thread, in which case we need to avoid override values written by // previous partitions (SPARK-13071). def updateBytesRead(): Unit = { - getBytesReadCallback.foreach { getBytesRead => + bytesReadCallback.foreach { getBytesRead => inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } - var reader: RecordReader[K, V] = null val inputFormat = getInputFormat(jobConf) HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), @@ -285,7 +287,7 @@ class HadoopRDD[K, V]( } finally { reader = null } - if (getBytesReadCallback.isDefined) { + if (bytesReadCallback.isDefined) { updateBytesRead() } else if (split.inputSplit.value.isInstanceOf[FileSplit] || split.inputSplit.value.isInstanceOf[CombineFileSplit]) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 61a7b9935af0..72c498a6e251 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -507,6 +507,10 @@ private[spark] object SQLConf { " method", isPublic = false) + val MAPPER_SPLIT_COMBINE_SIZE = intConf( + "spark.sql.mapper.splitCombineSize", + defaultValue = Some(-1), + isPublic = true) object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -574,6 +578,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon private[spark] def subexpressionEliminationEnabled: Boolean = getConf(SUBEXPRESSION_ELIMINATION_ENABLED) + private[spark] def mapperSplitCombineSize: Int = getConf(MAPPER_SPLIT_COMBINE_SIZE) + private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = diff --git a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java new file mode 100644 index 000000000000..662f827b0ce1 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.mapred; + +import com.clearspring.analytics.util.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableFactories; +import org.apache.hadoop.mapred.*; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.*; + +public class CombineSplitInputFormat implements InputFormat { + + private InputFormat delegate; + private long splitCombineSize = 0; + + public CombineSplitInputFormat(InputFormat inputformat, long sSize) { + this.delegate = inputformat; + this.splitCombineSize = sSize; + } + + private CombineSplit createCombineSplit( + long totalLen, + Collection locations, + List combineSplitBuffer) { + return new CombineSplit(combineSplitBuffer.toArray(new InputSplit[0]), + totalLen, locations.toArray(new String[0])); + } + + @Override + public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException { + InputSplit[] splits = delegate.getSplits(job, numSplits); + Map> nodeToSplits = Maps.newHashMap(); + Set splitsSet = Sets.newHashSet(); + for (InputSplit split: splits) { + for (String node: split.getLocations()) { + if (!nodeToSplits.containsKey(node)) { + nodeToSplits.put(node, new ArrayList()); + } + nodeToSplits.get(node).add(split); + } + splitsSet.add(split); + } + // Iterate the nodes to combine in order to evenly distributing the splits + // Ideally splits within the same combination should be in the same node + List combineSparkSplits = Lists.newArrayList(); + List combinedSplitBuffer = Lists.newArrayList(); + long accumulatedSplitSize = 0L; + for (Map.Entry> entry: nodeToSplits.entrySet()) { + String node = entry.getKey(); + List splitsPerNode = entry.getValue(); + for (InputSplit split: splitsPerNode) { + // this split has been combined + if (!splitsSet.contains(split)) { + continue; + } else { + accumulatedSplitSize += split.getLength(); + combinedSplitBuffer.add(split); + splitsSet.remove(split); + } + if (splitCombineSize > 0 && accumulatedSplitSize >= splitCombineSize) { + // TODO: optimize this by providing the second/third preference locations + combineSparkSplits.add(createCombineSplit( + accumulatedSplitSize, Collections.singleton(node), combinedSplitBuffer)); + accumulatedSplitSize = 0; + combinedSplitBuffer.clear(); + } + } + // populate the remaining splits into one combined split + if (!combinedSplitBuffer.isEmpty()) { + long remainLen = 0; + for (InputSplit s: combinedSplitBuffer) { + remainLen += s.getLength(); + } + combineSparkSplits.add(createCombineSplit( + remainLen, Collections.singleton(node), combinedSplitBuffer)); + accumulatedSplitSize = 0; + combinedSplitBuffer.clear(); + } + } + return combineSparkSplits.toArray(new InputSplit[0]); + } + + @Override + public RecordReader getRecordReader(final InputSplit split, + final JobConf jobConf, final Reporter reporter) throws IOException { + return new RecordReader() { + protected int idx = 0; + protected long progressedBytes = 0; + protected RecordReader curReader; + protected CombineSplit combineSplit; + { + combineSplit = (CombineSplit)split; + initNextRecordReader(); + } + + @Override + public boolean next(K key, V value) throws IOException { + while ((curReader == null) || !curReader.next(key, value)) { + if (!initNextRecordReader()) { + return false; + } + } + return true; + } + + public K createKey() { + return curReader.createKey(); + } + + public V createValue() { + return curReader.createValue(); + } + + /** + * return the amount of data processed + */ + public long getPos() throws IOException { + return progressedBytes + curReader.getPos(); + } + + public void close() throws IOException { + if (curReader != null) { + curReader.close(); + curReader = null; + } + } + + /** + * return progress based on the amount of data processed so far. + */ + public float getProgress() throws IOException { + return Math.min(1.0f, progressedBytes /(float)(split.getLength())); + } + + /** + * Get the record reader for the next split in this CombineSplit. + */ + protected boolean initNextRecordReader() throws IOException { + + if (curReader != null) { + curReader.close(); + curReader = null; + if (idx > 0) { + progressedBytes += combineSplit.getSplit(idx-1).getLength(); // done processing so far + } + } + + // if all splits have been processed, nothing more to do. + if (idx == combineSplit.getSplitNum()) { + return false; + } + + // get a record reader for the idx-th split + try { + curReader = delegate.getRecordReader(combineSplit.getSplit(idx), jobConf, Reporter.NULL); + } catch (Exception e) { + throw new RuntimeException (e); + } + idx++; + return true; + } + }; + } + + public static class CombineSplit implements InputSplit { + private InputSplit[] splits; + private long totalLen; + private String[] locations; + + public CombineSplit() { + } + + public CombineSplit(InputSplit[] ss, long totalLen, String[] locations) { + splits = ss; + this.totalLen = totalLen; + this.locations = locations; + } + + public InputSplit getSplit(int idx) { + return splits[idx]; + } + + public int getSplitNum() { + return splits.length; + } + + @Override + public long getLength() { + return totalLen; + } + + @Override + public String[] getLocations() throws IOException { + return locations; + } + + @Override + public void write(DataOutput out) throws IOException { + // We only process combination within a single table partition, + // so all of the class name of the splits should be identical. + out.writeUTF(splits[0].getClass().getCanonicalName()); + out.writeLong(totalLen); + out.writeInt(locations.length); + for (String location : locations) { + out.writeUTF(location); + } + out.writeInt(splits.length); + for (InputSplit split : splits) { + split.write(out); + } + } + + @Override + public void readFields(DataInput in) throws IOException { + String className = in.readUTF(); + this.totalLen = in.readLong(); + this.locations = new String[in.readInt()]; + for (int i = 0; i < locations.length; i++) { + locations[i] = in.readUTF(); + } + splits = new InputSplit[in.readInt()]; + Class clazz = null; + try { + clazz = (Class) Class.forName(className); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + for (int i = 0; i < splits.length; i++) { + Writable value = WritableFactories.newInstance(clazz, null); + value.readFields(in); + splits[i] = (InputSplit) value; + } + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopRDDwithCombination.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopRDDwithCombination.scala new file mode 100644 index 000000000000..5fe4b51006a6 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopRDDwithCombination.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapred.lib.CombineFileSplit + +import org.apache.spark.SparkContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.rdd.{HadoopPartition, HadoopRDD} +import org.apache.spark.sql.hive.mapred.CombineSplitInputFormat +import org.apache.spark.sql.hive.mapred.CombineSplitInputFormat.CombineSplit +import org.apache.spark.util.SerializableConfiguration + + +class HadoopRDDwithCombination[K, V]( + sc: SparkContext, + broadcastedConf: Broadcast[SerializableConfiguration], + initLocalJobConfFuncOpt: Option[JobConf => Unit], + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V], + minPartitions: Int, + splitCombineSize: Int) extends HadoopRDD[K, V](sc, + broadcastedConf, + initLocalJobConfFuncOpt, + inputFormatClass, + keyClass, + valueClass, + minPartitions +) { + + override protected def getInputFormat(conf: JobConf): InputFormat[K, V] = { + if (splitCombineSize < 0) { + super.getInputFormat(conf) + } else { + new CombineSplitInputFormat(super.getInputFormat(conf), splitCombineSize) + } + } + + override protected def getBytesReadCallback(split: HadoopPartition) = { + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes + val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { + case _: FileSplit | _: CombineFileSplit | _: CombineSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + getBytesReadCallback + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index fd465e80a87e..f1b865f0e2e0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -274,16 +274,15 @@ class HadoopTableReader( inputFormatClass: Class[InputFormat[Writable, Writable]]): RDD[Writable] = { val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(path, tableDesc) _ - - val rdd = new HadoopRDD( - sc.sparkContext, - _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableConfiguration]], - Some(initializeJobConfFunc), - inputFormatClass, - classOf[Writable], - classOf[Writable], - _minSplitsPerRDD) - + val rdd = new HadoopRDDwithCombination( + sc.sparkContext, + _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableConfiguration]], + Some(initializeJobConfFunc), + inputFormatClass, + classOf[Writable], + classOf[Writable], + _minSplitsPerRDD, + sc.conf.mapperSplitCombineSize) // Only take the value (skip the key) because Hive works only with values. rdd.map(_._2) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index b0c0dcbe5c25..a8f3776a1feb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -89,4 +88,33 @@ class HiveTableScanSuite extends HiveComparisonTest { assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi")) assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) } + + test("Spark-8813 Combine small splits for table scan") { + val partitionNum = 5 + val partitionTable = "combine_small" + sql("set hive.exec.dynamic.partition.mode=nonstrict") + val df = (1 to 100).map { i => (i, i) }.toDF("a", "b").coalesce(100) + df.registerTempTable("temp") + sql( + s"""create table $partitionTable (a int, b string) + |partitioned by (c int) + |stored as orc""".stripMargin) + sql( + s"""insert into table $partitionTable partition(c) + |select a, b, (b % $partitionNum) as c from temp""".stripMargin) + + // Check the num of RDD partition without the combination + assert(sql( s"""select * from $partitionTable""").rdd.getNumPartitions == 100) + + // Check the num of RDD partitions with the combination + sql("set spark.sql.mapper.splitCombineSize=10000") + assert(sql( + s"""select * from $partitionTable""").rdd.getNumPartitions == partitionNum) + + // Ensure that the result is the same as the original after the combination + assert( + sql( s"""select * from $partitionTable order by a""").collect().map(_.toString()).deep + == (1 to 100).map{i => s"[$i,$i,${i % partitionNum}]"}.toArray.deep + ) + } }