Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,22 @@ class HadoopRDD[K, V](
array
}

// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
protected def getBytesReadCallback(split: HadoopPartition): Option[() => Long] = {
split.inputSplit.value match {
case _: FileSplit | _: CombineFileSplit =>
SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
case _ => None
}
}

override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
val iter = new NextIterator[(K, V)] {

val split = theSplit.asInstanceOf[HadoopPartition]
logInfo("Input split: " + split.inputSplit)
val jobConf = getJobConf()

// TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD

val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop)
Expand All @@ -223,24 +232,17 @@ class HadoopRDD[K, V](
case _ => SqlNewHadoopRDDState.unsetInputFileName()
}

// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match {
case _: FileSplit | _: CombineFileSplit =>
SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
case _ => None
}
val bytesReadCallback = getBytesReadCallback(split)

// For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics.
// If we do a coalesce, however, we are likely to compute multiple partitions in the same
// task and in the same thread, in which case we need to avoid override values written by
// previous partitions (SPARK-13071).
def updateBytesRead(): Unit = {
getBytesReadCallback.foreach { getBytesRead =>
bytesReadCallback.foreach { getBytesRead =>
inputMetrics.setBytesRead(existingBytesRead + getBytesRead())
}
}

var reader: RecordReader[K, V] = null
val inputFormat = getInputFormat(jobConf)
HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
Expand Down Expand Up @@ -285,7 +287,7 @@ class HadoopRDD[K, V](
} finally {
reader = null
}
if (getBytesReadCallback.isDefined) {
if (bytesReadCallback.isDefined) {
updateBytesRead()
} else if (split.inputSplit.value.isInstanceOf[FileSplit] ||
split.inputSplit.value.isInstanceOf[CombineFileSplit]) {
Expand Down
6 changes: 6 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,10 @@ private[spark] object SQLConf {
" method",
isPublic = false)

val MAPPER_SPLIT_COMBINE_SIZE = intConf(
"spark.sql.mapper.splitCombineSize",
defaultValue = Some(-1),
isPublic = true)

object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
Expand Down Expand Up @@ -574,6 +578,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon
private[spark] def subexpressionEliminationEnabled: Boolean =
getConf(SUBEXPRESSION_ELIMINATION_ENABLED)

private[spark] def mapperSplitCombineSize: Int = getConf(MAPPER_SPLIT_COMBINE_SIZE)

private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)

private[spark] def defaultSizeInBytes: Long =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive.mapred;

import com.clearspring.analytics.util.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableFactories;
import org.apache.hadoop.mapred.*;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.*;

public class CombineSplitInputFormat<K, V> implements InputFormat<K, V> {

private InputFormat<K, V> delegate;
private long splitCombineSize = 0;

public CombineSplitInputFormat(InputFormat<K, V> inputformat, long sSize) {
this.delegate = inputformat;
this.splitCombineSize = sSize;
}

private CombineSplit createCombineSplit(
long totalLen,
Collection<String> locations,
List<InputSplit> combineSplitBuffer) {
return new CombineSplit(combineSplitBuffer.toArray(new InputSplit[0]),
totalLen, locations.toArray(new String[0]));
}

@Override
public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException {
InputSplit[] splits = delegate.getSplits(job, numSplits);
Map<String, List<InputSplit>> nodeToSplits = Maps.newHashMap();
Set<InputSplit> splitsSet = Sets.newHashSet();
for (InputSplit split: splits) {
for (String node: split.getLocations()) {
if (!nodeToSplits.containsKey(node)) {
nodeToSplits.put(node, new ArrayList<InputSplit>());
}
nodeToSplits.get(node).add(split);
}
splitsSet.add(split);
}
// Iterate the nodes to combine in order to evenly distributing the splits
// Ideally splits within the same combination should be in the same node
List<CombineSplit> combineSparkSplits = Lists.newArrayList();
List<InputSplit> combinedSplitBuffer = Lists.newArrayList();
long accumulatedSplitSize = 0L;
for (Map.Entry<String, List<InputSplit>> entry: nodeToSplits.entrySet()) {
String node = entry.getKey();
List<InputSplit> splitsPerNode = entry.getValue();
for (InputSplit split: splitsPerNode) {
// this split has been combined
if (!splitsSet.contains(split)) {
continue;
} else {
accumulatedSplitSize += split.getLength();
combinedSplitBuffer.add(split);
splitsSet.remove(split);
}
if (splitCombineSize > 0 && accumulatedSplitSize >= splitCombineSize) {
// TODO: optimize this by providing the second/third preference locations
combineSparkSplits.add(createCombineSplit(
accumulatedSplitSize, Collections.singleton(node), combinedSplitBuffer));
accumulatedSplitSize = 0;
combinedSplitBuffer.clear();
}
}
// populate the remaining splits into one combined split
if (!combinedSplitBuffer.isEmpty()) {
long remainLen = 0;
for (InputSplit s: combinedSplitBuffer) {
remainLen += s.getLength();
}
combineSparkSplits.add(createCombineSplit(
remainLen, Collections.singleton(node), combinedSplitBuffer));
accumulatedSplitSize = 0;
combinedSplitBuffer.clear();
}
}
return combineSparkSplits.toArray(new InputSplit[0]);
}

@Override
public RecordReader<K, V> getRecordReader(final InputSplit split,
final JobConf jobConf, final Reporter reporter) throws IOException {
return new RecordReader<K, V>() {
protected int idx = 0;
protected long progressedBytes = 0;
protected RecordReader<K, V> curReader;
protected CombineSplit combineSplit;
{
combineSplit = (CombineSplit)split;
initNextRecordReader();
}

@Override
public boolean next(K key, V value) throws IOException {
while ((curReader == null) || !curReader.next(key, value)) {
if (!initNextRecordReader()) {
return false;
}
}
return true;
}

public K createKey() {
return curReader.createKey();
}

public V createValue() {
return curReader.createValue();
}

/**
* return the amount of data processed
*/
public long getPos() throws IOException {
return progressedBytes + curReader.getPos();
}

public void close() throws IOException {
if (curReader != null) {
curReader.close();
curReader = null;
}
}

/**
* return progress based on the amount of data processed so far.
*/
public float getProgress() throws IOException {
return Math.min(1.0f, progressedBytes /(float)(split.getLength()));
}

/**
* Get the record reader for the next split in this CombineSplit.
*/
protected boolean initNextRecordReader() throws IOException {

if (curReader != null) {
curReader.close();
curReader = null;
if (idx > 0) {
progressedBytes += combineSplit.getSplit(idx-1).getLength(); // done processing so far
}
}

// if all splits have been processed, nothing more to do.
if (idx == combineSplit.getSplitNum()) {
return false;
}

// get a record reader for the idx-th split
try {
curReader = delegate.getRecordReader(combineSplit.getSplit(idx), jobConf, Reporter.NULL);
} catch (Exception e) {
throw new RuntimeException (e);
}
idx++;
return true;
}
};
}

public static class CombineSplit implements InputSplit {
private InputSplit[] splits;
private long totalLen;
private String[] locations;

public CombineSplit() {
}

public CombineSplit(InputSplit[] ss, long totalLen, String[] locations) {
splits = ss;
this.totalLen = totalLen;
this.locations = locations;
}

public InputSplit getSplit(int idx) {
return splits[idx];
}

public int getSplitNum() {
return splits.length;
}

@Override
public long getLength() {
return totalLen;
}

@Override
public String[] getLocations() throws IOException {
return locations;
}

@Override
public void write(DataOutput out) throws IOException {
// We only process combination within a single table partition,
// so all of the class name of the splits should be identical.
out.writeUTF(splits[0].getClass().getCanonicalName());
out.writeLong(totalLen);
out.writeInt(locations.length);
for (String location : locations) {
out.writeUTF(location);
}
out.writeInt(splits.length);
for (InputSplit split : splits) {
split.write(out);
}
}

@Override
public void readFields(DataInput in) throws IOException {
String className = in.readUTF();
this.totalLen = in.readLong();
this.locations = new String[in.readInt()];
for (int i = 0; i < locations.length; i++) {
locations[i] = in.readUTF();
}
splits = new InputSplit[in.readInt()];
Class<? extends Writable> clazz = null;
try {
clazz = (Class<? extends Writable>) Class.forName(className);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
for (int i = 0; i < splits.length; i++) {
Writable value = WritableFactories.newInstance(clazz, null);
value.readFields(in);
splits[i] = (InputSplit) value;
}
}
}
}
Loading