From c40526db27d5a8693b667a5b6b801fe8cf1300b4 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Tue, 13 Oct 2015 14:37:33 +0800 Subject: [PATCH 1/7] combine split by specific size --- .../org/apache/spark/rdd/HadoopRDD.scala | 21 +-- .../scala/org/apache/spark/sql/SQLConf.scala | 6 + .../spark/sql/hive/mapred/CombineSplit.java | 95 +++++++++++++ .../hive/mapred/CombineSplitInputFormat.java | 109 +++++++++++++++ .../hive/mapred/CombineSplitRecordReader.java | 128 ++++++++++++++++++ .../spark/sql/hive/HadoopCombineRDD.scala | 68 ++++++++++ .../apache/spark/sql/hive/TableReader.scala | 28 ++-- 7 files changed, 438 insertions(+), 17 deletions(-) create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplit.java create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java create mode 100644 sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitRecordReader.java create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopCombineRDD.scala diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 805cd9fe1f638..4aa08834df9d5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -205,13 +205,23 @@ class HadoopRDD[K, V]( array } + protected def registMetricsReadCallback(split: HadoopPartition) = { + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes + val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + getBytesReadCallback + } + override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { val iter = new NextIterator[(K, V)] { val split = theSplit.asInstanceOf[HadoopPartition] logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() - // TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) @@ -223,13 +233,7 @@ class HadoopRDD[K, V]( case _ => SqlNewHadoopRDDState.unsetInputFileName() } - // Find a function that will return the FileSystem bytes read by this thread. Do this before - // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None - } + val getBytesReadCallback = registMetricsReadCallback(split) // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. // If we do a coalesce, however, we are likely to compute multiple partitions in the same @@ -240,7 +244,6 @@ class HadoopRDD[K, V]( inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } - var reader: RecordReader[K, V] = null val inputFormat = getInputFormat(jobConf) HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 61a7b9935af0b..72c498a6e2515 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -507,6 +507,10 @@ private[spark] object SQLConf { " method", isPublic = false) + val MAPPER_SPLIT_COMBINE_SIZE = intConf( + "spark.sql.mapper.splitCombineSize", + defaultValue = Some(-1), + isPublic = true) object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -574,6 +578,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon private[spark] def subexpressionEliminationEnabled: Boolean = getConf(SUBEXPRESSION_ELIMINATION_ENABLED) + private[spark] def mapperSplitCombineSize: Int = getConf(MAPPER_SPLIT_COMBINE_SIZE) + private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = diff --git a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplit.java b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplit.java new file mode 100644 index 0000000000000..07c33b9de6b20 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplit.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.mapred; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableFactories; +import org.apache.hadoop.mapred.InputSplit; + +public class CombineSplit implements InputSplit { + private InputSplit[] splits; + private long totalLen; + private String[] locations; + + public CombineSplit() { + } + + public CombineSplit(InputSplit[] ss, long totalLen, String[] locations) { + splits = ss; + this.totalLen = totalLen; + this.locations = locations; + } + + public InputSplit getSplit(int idx) { + return splits[idx]; + } + + public int getSplitNum() { + return splits.length; + } + + @Override + public long getLength() { + return totalLen; + } + + @Override + public String[] getLocations() throws IOException { + return locations; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeLong(totalLen); + out.writeInt(locations.length); + for (String location : locations) { + out.writeUTF(location); + } + out.writeInt(splits.length); + out.writeUTF(splits[0].getClass().getCanonicalName()); + for (InputSplit split : splits) { + split.write(out); + } + } + + @Override + public void readFields(DataInput in) throws IOException { + this.totalLen = in.readLong(); + this.locations = new String[in.readInt()]; + for (int i = 0; i < locations.length; i++) { + locations[i] = in.readUTF(); + } + splits = new InputSplit[in.readInt()]; + String className = in.readUTF(); + Class clazz = null; + try { + clazz = (Class) Class.forName(className); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + for (int i = 0; i < splits.length; i++) { + Writable value = WritableFactories.newInstance(clazz, null); + value.readFields(in); + splits[i] = (InputSplit) value; + } + } +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java new file mode 100644 index 0000000000000..d33579e489fd1 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.mapred; + +import java.io.IOException; +import java.util.*; + +import com.clearspring.analytics.util.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.apache.hadoop.mapred.*; + +public class CombineSplitInputFormat implements InputFormat { + + private InputFormat inputformat; + private long splitSize = 0; + + public CombineSplitInputFormat(InputFormat inputformat, long splitSize) { + this.inputformat = inputformat; + this.splitSize = splitSize; + } + + /** + * Create a single split from the list of blocks specified in validBlocks + * Add this new split into splitList. + */ + private void addCreatedSplit(List splitList, + long totalLen, + Collection locations, + List validSplits) { + CombineSplit combineSparkSplit = + new CombineSplit(validSplits.toArray(new InputSplit[0]), + totalLen, locations.toArray(new String[0])); + splitList.add(combineSparkSplit); + } + + @Override + public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException { + InputSplit[] splits = inputformat.getSplits(job, numSplits); + // populate nodeToSplits and splitsSet + Map> nodeToSplits = Maps.newHashMap(); + Set splitsSet = Sets.newHashSet(); + for (InputSplit split: splits) { + for (String node: split.getLocations()) { + if (!nodeToSplits.containsKey(node)) { + nodeToSplits.put(node, new ArrayList()); + } + nodeToSplits.get(node).add(split); + } + splitsSet.add(split); + } + // Iterate the nodes to combine in order to evenly distributing the splits + List combineSparkSplits = Lists.newArrayList(); + List oneCombinedSplits = Lists.newArrayList(); + long currentSplitSize = 0L; + for (Map.Entry> entry: nodeToSplits.entrySet()) { + String node = entry.getKey(); + List splitsPerNode = entry.getValue(); + for (InputSplit split: splitsPerNode) { + if (splitSize != 0 && currentSplitSize > splitSize) { + addCreatedSplit(combineSparkSplits, + currentSplitSize, Collections.singleton(node), oneCombinedSplits); + currentSplitSize = 0; + oneCombinedSplits.clear(); + } + // this split has been combined + if (!splitsSet.contains(split)) { + continue; + } else { + currentSplitSize += split.getLength(); + oneCombinedSplits.add(split); + splitsSet.remove(split); + } + } + // populate the remaining splits into one combined split + if (!oneCombinedSplits.isEmpty()) { + long remainLen = 0; + for (InputSplit s: oneCombinedSplits) { + remainLen += s.getLength(); + } + addCreatedSplit(combineSparkSplits, + remainLen, Collections.singleton(node), oneCombinedSplits); + currentSplitSize = 0; + oneCombinedSplits.clear(); + } + } + return combineSparkSplits.toArray(new InputSplit[0]); + } + + @Override + public RecordReader getRecordReader(InputSplit split, JobConf job, Reporter reporter) throws IOException { + return new CombineSplitRecordReader(job, (CombineSplit)split, inputformat); + } +} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitRecordReader.java b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitRecordReader.java new file mode 100644 index 0000000000000..d4873b1c67f96 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitRecordReader.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.mapred; + +import java.io.IOException; + +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.mapred.InputFormat; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.RecordReader; +import org.apache.hadoop.mapred.Reporter; + +/** + * A generic RecordReader that can hand out different recordReaders + * for each split in a {@link org.apache.spark.sql.hive.mapred.CombineSplit}. + */ +@InterfaceAudience.Public +@InterfaceStability.Stable +public class CombineSplitRecordReader implements RecordReader { + protected CombineSplit split; + protected JobConf jc; + protected FileSystem fs; + + protected int idx; + protected long progress; + protected RecordReader curReader; + + @Override + public boolean next(K key, V value) throws IOException { + while ((curReader == null) || !curReader.next(key, value)) { + if (!initNextRecordReader()) { + return false; + } + } + return true; + } + + public K createKey() { + return curReader.createKey(); + } + + public V createValue() { + return curReader.createValue(); + } + + /** + * return the amount of data processed + */ + public long getPos() throws IOException { + return progress; + } + + public void close() throws IOException { + if (curReader != null) { + curReader.close(); + curReader = null; + } + } + + /** + * return progress based on the amount of data processed so far. + */ + public float getProgress() throws IOException { + return Math.min(1.0f, progress/(float)(split.getLength())); + } + private InputFormat inputFormat; + /** + * A generic RecordReader that can hand out different recordReaders + * for each split in the CombineSplit. + */ + public CombineSplitRecordReader(JobConf job, CombineSplit split, + InputFormat inputFormat) + throws IOException { + this.split = split; + this.jc = job; + this.idx = 0; + this.curReader = null; + this.progress = 0; + this.inputFormat = inputFormat; + initNextRecordReader(); + } + + /** + * Get the record reader for the next split in this CombineSplit. + */ + protected boolean initNextRecordReader() throws IOException { + + if (curReader != null) { + curReader.close(); + curReader = null; + if (idx > 0) { + progress += split.getSplit(idx-1).getLength(); // done processing so far + } + } + + // if all splits have been processed, nothing more to do. + if (idx == split.getSplitNum()) { + return false; + } + + // get a record reader for the idx-th split + try { + curReader = inputFormat.getRecordReader(split.getSplit(idx), jc, Reporter.NULL); + } catch (Exception e) { + throw new RuntimeException (e); + } + idx++; + return true; + } +} + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopCombineRDD.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopCombineRDD.scala new file mode 100644 index 0000000000000..9bd66cd0d333e --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopCombineRDD.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapred.lib.CombineFileSplit + +import org.apache.spark.SparkContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.InputMetrics +import org.apache.spark.rdd.{HadoopPartition, HadoopRDD} +import org.apache.spark.sql.hive.mapred.{CombineSplit, CombineSplitInputFormat} +import org.apache.spark.util.SerializableConfiguration + + +class HadoopCombineRDD[K, V]( + @transient sc: SparkContext, + broadcastedConf: Broadcast[SerializableConfiguration], + initLocalJobConfFuncOpt: Option[JobConf => Unit], + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V], + minPartitions: Int, + mapperSplitSize: Int) extends HadoopRDD[K, V](sc, + broadcastedConf, + initLocalJobConfFuncOpt, + inputFormatClass, + keyClass, + valueClass, + minPartitions +) { + override protected def getInputFormat(conf: JobConf): InputFormat[K, V] = { + val inputFormat = super.getInputFormat(conf) + new CombineSplitInputFormat(inputFormat, mapperSplitSize) + } + + override protected def registMetricsReadCallback( + split: HadoopPartition, + inputMetrics: InputMetrics) = { + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { + split.inputSplit.value match { + case _: FileSplit | _: CombineFileSplit | _: CombineSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + } + inputMetrics.setBytesReadCallback(bytesReadCallback) + bytesReadCallback + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index fd465e80a87e5..eb76744e3a971 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -274,15 +274,27 @@ class HadoopTableReader( inputFormatClass: Class[InputFormat[Writable, Writable]]): RDD[Writable] = { val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(path, tableDesc) _ + val rdd = if (sc.conf.mapperSplitCombineSize < 0) { + new HadoopRDD( + sc.sparkContext, + _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableConfiguration]], + Some(initializeJobConfFunc), + inputFormatClass, + classOf[Writable], + classOf[Writable], + _minSplitsPerRDD) - val rdd = new HadoopRDD( - sc.sparkContext, - _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableConfiguration]], - Some(initializeJobConfFunc), - inputFormatClass, - classOf[Writable], - classOf[Writable], - _minSplitsPerRDD) + } else { + new HadoopCombineRDD( + sc.sparkContext, + _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableConfiguration]], + Some(initializeJobConfFunc), + inputFormatClass, + classOf[Writable], + classOf[Writable], + _minSplitsPerRDD, + sc.conf.mapperSplitCombineSize) + } // Only take the value (skip the key) because Hive works only with values. rdd.map(_._2) From faedb9a8338d25175894f201ee60dbf2d4e6f147 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Tue, 20 Oct 2015 09:39:59 +0800 Subject: [PATCH 2/7] comments --- .../spark/sql/hive/mapred/CombineSplit.java | 6 ++++-- .../hive/mapred/CombineSplitInputFormat.java | 17 +++++++++-------- .../hive/mapred/CombineSplitRecordReader.java | 10 +++++----- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplit.java b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplit.java index 07c33b9de6b20..a6bcebd9a3479 100644 --- a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplit.java +++ b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplit.java @@ -59,13 +59,15 @@ public String[] getLocations() throws IOException { @Override public void write(DataOutput out) throws IOException { + // We only process combination within a single table partition, + // so all of the class name of the splits should be identical. + out.writeUTF(splits[0].getClass().getCanonicalName()); out.writeLong(totalLen); out.writeInt(locations.length); for (String location : locations) { out.writeUTF(location); } out.writeInt(splits.length); - out.writeUTF(splits[0].getClass().getCanonicalName()); for (InputSplit split : splits) { split.write(out); } @@ -73,13 +75,13 @@ public void write(DataOutput out) throws IOException { @Override public void readFields(DataInput in) throws IOException { + String className = in.readUTF(); this.totalLen = in.readLong(); this.locations = new String[in.readInt()]; for (int i = 0; i < locations.length; i++) { locations[i] = in.readUTF(); } splits = new InputSplit[in.readInt()]; - String className = in.readUTF(); Class clazz = null; try { clazz = (Class) Class.forName(className); diff --git a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java index d33579e489fd1..7a6f74a7a4dba 100644 --- a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java +++ b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java @@ -36,8 +36,8 @@ public CombineSplitInputFormat(InputFormat inputformat, long splitSize) { } /** - * Create a single split from the list of blocks specified in validBlocks - * Add this new split into splitList. + * Create a combine split from the list of splits + * Add this new combine split into splitList. */ private void addCreatedSplit(List splitList, long totalLen, @@ -72,12 +72,6 @@ public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException { String node = entry.getKey(); List splitsPerNode = entry.getValue(); for (InputSplit split: splitsPerNode) { - if (splitSize != 0 && currentSplitSize > splitSize) { - addCreatedSplit(combineSparkSplits, - currentSplitSize, Collections.singleton(node), oneCombinedSplits); - currentSplitSize = 0; - oneCombinedSplits.clear(); - } // this split has been combined if (!splitsSet.contains(split)) { continue; @@ -86,6 +80,13 @@ public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException { oneCombinedSplits.add(split); splitsSet.remove(split); } + if (splitSize != 0 && currentSplitSize > splitSize) { + // TODO: optimize this by providing the second/third preference locations + addCreatedSplit(combineSparkSplits, + currentSplitSize, Collections.singleton(node), oneCombinedSplits); + currentSplitSize = 0; + oneCombinedSplits.clear(); + } } // populate the remaining splits into one combined split if (!oneCombinedSplits.isEmpty()) { diff --git a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitRecordReader.java b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitRecordReader.java index d4873b1c67f96..5c9f70f2d7ac9 100644 --- a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitRecordReader.java +++ b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitRecordReader.java @@ -39,7 +39,7 @@ public class CombineSplitRecordReader implements RecordReader { protected FileSystem fs; protected int idx; - protected long progress; + protected long progressedBytes; protected RecordReader curReader; @Override @@ -64,7 +64,7 @@ public V createValue() { * return the amount of data processed */ public long getPos() throws IOException { - return progress; + return progressedBytes + curReader.getPos(); } public void close() throws IOException { @@ -78,7 +78,7 @@ public void close() throws IOException { * return progress based on the amount of data processed so far. */ public float getProgress() throws IOException { - return Math.min(1.0f, progress/(float)(split.getLength())); + return Math.min(1.0f, progressedBytes /(float)(split.getLength())); } private InputFormat inputFormat; /** @@ -92,7 +92,7 @@ public CombineSplitRecordReader(JobConf job, CombineSplit split, this.jc = job; this.idx = 0; this.curReader = null; - this.progress = 0; + this.progressedBytes = 0; this.inputFormat = inputFormat; initNextRecordReader(); } @@ -106,7 +106,7 @@ protected boolean initNextRecordReader() throws IOException { curReader.close(); curReader = null; if (idx > 0) { - progress += split.getSplit(idx-1).getLength(); // done processing so far + progressedBytes += split.getSplit(idx-1).getLength(); // done processing so far } } From 051931b5fb34c3a592f1cd200dcaf1d3279bff0f Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Mon, 26 Oct 2015 09:07:29 +0800 Subject: [PATCH 3/7] refactor name --- .../hive/mapred/CombineSplitInputFormat.java | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java index 7a6f74a7a4dba..c7bfec84a4606 100644 --- a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java +++ b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java @@ -39,14 +39,14 @@ public CombineSplitInputFormat(InputFormat inputformat, long splitSize) { * Create a combine split from the list of splits * Add this new combine split into splitList. */ - private void addCreatedSplit(List splitList, + private void addCreatedSplit(List combineSplits, long totalLen, Collection locations, - List validSplits) { + List combineSplitBuffer) { CombineSplit combineSparkSplit = - new CombineSplit(validSplits.toArray(new InputSplit[0]), + new CombineSplit(combineSplitBuffer.toArray(new InputSplit[0]), totalLen, locations.toArray(new String[0])); - splitList.add(combineSparkSplit); + combineSplits.add(combineSparkSplit); } @Override @@ -66,7 +66,7 @@ public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException { } // Iterate the nodes to combine in order to evenly distributing the splits List combineSparkSplits = Lists.newArrayList(); - List oneCombinedSplits = Lists.newArrayList(); + List combinedSplitBuffer = Lists.newArrayList(); long currentSplitSize = 0L; for (Map.Entry> entry: nodeToSplits.entrySet()) { String node = entry.getKey(); @@ -77,27 +77,27 @@ public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException { continue; } else { currentSplitSize += split.getLength(); - oneCombinedSplits.add(split); + combinedSplitBuffer.add(split); splitsSet.remove(split); } if (splitSize != 0 && currentSplitSize > splitSize) { // TODO: optimize this by providing the second/third preference locations addCreatedSplit(combineSparkSplits, - currentSplitSize, Collections.singleton(node), oneCombinedSplits); + currentSplitSize, Collections.singleton(node), combinedSplitBuffer); currentSplitSize = 0; - oneCombinedSplits.clear(); + combinedSplitBuffer.clear(); } } // populate the remaining splits into one combined split - if (!oneCombinedSplits.isEmpty()) { + if (!combinedSplitBuffer.isEmpty()) { long remainLen = 0; - for (InputSplit s: oneCombinedSplits) { + for (InputSplit s: combinedSplitBuffer) { remainLen += s.getLength(); } addCreatedSplit(combineSparkSplits, - remainLen, Collections.singleton(node), oneCombinedSplits); + remainLen, Collections.singleton(node), combinedSplitBuffer); currentSplitSize = 0; - oneCombinedSplits.clear(); + combinedSplitBuffer.clear(); } } return combineSparkSplits.toArray(new InputSplit[0]); From 5b2496c88c9be5888f0341f85dfedb04ab26fcc6 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Tue, 23 Feb 2016 15:28:41 +0800 Subject: [PATCH 4/7] add unit test --- .../hive/execution/HiveTableScanSuite.scala | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index b0c0dcbe5c25c..c5699da28b1fc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row -import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -89,4 +90,32 @@ class HiveTableScanSuite extends HiveComparisonTest { assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi")) assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) } + + test("combine small files") { + def getPartitions[T](sc: SparkContext, rdd: RDD[T]): Array[Long] = { + sc.runJob(rdd, (x: Iterator[T]) => 1) + } + val partitionNum = 5 + val partitionTable = "combine_small" + sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql("set spark.sql.mapper.splitCombineSize=1000000") + val df = (1 to 100).map { i => (i, i) }.toDF("a", "b").repartition(100) + df.registerTempTable("temp") + sql( + s"""create table $partitionTable (a int, b string) + |partitioned by (c int) + |stored as orc""".stripMargin) + sql( + s"""insert into table $partitionTable partition(c) + |select a, b, (b % $partitionNum) as c from temp""".stripMargin) + val result = sql( s"""select * from $partitionTable order by a""") + val rddPartitions = getPartitions(TestHive.sparkContext, result.rdd) + // Ensure that there are only have 4 RDD partitions after combination + assert(rddPartitions.length == 5) + // Ensure that the result is the same as the original after combination + assert( + result.collect().map(_.toString()).deep + == (1 to 100).map{i => s"[$i,$i,${i % partitionNum}]"}.toArray.deep + ) + } } From 0ab38b30b83b918c4f23c6e274ed435067df1dde Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Tue, 23 Feb 2016 15:56:38 +0800 Subject: [PATCH 5/7] fix rebase --- .../spark/sql/hive/HadoopCombineRDD.scala | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopCombineRDD.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopCombineRDD.scala index 9bd66cd0d333e..be6b3b465e9c5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopCombineRDD.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopCombineRDD.scala @@ -50,19 +50,14 @@ class HadoopCombineRDD[K, V]( new CombineSplitInputFormat(inputFormat, mapperSplitSize) } - override protected def registMetricsReadCallback( - split: HadoopPartition, - inputMetrics: InputMetrics) = { + override protected def registMetricsReadCallback(split: HadoopPartition) = { // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { - split.inputSplit.value match { - case _: FileSplit | _: CombineFileSplit | _: CombineSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None - } + val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None } - inputMetrics.setBytesReadCallback(bytesReadCallback) - bytesReadCallback + getBytesReadCallback } } From 701700b10078d48d8e9f0342e623e94243642de3 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Wed, 24 Feb 2016 10:08:33 +0800 Subject: [PATCH 6/7] refactor code and unit test useless import refactor --- .../org/apache/spark/rdd/HadoopRDD.scala | 15 +- .../spark/sql/hive/mapred/CombineSplit.java | 97 -------- .../hive/mapred/CombineSplitInputFormat.java | 210 +++++++++++++++--- .../hive/mapred/CombineSplitRecordReader.java | 128 ----------- ...D.scala => HadoopRDDwithCombination.scala} | 20 +- .../apache/spark/sql/hive/TableReader.scala | 15 +- .../hive/execution/HiveTableScanSuite.scala | 27 ++- 7 files changed, 211 insertions(+), 301 deletions(-) delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplit.java delete mode 100644 sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitRecordReader.java rename sql/hive/src/main/scala/org/apache/spark/sql/hive/{HadoopCombineRDD.scala => HadoopRDDwithCombination.scala} (79%) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 4aa08834df9d5..d9f0f184f4908 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -205,15 +205,14 @@ class HadoopRDD[K, V]( array } - protected def registMetricsReadCallback(split: HadoopPartition) = { - // Find a function that will return the FileSystem bytes read by this thread. Do this before - // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes + protected def getBytesReadCallback(split: HadoopPartition): Option[() => Long] = { + split.inputSplit.value match { case _: FileSplit | _: CombineFileSplit => SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None } - getBytesReadCallback } override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { @@ -233,14 +232,14 @@ class HadoopRDD[K, V]( case _ => SqlNewHadoopRDDState.unsetInputFileName() } - val getBytesReadCallback = registMetricsReadCallback(split) + val bytesReadCallback = getBytesReadCallback(split) // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. // If we do a coalesce, however, we are likely to compute multiple partitions in the same // task and in the same thread, in which case we need to avoid override values written by // previous partitions (SPARK-13071). def updateBytesRead(): Unit = { - getBytesReadCallback.foreach { getBytesRead => + bytesReadCallback.foreach { getBytesRead => inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } @@ -288,7 +287,7 @@ class HadoopRDD[K, V]( } finally { reader = null } - if (getBytesReadCallback.isDefined) { + if (bytesReadCallback.isDefined) { updateBytesRead() } else if (split.inputSplit.value.isInstanceOf[FileSplit] || split.inputSplit.value.isInstanceOf[CombineFileSplit]) { diff --git a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplit.java b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplit.java deleted file mode 100644 index a6bcebd9a3479..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplit.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.mapred; - -import java.io.DataInput; -import java.io.DataOutput; -import java.io.IOException; - -import org.apache.hadoop.io.Writable; -import org.apache.hadoop.io.WritableFactories; -import org.apache.hadoop.mapred.InputSplit; - -public class CombineSplit implements InputSplit { - private InputSplit[] splits; - private long totalLen; - private String[] locations; - - public CombineSplit() { - } - - public CombineSplit(InputSplit[] ss, long totalLen, String[] locations) { - splits = ss; - this.totalLen = totalLen; - this.locations = locations; - } - - public InputSplit getSplit(int idx) { - return splits[idx]; - } - - public int getSplitNum() { - return splits.length; - } - - @Override - public long getLength() { - return totalLen; - } - - @Override - public String[] getLocations() throws IOException { - return locations; - } - - @Override - public void write(DataOutput out) throws IOException { - // We only process combination within a single table partition, - // so all of the class name of the splits should be identical. - out.writeUTF(splits[0].getClass().getCanonicalName()); - out.writeLong(totalLen); - out.writeInt(locations.length); - for (String location : locations) { - out.writeUTF(location); - } - out.writeInt(splits.length); - for (InputSplit split : splits) { - split.write(out); - } - } - - @Override - public void readFields(DataInput in) throws IOException { - String className = in.readUTF(); - this.totalLen = in.readLong(); - this.locations = new String[in.readInt()]; - for (int i = 0; i < locations.length; i++) { - locations[i] = in.readUTF(); - } - splits = new InputSplit[in.readInt()]; - Class clazz = null; - try { - clazz = (Class) Class.forName(className); - } catch (ClassNotFoundException e) { - throw new RuntimeException(e); - } - for (int i = 0; i < splits.length; i++) { - Writable value = WritableFactories.newInstance(clazz, null); - value.readFields(in); - splits[i] = (InputSplit) value; - } - } -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java index c7bfec84a4606..662f827b0ce1b 100644 --- a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java +++ b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitInputFormat.java @@ -17,42 +17,39 @@ package org.apache.spark.sql.hive.mapred; -import java.io.IOException; -import java.util.*; - import com.clearspring.analytics.util.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableFactories; import org.apache.hadoop.mapred.*; +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.*; + public class CombineSplitInputFormat implements InputFormat { - private InputFormat inputformat; - private long splitSize = 0; + private InputFormat delegate; + private long splitCombineSize = 0; - public CombineSplitInputFormat(InputFormat inputformat, long splitSize) { - this.inputformat = inputformat; - this.splitSize = splitSize; + public CombineSplitInputFormat(InputFormat inputformat, long sSize) { + this.delegate = inputformat; + this.splitCombineSize = sSize; } - /** - * Create a combine split from the list of splits - * Add this new combine split into splitList. - */ - private void addCreatedSplit(List combineSplits, - long totalLen, - Collection locations, - List combineSplitBuffer) { - CombineSplit combineSparkSplit = - new CombineSplit(combineSplitBuffer.toArray(new InputSplit[0]), + private CombineSplit createCombineSplit( + long totalLen, + Collection locations, + List combineSplitBuffer) { + return new CombineSplit(combineSplitBuffer.toArray(new InputSplit[0]), totalLen, locations.toArray(new String[0])); - combineSplits.add(combineSparkSplit); } @Override public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException { - InputSplit[] splits = inputformat.getSplits(job, numSplits); - // populate nodeToSplits and splitsSet + InputSplit[] splits = delegate.getSplits(job, numSplits); Map> nodeToSplits = Maps.newHashMap(); Set splitsSet = Sets.newHashSet(); for (InputSplit split: splits) { @@ -65,9 +62,10 @@ public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException { splitsSet.add(split); } // Iterate the nodes to combine in order to evenly distributing the splits + // Ideally splits within the same combination should be in the same node List combineSparkSplits = Lists.newArrayList(); List combinedSplitBuffer = Lists.newArrayList(); - long currentSplitSize = 0L; + long accumulatedSplitSize = 0L; for (Map.Entry> entry: nodeToSplits.entrySet()) { String node = entry.getKey(); List splitsPerNode = entry.getValue(); @@ -76,15 +74,15 @@ public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException { if (!splitsSet.contains(split)) { continue; } else { - currentSplitSize += split.getLength(); + accumulatedSplitSize += split.getLength(); combinedSplitBuffer.add(split); splitsSet.remove(split); } - if (splitSize != 0 && currentSplitSize > splitSize) { + if (splitCombineSize > 0 && accumulatedSplitSize >= splitCombineSize) { // TODO: optimize this by providing the second/third preference locations - addCreatedSplit(combineSparkSplits, - currentSplitSize, Collections.singleton(node), combinedSplitBuffer); - currentSplitSize = 0; + combineSparkSplits.add(createCombineSplit( + accumulatedSplitSize, Collections.singleton(node), combinedSplitBuffer)); + accumulatedSplitSize = 0; combinedSplitBuffer.clear(); } } @@ -94,9 +92,9 @@ public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException { for (InputSplit s: combinedSplitBuffer) { remainLen += s.getLength(); } - addCreatedSplit(combineSparkSplits, - remainLen, Collections.singleton(node), combinedSplitBuffer); - currentSplitSize = 0; + combineSparkSplits.add(createCombineSplit( + remainLen, Collections.singleton(node), combinedSplitBuffer)); + accumulatedSplitSize = 0; combinedSplitBuffer.clear(); } } @@ -104,7 +102,155 @@ public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException { } @Override - public RecordReader getRecordReader(InputSplit split, JobConf job, Reporter reporter) throws IOException { - return new CombineSplitRecordReader(job, (CombineSplit)split, inputformat); + public RecordReader getRecordReader(final InputSplit split, + final JobConf jobConf, final Reporter reporter) throws IOException { + return new RecordReader() { + protected int idx = 0; + protected long progressedBytes = 0; + protected RecordReader curReader; + protected CombineSplit combineSplit; + { + combineSplit = (CombineSplit)split; + initNextRecordReader(); + } + + @Override + public boolean next(K key, V value) throws IOException { + while ((curReader == null) || !curReader.next(key, value)) { + if (!initNextRecordReader()) { + return false; + } + } + return true; + } + + public K createKey() { + return curReader.createKey(); + } + + public V createValue() { + return curReader.createValue(); + } + + /** + * return the amount of data processed + */ + public long getPos() throws IOException { + return progressedBytes + curReader.getPos(); + } + + public void close() throws IOException { + if (curReader != null) { + curReader.close(); + curReader = null; + } + } + + /** + * return progress based on the amount of data processed so far. + */ + public float getProgress() throws IOException { + return Math.min(1.0f, progressedBytes /(float)(split.getLength())); + } + + /** + * Get the record reader for the next split in this CombineSplit. + */ + protected boolean initNextRecordReader() throws IOException { + + if (curReader != null) { + curReader.close(); + curReader = null; + if (idx > 0) { + progressedBytes += combineSplit.getSplit(idx-1).getLength(); // done processing so far + } + } + + // if all splits have been processed, nothing more to do. + if (idx == combineSplit.getSplitNum()) { + return false; + } + + // get a record reader for the idx-th split + try { + curReader = delegate.getRecordReader(combineSplit.getSplit(idx), jobConf, Reporter.NULL); + } catch (Exception e) { + throw new RuntimeException (e); + } + idx++; + return true; + } + }; + } + + public static class CombineSplit implements InputSplit { + private InputSplit[] splits; + private long totalLen; + private String[] locations; + + public CombineSplit() { + } + + public CombineSplit(InputSplit[] ss, long totalLen, String[] locations) { + splits = ss; + this.totalLen = totalLen; + this.locations = locations; + } + + public InputSplit getSplit(int idx) { + return splits[idx]; + } + + public int getSplitNum() { + return splits.length; + } + + @Override + public long getLength() { + return totalLen; + } + + @Override + public String[] getLocations() throws IOException { + return locations; + } + + @Override + public void write(DataOutput out) throws IOException { + // We only process combination within a single table partition, + // so all of the class name of the splits should be identical. + out.writeUTF(splits[0].getClass().getCanonicalName()); + out.writeLong(totalLen); + out.writeInt(locations.length); + for (String location : locations) { + out.writeUTF(location); + } + out.writeInt(splits.length); + for (InputSplit split : splits) { + split.write(out); + } + } + + @Override + public void readFields(DataInput in) throws IOException { + String className = in.readUTF(); + this.totalLen = in.readLong(); + this.locations = new String[in.readInt()]; + for (int i = 0; i < locations.length; i++) { + locations[i] = in.readUTF(); + } + splits = new InputSplit[in.readInt()]; + Class clazz = null; + try { + clazz = (Class) Class.forName(className); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + for (int i = 0; i < splits.length; i++) { + Writable value = WritableFactories.newInstance(clazz, null); + value.readFields(in); + splits[i] = (InputSplit) value; + } + } } } diff --git a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitRecordReader.java b/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitRecordReader.java deleted file mode 100644 index 5c9f70f2d7ac9..0000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/hive/mapred/CombineSplitRecordReader.java +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.mapred; - -import java.io.IOException; - -import org.apache.hadoop.classification.InterfaceAudience; -import org.apache.hadoop.classification.InterfaceStability; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.mapred.InputFormat; -import org.apache.hadoop.mapred.JobConf; -import org.apache.hadoop.mapred.RecordReader; -import org.apache.hadoop.mapred.Reporter; - -/** - * A generic RecordReader that can hand out different recordReaders - * for each split in a {@link org.apache.spark.sql.hive.mapred.CombineSplit}. - */ -@InterfaceAudience.Public -@InterfaceStability.Stable -public class CombineSplitRecordReader implements RecordReader { - protected CombineSplit split; - protected JobConf jc; - protected FileSystem fs; - - protected int idx; - protected long progressedBytes; - protected RecordReader curReader; - - @Override - public boolean next(K key, V value) throws IOException { - while ((curReader == null) || !curReader.next(key, value)) { - if (!initNextRecordReader()) { - return false; - } - } - return true; - } - - public K createKey() { - return curReader.createKey(); - } - - public V createValue() { - return curReader.createValue(); - } - - /** - * return the amount of data processed - */ - public long getPos() throws IOException { - return progressedBytes + curReader.getPos(); - } - - public void close() throws IOException { - if (curReader != null) { - curReader.close(); - curReader = null; - } - } - - /** - * return progress based on the amount of data processed so far. - */ - public float getProgress() throws IOException { - return Math.min(1.0f, progressedBytes /(float)(split.getLength())); - } - private InputFormat inputFormat; - /** - * A generic RecordReader that can hand out different recordReaders - * for each split in the CombineSplit. - */ - public CombineSplitRecordReader(JobConf job, CombineSplit split, - InputFormat inputFormat) - throws IOException { - this.split = split; - this.jc = job; - this.idx = 0; - this.curReader = null; - this.progressedBytes = 0; - this.inputFormat = inputFormat; - initNextRecordReader(); - } - - /** - * Get the record reader for the next split in this CombineSplit. - */ - protected boolean initNextRecordReader() throws IOException { - - if (curReader != null) { - curReader.close(); - curReader = null; - if (idx > 0) { - progressedBytes += split.getSplit(idx-1).getLength(); // done processing so far - } - } - - // if all splits have been processed, nothing more to do. - if (idx == split.getSplitNum()) { - return false; - } - - // get a record reader for the idx-th split - try { - curReader = inputFormat.getRecordReader(split.getSplit(idx), jc, Reporter.NULL); - } catch (Exception e) { - throw new RuntimeException (e); - } - idx++; - return true; - } -} - diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopCombineRDD.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopRDDwithCombination.scala similarity index 79% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopCombineRDD.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopRDDwithCombination.scala index be6b3b465e9c5..ba5e2d7f097c2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopCombineRDD.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopRDDwithCombination.scala @@ -23,13 +23,13 @@ import org.apache.hadoop.mapred.lib.CombineFileSplit import org.apache.spark.SparkContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.InputMetrics import org.apache.spark.rdd.{HadoopPartition, HadoopRDD} -import org.apache.spark.sql.hive.mapred.{CombineSplit, CombineSplitInputFormat} +import org.apache.spark.sql.hive.mapred.CombineSplitInputFormat +import org.apache.spark.sql.hive.mapred.CombineSplitInputFormat.CombineSplit import org.apache.spark.util.SerializableConfiguration -class HadoopCombineRDD[K, V]( +class HadoopRDDwithCombination[K, V]( @transient sc: SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], initLocalJobConfFuncOpt: Option[JobConf => Unit], @@ -37,7 +37,7 @@ class HadoopCombineRDD[K, V]( keyClass: Class[K], valueClass: Class[V], minPartitions: Int, - mapperSplitSize: Int) extends HadoopRDD[K, V](sc, + splitCombineSize: Int) extends HadoopRDD[K, V](sc, broadcastedConf, initLocalJobConfFuncOpt, inputFormatClass, @@ -45,16 +45,20 @@ class HadoopCombineRDD[K, V]( valueClass, minPartitions ) { + override protected def getInputFormat(conf: JobConf): InputFormat[K, V] = { - val inputFormat = super.getInputFormat(conf) - new CombineSplitInputFormat(inputFormat, mapperSplitSize) + if (splitCombineSize < 0) { + super.getInputFormat(conf) + } else { + new CombineSplitInputFormat(super.getInputFormat(conf), splitCombineSize) + } } - override protected def registMetricsReadCallback(split: HadoopPartition) = { + override protected def getBytesReadCallback(split: HadoopPartition) = { // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { - case _: FileSplit | _: CombineFileSplit => + case _: FileSplit | _: CombineFileSplit | _: CombineSplit => SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index eb76744e3a971..f1b865f0e2e0e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -274,18 +274,7 @@ class HadoopTableReader( inputFormatClass: Class[InputFormat[Writable, Writable]]): RDD[Writable] = { val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(path, tableDesc) _ - val rdd = if (sc.conf.mapperSplitCombineSize < 0) { - new HadoopRDD( - sc.sparkContext, - _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableConfiguration]], - Some(initializeJobConfFunc), - inputFormatClass, - classOf[Writable], - classOf[Writable], - _minSplitsPerRDD) - - } else { - new HadoopCombineRDD( + val rdd = new HadoopRDDwithCombination( sc.sparkContext, _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableConfiguration]], Some(initializeJobConfFunc), @@ -294,8 +283,6 @@ class HadoopTableReader( classOf[Writable], _minSplitsPerRDD, sc.conf.mapperSplitCombineSize) - } - // Only take the value (skip the key) because Hive works only with values. rdd.map(_._2) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index c5699da28b1fc..a8f3776a1febc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -91,15 +89,11 @@ class HiveTableScanSuite extends HiveComparisonTest { assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) } - test("combine small files") { - def getPartitions[T](sc: SparkContext, rdd: RDD[T]): Array[Long] = { - sc.runJob(rdd, (x: Iterator[T]) => 1) - } + test("Spark-8813 Combine small splits for table scan") { val partitionNum = 5 val partitionTable = "combine_small" sql("set hive.exec.dynamic.partition.mode=nonstrict") - sql("set spark.sql.mapper.splitCombineSize=1000000") - val df = (1 to 100).map { i => (i, i) }.toDF("a", "b").repartition(100) + val df = (1 to 100).map { i => (i, i) }.toDF("a", "b").coalesce(100) df.registerTempTable("temp") sql( s"""create table $partitionTable (a int, b string) @@ -108,13 +102,18 @@ class HiveTableScanSuite extends HiveComparisonTest { sql( s"""insert into table $partitionTable partition(c) |select a, b, (b % $partitionNum) as c from temp""".stripMargin) - val result = sql( s"""select * from $partitionTable order by a""") - val rddPartitions = getPartitions(TestHive.sparkContext, result.rdd) - // Ensure that there are only have 4 RDD partitions after combination - assert(rddPartitions.length == 5) - // Ensure that the result is the same as the original after combination + + // Check the num of RDD partition without the combination + assert(sql( s"""select * from $partitionTable""").rdd.getNumPartitions == 100) + + // Check the num of RDD partitions with the combination + sql("set spark.sql.mapper.splitCombineSize=10000") + assert(sql( + s"""select * from $partitionTable""").rdd.getNumPartitions == partitionNum) + + // Ensure that the result is the same as the original after the combination assert( - result.collect().map(_.toString()).deep + sql( s"""select * from $partitionTable order by a""").collect().map(_.toString()).deep == (1 to 100).map{i => s"[$i,$i,${i % partitionNum}]"}.toArray.deep ) } From 085ce5feca2294f81f9ec7a5660635be13c70a4a Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Wed, 24 Feb 2016 12:24:45 +0800 Subject: [PATCH 7/7] remove transient --- .../org/apache/spark/sql/hive/HadoopRDDwithCombination.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopRDDwithCombination.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopRDDwithCombination.scala index ba5e2d7f097c2..5fe4b51006a6b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopRDDwithCombination.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HadoopRDDwithCombination.scala @@ -30,7 +30,7 @@ import org.apache.spark.util.SerializableConfiguration class HadoopRDDwithCombination[K, V]( - @transient sc: SparkContext, + sc: SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]],