diff --git a/.rat-excludes b/.rat-excludes index 15344dfb292db..796c32a80896c 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -4,6 +4,8 @@ target .classpath .mima-excludes .generated-mima-excludes +.generated-mima-class-excludes +.generated-mima-member-excludes .rat-excludes .*md derby.log diff --git a/bin/spark-class b/bin/spark-class index cfe363a71da31..60d9657c0ffcd 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -130,6 +130,11 @@ else fi if [[ "$1" =~ org.apache.spark.tools.* ]]; then + if test -z "$SPARK_TOOLS_JAR"; then + echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SCALA_VERSION/" 1>&2 + echo "You need to build spark before running $1." 1>&2 + exit 1 + fi CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR" fi diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index f7f853559468a..89eec7d4b7f61 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -7,5 +7,6 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: # Settings to quiet third party logs that are too verbose log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/pom.xml b/core/pom.xml index bd6767e03bb9d..8c23842730e37 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -244,6 +244,11 @@ easymockclassextension test + + asm + asm + test + com.novocode junit-interface diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index f7f853559468a..89eec7d4b7f61 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -7,5 +7,6 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: # Settings to quiet third party logs that are too verbose log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index cdfd338081fa2..9c55bfbb47626 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -127,7 +127,7 @@ class Accumulable[R, T] ( Accumulators.register(this, false) } - override def toString = value_.toString + override def toString = if (value_ == null) "null" else value_.toString } /** diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 315ed91f81df3..3f667a4a0f9c5 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -20,25 +20,25 @@ package org.apache.spark import scala.collection.mutable.{ArrayBuffer, HashSet} import org.apache.spark.rdd.RDD -import org.apache.spark.storage.{BlockId, BlockManager, BlockStatus, RDDBlockId, StorageLevel} +import org.apache.spark.storage._ /** - * Spark class responsible for passing RDDs split contents to the BlockManager and making + * Spark class responsible for passing RDDs partition contents to the BlockManager and making * sure a node doesn't load two copies of an RDD at once. */ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { - /** Keys of RDD splits that are being computed/loaded. */ + /** Keys of RDD partitions that are being computed/loaded. */ private val loading = new HashSet[RDDBlockId]() - /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */ + /** Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached. */ def getOrCompute[T]( rdd: RDD[T], - split: Partition, + partition: Partition, context: TaskContext, storageLevel: StorageLevel): Iterator[T] = { - val key = RDDBlockId(rdd.id, split.index) + val key = RDDBlockId(rdd.id, partition.index) logDebug(s"Looking for partition $key") blockManager.get(key) match { case Some(values) => @@ -46,79 +46,28 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) case None => - // Mark the split as loading (unless someone else marks it first) - loading.synchronized { - if (loading.contains(key)) { - logInfo(s"Another thread is loading $key, waiting for it to finish...") - while (loading.contains(key)) { - try { - loading.wait() - } catch { - case e: Exception => - logWarning(s"Got an exception while waiting for another thread to load $key", e) - } - } - logInfo(s"Finished waiting for $key") - /* See whether someone else has successfully loaded it. The main way this would fail - * is for the RDD-level cache eviction policy if someone else has loaded the same RDD - * partition but we didn't want to make space for it. However, that case is unlikely - * because it's unlikely that two threads would work on the same RDD partition. One - * downside of the current code is that threads wait serially if this does happen. */ - blockManager.get(key) match { - case Some(values) => - return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) - case None => - logInfo(s"Whoever was loading $key failed; we'll try it ourselves") - loading.add(key) - } - } else { - loading.add(key) - } + // Acquire a lock for loading this partition + // If another thread already holds the lock, wait for it to finish return its results + val storedValues = acquireLockForPartition[T](key) + if (storedValues.isDefined) { + return new InterruptibleIterator[T](context, storedValues.get) } + + // Otherwise, we have to load the partition ourselves try { - // If we got here, we have to load the split logInfo(s"Partition $key not found, computing it") - val computedValues = rdd.computeOrReadCheckpoint(split, context) + val computedValues = rdd.computeOrReadCheckpoint(partition, context) - // Persist the result, so long as the task is not running locally + // If the task is running locally, do not persist the result if (context.runningLocally) { return computedValues } - // Keep track of blocks with updated statuses - var updatedBlocks = Seq[(BlockId, BlockStatus)]() - val returnValue: Iterator[T] = { - if (storageLevel.useDisk && !storageLevel.useMemory) { - /* In the case that this RDD is to be persisted using DISK_ONLY - * the iterator will be passed directly to the blockManager (rather then - * caching it to an ArrayBuffer first), then the resulting block data iterator - * will be passed back to the user. If the iterator generates a lot of data, - * this means that it doesn't all have to be held in memory at one time. - * This could also apply to MEMORY_ONLY_SER storage, but we need to make sure - * blocks aren't dropped by the block store before enabling that. */ - updatedBlocks = blockManager.put(key, computedValues, storageLevel, tellMaster = true) - blockManager.get(key) match { - case Some(values) => - values.asInstanceOf[Iterator[T]] - case None => - logInfo(s"Failure to store $key") - throw new SparkException("Block manager failed to return persisted value") - } - } else { - // In this case the RDD is cached to an array buffer. This will save the results - // if we're dealing with a 'one-time' iterator - val elements = new ArrayBuffer[Any] - elements ++= computedValues - updatedBlocks = blockManager.put(key, elements, storageLevel, tellMaster = true) - elements.iterator.asInstanceOf[Iterator[T]] - } - } - - // Update task metrics to include any blocks whose storage status is updated - val metrics = context.taskMetrics - metrics.updatedBlocks = Some(updatedBlocks) - - new InterruptibleIterator(context, returnValue) + // Otherwise, cache the values and keep track of any updates in block statuses + val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + val cachedValues = putInBlockManager(key, computedValues, storageLevel, updatedBlocks) + context.taskMetrics.updatedBlocks = Some(updatedBlocks) + new InterruptibleIterator(context, cachedValues) } finally { loading.synchronized { @@ -128,4 +77,76 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { } } } + + /** + * Acquire a loading lock for the partition identified by the given block ID. + * + * If the lock is free, just acquire it and return None. Otherwise, another thread is already + * loading the partition, so we wait for it to finish and return the values loaded by the thread. + */ + private def acquireLockForPartition[T](id: RDDBlockId): Option[Iterator[T]] = { + loading.synchronized { + if (!loading.contains(id)) { + // If the partition is free, acquire its lock to compute its value + loading.add(id) + None + } else { + // Otherwise, wait for another thread to finish and return its result + logInfo(s"Another thread is loading $id, waiting for it to finish...") + while (loading.contains(id)) { + try { + loading.wait() + } catch { + case e: Exception => + logWarning(s"Exception while waiting for another thread to load $id", e) + } + } + logInfo(s"Finished waiting for $id") + val values = blockManager.get(id) + if (!values.isDefined) { + /* The block is not guaranteed to exist even after the other thread has finished. + * For instance, the block could be evicted after it was put, but before our get. + * In this case, we still need to load the partition ourselves. */ + logInfo(s"Whoever was loading $id failed; we'll try it ourselves") + loading.add(id) + } + values.map(_.asInstanceOf[Iterator[T]]) + } + } + } + + /** + * Cache the values of a partition, keeping track of any updates in the storage statuses + * of other blocks along the way. + */ + private def putInBlockManager[T]( + key: BlockId, + values: Iterator[T], + storageLevel: StorageLevel, + updatedBlocks: ArrayBuffer[(BlockId, BlockStatus)]): Iterator[T] = { + + if (!storageLevel.useMemory) { + /* This RDD is not to be cached in memory, so we can just pass the computed values + * as an iterator directly to the BlockManager, rather than first fully unrolling + * it in memory. The latter option potentially uses much more memory and risks OOM + * exceptions that can be avoided. */ + updatedBlocks ++= blockManager.put(key, values, storageLevel, tellMaster = true) + blockManager.get(key) match { + case Some(v) => v.asInstanceOf[Iterator[T]] + case None => + logInfo(s"Failure to store $key") + throw new BlockException(key, s"Block manager failed to return cached value for $key!") + } + } else { + /* This RDD is to be cached in memory. In this case we cannot pass the computed values + * to the BlockManager as an iterator and expect to read it back later. This is because + * we may end up dropping a partition from memory store before getting it back, e.g. + * when the entirety of the RDD does not fit in memory. */ + val elements = new ArrayBuffer[Any] + elements ++= values + updatedBlocks ++= blockManager.put(key, elements, storageLevel, tellMaster = true) + elements.iterator.asInstanceOf[Iterator[T]] + } + } + } diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index bf3c3a6ceb5ef..9d7374774e9fa 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -150,7 +150,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { def doCleanupShuffle(shuffleId: Int, blocking: Boolean) { try { logDebug("Cleaning shuffle " + shuffleId) - mapOutputTrackerMaster.unregisterShuffle(shuffleId) + shuffleManager.unregisterShuffle(shuffleId) blockManagerMaster.removeShuffle(shuffleId, blocking) listeners.foreach(_.shuffleCleaned(shuffleId)) logInfo("Cleaned shuffle " + shuffleId) @@ -173,7 +173,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager - private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + private def shuffleManager = sc.env.shuffleManager // Used for testing. These methods explicitly blocks until cleanup is completed // to ensure that more reliable testing. diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index c8c194a111aac..09a60571238ea 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -61,7 +61,8 @@ class ShuffleDependency[K, V, C]( val partitioner: Partitioner, val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, - val aggregator: Option[Aggregator[K, V, C]] = None) + val aggregator: Option[Aggregator[K, V, C]] = None, + val mapSideCombine: Boolean = false) extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 8dfa8cc4b5b3f..4699354c39582 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -56,7 +56,6 @@ class SparkEnv ( val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, - val mapOutputTracker: MapOutputTracker, val shuffleManager: ShuffleManager, val broadcastManager: BroadcastManager, val blockManager: BlockManager, @@ -80,7 +79,6 @@ class SparkEnv ( private[spark] def stop() { pythonWorkers.foreach { case(key, worker) => worker.stop() } httpFileServer.stop() - mapOutputTracker.stop() shuffleManager.stop() broadcastManager.stop() blockManager.stop() @@ -202,24 +200,17 @@ object SparkEnv extends Logging { } } - val mapOutputTracker = if (isDriver) { - new MapOutputTrackerMaster(conf) - } else { - new MapOutputTrackerWorker(conf) - } + val shuffleManager = instantiateClass[ShuffleManager]( + "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") - // Have to assign trackerActor after initialization as MapOutputTrackerActor - // requires the MapOutputTracker itself - mapOutputTracker.trackerActor = registerOrLookup( - "MapOutputTracker", - new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + shuffleManager.initMapOutputTracker(conf, isDriver, actorSystem) val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf) val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, securityManager, mapOutputTracker) + serializer, conf, securityManager, shuffleManager) val connectionManager = blockManager.connectionManager @@ -247,9 +238,6 @@ object SparkEnv extends Logging { "." } - val shuffleManager = instantiateClass[ShuffleManager]( - "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") - // Warn about deprecated spark.cache.class property if (conf.contains("spark.cache.class")) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + @@ -262,7 +250,6 @@ object SparkEnv extends Logging { serializer, closureSerializer, cacheManager, - mapOutputTracker, shuffleManager, broadcastManager, blockManager, diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index a3074916d13e7..df42d679b4699 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -30,27 +30,69 @@ import org.apache.spark.storage.BlockManagerId @DeveloperApi sealed trait TaskEndReason +/** + * :: DeveloperApi :: + * Task succeeded. + */ @DeveloperApi case object Success extends TaskEndReason +/** + * :: DeveloperApi :: + * Various possible reasons why a task failed. + */ +@DeveloperApi +sealed trait TaskFailedReason extends TaskEndReason { + /** Error message displayed in the web UI. */ + def toErrorString: String +} + +/** + * :: DeveloperApi :: + * A [[org.apache.spark.scheduler.ShuffleMapTask]] that completed successfully earlier, but we + * lost the executor before the stage completed. This means Spark needs to reschedule the task + * to be re-executed on a different executor. + */ @DeveloperApi -case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it +case object Resubmitted extends TaskFailedReason { + override def toErrorString: String = "Resubmitted (resubmitted due to lost executor)" +} +/** + * :: DeveloperApi :: + * Task failed to fetch shuffle data from a remote node. Probably means we have lost the remote + * executors the task is trying to fetch from, and thus need to rerun the previous stage. + */ @DeveloperApi case class FetchFailed( - bmAddress: BlockManagerId, + bmAddress: BlockManagerId, // Note that bmAddress can be null shuffleId: Int, mapId: Int, reduceId: Int) - extends TaskEndReason + extends TaskFailedReason { + override def toErrorString: String = { + val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString + s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId)" + } +} +/** + * :: DeveloperApi :: + * Task failed due to a runtime exception. This is the most common failure case and also captures + * user program exceptions. + */ @DeveloperApi case class ExceptionFailure( className: String, description: String, stackTrace: Array[StackTraceElement], metrics: Option[TaskMetrics]) - extends TaskEndReason + extends TaskFailedReason { + override def toErrorString: String = { + val stackTraceString = if (stackTrace == null) "null" else stackTrace.mkString("\n") + s"$className ($description}\n$stackTraceString" + } +} /** * :: DeveloperApi :: @@ -58,10 +100,18 @@ case class ExceptionFailure( * it was fetched. */ @DeveloperApi -case object TaskResultLost extends TaskEndReason +case object TaskResultLost extends TaskFailedReason { + override def toErrorString: String = "TaskResultLost (result lost from block manager)" +} +/** + * :: DeveloperApi :: + * Task was killed intentionally and needs to be rescheduled. + */ @DeveloperApi -case object TaskKilled extends TaskEndReason +case object TaskKilled extends TaskFailedReason { + override def toErrorString: String = "TaskKilled (killed intentionally)" +} /** * :: DeveloperApi :: @@ -69,7 +119,9 @@ case object TaskKilled extends TaskEndReason * the task crashed the JVM. */ @DeveloperApi -case object ExecutorLostFailure extends TaskEndReason +case object ExecutorLostFailure extends TaskFailedReason { + override def toErrorString: String = "ExecutorLostFailure (executor lost)" +} /** * :: DeveloperApi :: @@ -77,4 +129,6 @@ case object ExecutorLostFailure extends TaskEndReason * deserializing the task result. */ @DeveloperApi -case object UnknownReason extends TaskEndReason +case object UnknownReason extends TaskFailedReason { + override def toErrorString: String = "UnknownReason" +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 14fa9d8135afe..4f3081433a542 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -543,6 +543,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) partitioner: Partitioner): JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner))) + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3], + partitioner: Partitioner) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, partitioner))) + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -558,6 +570,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2))) + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3]) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3))) + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -574,6 +597,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions))) + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3], + numPartitions: Int) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, numPartitions))) + /** Alias for cogroup. */ def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JIterable[V], JIterable[W])] = fromRDD(cogroupResultToJava(rdd.groupWith(other))) @@ -583,6 +618,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2))) + /** Alias for cogroup. */ + def groupWith[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3]) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.groupWith(other1, other2, other3))) + /** * Return the list of values in the RDD for key `key`. This operation is done efficiently if the * RDD has a known partitioner by only searching the partition that the key maps to. @@ -786,6 +828,15 @@ object JavaPairRDD { .mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3))) } + private[spark] + def cogroupResult3ToJava[K: ClassTag, V, W1, W2, W3]( + rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))]) + : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3]))] = { + rddToPairRDDFunctions(rdd) + .mapValues(x => + (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3), asJavaIterable(x._4))) + } + def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = { new JavaPairRDD[K, V](rdd) } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 330569a8d8837..f917cfd1419ec 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -43,8 +43,11 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def rdd: RDD[T] - /** Set of partitions in this RDD. */ + @deprecated("Use partitions() instead.", "1.1.0") def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + + /** Set of partitions in this RDD. */ + def partitions: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ def context: SparkContext = rdd.context diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala new file mode 100644 index 0000000000000..a0e8bd403a41d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import org.apache.spark.ui.SparkUI + +private[spark] case class ApplicationHistoryInfo( + id: String, + name: String, + startTime: Long, + endTime: Long, + lastUpdated: Long, + sparkUser: String) + +private[spark] abstract class ApplicationHistoryProvider { + + /** + * Returns a list of applications available for the history server to show. + * + * @return List of all know applications. + */ + def getListing(): Seq[ApplicationHistoryInfo] + + /** + * Returns the Spark UI for a specific application. + * + * @param appId The application ID. + * @return The application's UI, or null if application is not found. + */ + def getAppUI(appId: String): SparkUI + + /** + * Called when the server is shutting down. + */ + def stop(): Unit = { } + + /** + * Returns configuration data to be shown in the History Server home page. + * + * @return A map with the configuration data. Data is show in the order returned by the map. + */ + def getConfig(): Map[String, String] = Map() + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala new file mode 100644 index 0000000000000..a8c9ac072449f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import java.io.FileNotFoundException + +import scala.collection.mutable + +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.scheduler._ +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.Utils + +private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHistoryProvider + with Logging { + + // Interval between each check for event log updates + private val UPDATE_INTERVAL_MS = conf.getInt("spark.history.fs.updateInterval", + conf.getInt("spark.history.updateInterval", 10)) * 1000 + + private val logDir = conf.get("spark.history.fs.logDirectory", null) + if (logDir == null) { + throw new IllegalArgumentException("Logging directory must be specified.") + } + + private val fs = Utils.getHadoopFileSystem(logDir) + + // A timestamp of when the disk was last accessed to check for log updates + private var lastLogCheckTimeMs = -1L + + // List of applications, in order from newest to oldest. + @volatile private var appList: Seq[ApplicationHistoryInfo] = Nil + + /** + * A background thread that periodically checks for event log updates on disk. + * + * If a log check is invoked manually in the middle of a period, this thread re-adjusts the + * time at which it performs the next log check to maintain the same period as before. + * + * TODO: Add a mechanism to update manually. + */ + private val logCheckingThread = new Thread("LogCheckingThread") { + override def run() = Utils.logUncaughtExceptions { + while (true) { + val now = getMonotonicTimeMs() + if (now - lastLogCheckTimeMs > UPDATE_INTERVAL_MS) { + Thread.sleep(UPDATE_INTERVAL_MS) + } else { + // If the user has manually checked for logs recently, wait until + // UPDATE_INTERVAL_MS after the last check time + Thread.sleep(lastLogCheckTimeMs + UPDATE_INTERVAL_MS - now) + } + checkForLogs() + } + } + } + + initialize() + + private def initialize() { + // Validate the log directory. + val path = new Path(logDir) + if (!fs.exists(path)) { + throw new IllegalArgumentException( + "Logging directory specified does not exist: %s".format(logDir)) + } + if (!fs.getFileStatus(path).isDir) { + throw new IllegalArgumentException( + "Logging directory specified is not a directory: %s".format(logDir)) + } + + checkForLogs() + logCheckingThread.setDaemon(true) + logCheckingThread.start() + } + + override def getListing() = appList + + override def getAppUI(appId: String): SparkUI = { + try { + val appLogDir = fs.getFileStatus(new Path(logDir, appId)) + loadAppInfo(appLogDir, true)._2 + } catch { + case e: FileNotFoundException => null + } + } + + override def getConfig(): Map[String, String] = + Map(("Event Log Location" -> logDir)) + + /** + * Builds the application list based on the current contents of the log directory. + * Tries to reuse as much of the data already in memory as possible, by not reading + * applications that haven't been updated since last time the logs were checked. + */ + private def checkForLogs() = { + lastLogCheckTimeMs = getMonotonicTimeMs() + logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs)) + try { + val logStatus = fs.listStatus(new Path(logDir)) + val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]() + val logInfos = logDirs.filter { + dir => fs.isFile(new Path(dir.getPath(), EventLoggingListener.APPLICATION_COMPLETE)) + } + + val currentApps = Map[String, ApplicationHistoryInfo]( + appList.map(app => (app.id -> app)):_*) + + // For any application that either (i) is not listed or (ii) has changed since the last time + // the listing was created (defined by the log dir's modification time), load the app's info. + // Otherwise just reuse what's already in memory. + val newApps = new mutable.ArrayBuffer[ApplicationHistoryInfo](logInfos.size) + for (dir <- logInfos) { + val curr = currentApps.getOrElse(dir.getPath().getName(), null) + if (curr == null || curr.lastUpdated < getModificationTime(dir)) { + try { + newApps += loadAppInfo(dir, false)._1 + } catch { + case e: Exception => logError(s"Failed to load app info from directory $dir.") + } + } else { + newApps += curr + } + } + + appList = newApps.sortBy { info => -info.endTime } + } catch { + case t: Throwable => logError("Exception in checking for event log updates", t) + } + } + + /** + * Parse the application's logs to find out the information we need to build the + * listing page. + * + * When creating the listing of available apps, there is no need to load the whole UI for the + * application. The UI is requested by the HistoryServer (by calling getAppInfo()) when the user + * clicks on a specific application. + * + * @param logDir Directory with application's log files. + * @param renderUI Whether to create the SparkUI for the application. + * @return A 2-tuple `(app info, ui)`. `ui` will be null if `renderUI` is false. + */ + private def loadAppInfo(logDir: FileStatus, renderUI: Boolean) = { + val elogInfo = EventLoggingListener.parseLoggingInfo(logDir.getPath(), fs) + val path = logDir.getPath + val appId = path.getName + val replayBus = new ReplayListenerBus(elogInfo.logPaths, fs, elogInfo.compressionCodec) + val appListener = new ApplicationEventListener + replayBus.addListener(appListener) + + val ui: SparkUI = if (renderUI) { + val conf = this.conf.clone() + val appSecManager = new SecurityManager(conf) + new SparkUI(conf, appSecManager, replayBus, appId, "/history/" + appId) + // Do not call ui.bind() to avoid creating a new server for each application + } else { + null + } + + replayBus.replay() + val appInfo = ApplicationHistoryInfo( + appId, + appListener.appName, + appListener.startTime, + appListener.endTime, + getModificationTime(logDir), + appListener.sparkUser) + + if (ui != null) { + val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) + ui.getSecurityManager.setUIAcls(uiAclsEnabled) + ui.getSecurityManager.setViewAcls(appListener.sparkUser, appListener.viewAcls) + } + (appInfo, ui) + } + + /** Return when this directory was last modified. */ + private def getModificationTime(dir: FileStatus): Long = { + try { + val logFiles = fs.listStatus(dir.getPath) + if (logFiles != null && !logFiles.isEmpty) { + logFiles.map(_.getModificationTime).max + } else { + dir.getModificationTime + } + } catch { + case t: Throwable => + logError("Exception in accessing modification time of %s".format(dir.getPath), t) + -1L + } + } + + /** Returns the system's mononotically increasing time. */ + private def getMonotonicTimeMs() = System.nanoTime() / (1000 * 1000) + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 180c853ce3096..a958c837c2ff6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -25,20 +25,36 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { + private val pageSize = 20 + def render(request: HttpServletRequest): Seq[Node] = { - val appRows = parent.appIdToInfo.values.toSeq.sortBy { app => -app.lastUpdated } - val appTable = UIUtils.listingTable(appHeader, appRow, appRows) + val requestedPage = Option(request.getParameter("page")).getOrElse("1").toInt + val requestedFirst = (requestedPage - 1) * pageSize + + val allApps = parent.getApplicationList() + val actualFirst = if (requestedFirst < allApps.size) requestedFirst else 0 + val apps = allApps.slice(actualFirst, Math.min(actualFirst + pageSize, allApps.size)) + + val actualPage = (actualFirst / pageSize) + 1 + val last = Math.min(actualFirst + pageSize, allApps.size) - 1 + val pageCount = allApps.size / pageSize + (if (allApps.size % pageSize > 0) 1 else 0) + + val appTable = UIUtils.listingTable(appHeader, appRow, apps) + val providerConfig = parent.getProviderConfig() val content =
    -
  • Event Log Location: {parent.baseLogDir}
  • + { providerConfig.map(e =>
  • {e._1}: {e._2}
  • ) }
{ - if (parent.appIdToInfo.size > 0) { + if (allApps.size > 0) {

- Showing {parent.appIdToInfo.size}/{parent.getNumApplications} - Completed Application{if (parent.getNumApplications > 1) "s" else ""} + Showing {actualFirst + 1}-{last + 1} of {allApps.size} + + {if (actualPage > 1) <} + {if (actualPage < pageCount) >} +

++ appTable } else { @@ -56,26 +72,20 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { "Completed", "Duration", "Spark User", - "Log Directory", "Last Updated") private def appRow(info: ApplicationHistoryInfo): Seq[Node] = { - val appName = if (info.started) info.name else info.logDirPath.getName - val uiAddress = parent.getAddress + info.ui.basePath - val startTime = if (info.started) UIUtils.formatDate(info.startTime) else "Not started" - val endTime = if (info.completed) UIUtils.formatDate(info.endTime) else "Not completed" - val difference = if (info.started && info.completed) info.endTime - info.startTime else -1L - val duration = if (difference > 0) UIUtils.formatDuration(difference) else "---" - val sparkUser = if (info.started) info.sparkUser else "Unknown user" - val logDirectory = info.logDirPath.getName + val uiAddress = "/history/" + info.id + val startTime = UIUtils.formatDate(info.startTime) + val endTime = UIUtils.formatDate(info.endTime) + val duration = UIUtils.formatDuration(info.endTime - info.startTime) val lastUpdated = UIUtils.formatDate(info.lastUpdated) - {appName} + {info.name} {startTime} {endTime} {duration} - {sparkUser} - {logDirectory} + {info.sparkUser} {lastUpdated} } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index a9c11dca5678e..29a78a56c8ed5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -17,14 +17,15 @@ package org.apache.spark.deploy.history -import scala.collection.mutable +import java.util.NoSuchElementException +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} -import org.apache.hadoop.fs.{FileStatus, Path} +import com.google.common.cache._ +import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.scheduler._ -import org.apache.spark.ui.{WebUI, SparkUI} +import org.apache.spark.ui.{WebUI, SparkUI, UIUtils} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.Utils @@ -38,56 +39,68 @@ import org.apache.spark.util.Utils * application's event logs are maintained in the application's own sub-directory. This * is the same structure as maintained in the event log write code path in * EventLoggingListener. - * - * @param baseLogDir The base directory in which event logs are found */ class HistoryServer( - val baseLogDir: String, + conf: SparkConf, + provider: ApplicationHistoryProvider, securityManager: SecurityManager, - conf: SparkConf) - extends WebUI(securityManager, HistoryServer.WEB_UI_PORT, conf) with Logging { - - import HistoryServer._ + port: Int) + extends WebUI(securityManager, port, conf) with Logging { - private val fileSystem = Utils.getHadoopFileSystem(baseLogDir) - private val localHost = Utils.localHostName() - private val publicHost = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHost) + // How many applications to retain + private val retainedApplications = conf.getInt("spark.history.retainedApplications", 50) - // A timestamp of when the disk was last accessed to check for log updates - private var lastLogCheckTime = -1L + private val appLoader = new CacheLoader[String, SparkUI] { + override def load(key: String): SparkUI = { + val ui = provider.getAppUI(key) + if (ui == null) { + throw new NoSuchElementException() + } + attachSparkUI(ui) + ui + } + } - // Number of completed applications found in this directory - private var numCompletedApplications = 0 + private val appCache = CacheBuilder.newBuilder() + .maximumSize(retainedApplications) + .removalListener(new RemovalListener[String, SparkUI] { + override def onRemoval(rm: RemovalNotification[String, SparkUI]) = { + detachSparkUI(rm.getValue()) + } + }) + .build(appLoader) + + private val loaderServlet = new HttpServlet { + protected override def doGet(req: HttpServletRequest, res: HttpServletResponse): Unit = { + val parts = Option(req.getPathInfo()).getOrElse("").split("/") + if (parts.length < 2) { + res.sendError(HttpServletResponse.SC_BAD_REQUEST, + s"Unexpected path info in request (URI = ${req.getRequestURI()}") + return + } - @volatile private var stopped = false + val appId = parts(1) - /** - * A background thread that periodically checks for event log updates on disk. - * - * If a log check is invoked manually in the middle of a period, this thread re-adjusts the - * time at which it performs the next log check to maintain the same period as before. - * - * TODO: Add a mechanism to update manually. - */ - private val logCheckingThread = new Thread { - override def run(): Unit = Utils.logUncaughtExceptions { - while (!stopped) { - val now = System.currentTimeMillis - if (now - lastLogCheckTime > UPDATE_INTERVAL_MS) { - checkForLogs() - Thread.sleep(UPDATE_INTERVAL_MS) - } else { - // If the user has manually checked for logs recently, wait until - // UPDATE_INTERVAL_MS after the last check time - Thread.sleep(lastLogCheckTime + UPDATE_INTERVAL_MS - now) + // Note we don't use the UI retrieved from the cache; the cache loader above will register + // the app's UI, and all we need to do is redirect the user to the same URI that was + // requested, and the proper data should be served at that point. + try { + appCache.get(appId) + res.sendRedirect(res.encodeRedirectURL(req.getRequestURI())) + } catch { + case e: Exception => e.getCause() match { + case nsee: NoSuchElementException => + val msg =
Application {appId} not found.
+ res.setStatus(HttpServletResponse.SC_NOT_FOUND) + UIUtils.basicSparkPage(msg, "Not Found").foreach( + n => res.getWriter().write(n.toString)) + + case cause: Exception => throw cause } } } } - // A mapping of application ID to its history information, which includes the rendered UI - val appIdToInfo = mutable.HashMap[String, ApplicationHistoryInfo]() - initialize() /** @@ -98,108 +111,23 @@ class HistoryServer( */ def initialize() { attachPage(new HistoryPage(this)) - attachHandler(createStaticHandler(STATIC_RESOURCE_DIR, "/static")) + attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + + val contextHandler = new ServletContextHandler + contextHandler.setContextPath("/history") + contextHandler.addServlet(new ServletHolder(loaderServlet), "/*") + attachHandler(contextHandler) } /** Bind to the HTTP server behind this web interface. */ override def bind() { super.bind() - logCheckingThread.start() - } - - /** - * Check for any updates to event logs in the base directory. This is only effective once - * the server has been bound. - * - * If a new completed application is found, the server renders the associated SparkUI - * from the application's event logs, attaches this UI to itself, and stores metadata - * information for this application. - * - * If the logs for an existing completed application are no longer found, the server - * removes all associated information and detaches the SparkUI. - */ - def checkForLogs() = synchronized { - if (serverInfo.isDefined) { - lastLogCheckTime = System.currentTimeMillis - logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTime)) - try { - val logStatus = fileSystem.listStatus(new Path(baseLogDir)) - val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]() - val logInfos = logDirs - .sortBy { dir => getModificationTime(dir) } - .map { dir => (dir, EventLoggingListener.parseLoggingInfo(dir.getPath, fileSystem)) } - .filter { case (dir, info) => info.applicationComplete } - - // Logging information for applications that should be retained - val retainedLogInfos = logInfos.takeRight(RETAINED_APPLICATIONS) - val retainedAppIds = retainedLogInfos.map { case (dir, _) => dir.getPath.getName } - - // Remove any applications that should no longer be retained - appIdToInfo.foreach { case (appId, info) => - if (!retainedAppIds.contains(appId)) { - detachSparkUI(info.ui) - appIdToInfo.remove(appId) - } - } - - // Render the application's UI if it is not already there - retainedLogInfos.foreach { case (dir, info) => - val appId = dir.getPath.getName - if (!appIdToInfo.contains(appId)) { - renderSparkUI(dir, info) - } - } - - // Track the total number of completed applications observed this round - numCompletedApplications = logInfos.size - - } catch { - case e: Exception => logError("Exception in checking for event log updates", e) - } - } else { - logWarning("Attempted to check for event log updates before binding the server.") - } - } - - /** - * Render a new SparkUI from the event logs if the associated application is completed. - * - * HistoryServer looks for a special file that indicates application completion in the given - * directory. If this file exists, the associated application is regarded to be completed, in - * which case the server proceeds to render the SparkUI. Otherwise, the server does nothing. - */ - private def renderSparkUI(logDir: FileStatus, elogInfo: EventLoggingInfo) { - val path = logDir.getPath - val appId = path.getName - val replayBus = new ReplayListenerBus(elogInfo.logPaths, fileSystem, elogInfo.compressionCodec) - val appListener = new ApplicationEventListener - replayBus.addListener(appListener) - val appConf = conf.clone() - val appSecManager = new SecurityManager(appConf) - val ui = new SparkUI(conf, appSecManager, replayBus, appId, "/history/" + appId) - - // Do not call ui.bind() to avoid creating a new server for each application - replayBus.replay() - if (appListener.applicationStarted) { - appSecManager.setUIAcls(HISTORY_UI_ACLS_ENABLED) - appSecManager.setViewAcls(appListener.sparkUser, appListener.viewAcls) - attachSparkUI(ui) - val appName = appListener.appName - val sparkUser = appListener.sparkUser - val startTime = appListener.startTime - val endTime = appListener.endTime - val lastUpdated = getModificationTime(logDir) - ui.setAppName(appName + " (completed)") - appIdToInfo(appId) = ApplicationHistoryInfo(appId, appName, startTime, endTime, - lastUpdated, sparkUser, path, ui) - } } /** Stop the server and close the file system. */ override def stop() { super.stop() - stopped = true - fileSystem.close() + provider.stop() } /** Attach a reconstructed UI to this server. Only valid after bind(). */ @@ -215,27 +143,20 @@ class HistoryServer( ui.getHandlers.foreach(detachHandler) } - /** Return the address of this server. */ - def getAddress: String = "http://" + publicHost + ":" + boundPort + /** + * Returns a list of available applications, in descending order according to their end time. + * + * @return List of all known applications. + */ + def getApplicationList() = provider.getListing() - /** Return the number of completed applications found, whether or not the UI is rendered. */ - def getNumApplications: Int = numCompletedApplications + /** + * Returns the provider configuration to show in the listing page. + * + * @return A map with the provider's configuration. + */ + def getProviderConfig() = provider.getConfig() - /** Return when this directory was last modified. */ - private def getModificationTime(dir: FileStatus): Long = { - try { - val logFiles = fileSystem.listStatus(dir.getPath) - if (logFiles != null && !logFiles.isEmpty) { - logFiles.map(_.getModificationTime).max - } else { - dir.getModificationTime - } - } catch { - case e: Exception => - logError("Exception in accessing modification time of %s".format(dir.getPath), e) - -1L - } - } } /** @@ -251,30 +172,31 @@ class HistoryServer( object HistoryServer { private val conf = new SparkConf - // Interval between each check for event log updates - val UPDATE_INTERVAL_MS = conf.getInt("spark.history.updateInterval", 10) * 1000 - - // How many applications to retain - val RETAINED_APPLICATIONS = conf.getInt("spark.history.retainedApplications", 250) - - // The port to which the web UI is bound - val WEB_UI_PORT = conf.getInt("spark.history.ui.port", 18080) - - // set whether to enable or disable view acls for all applications - val HISTORY_UI_ACLS_ENABLED = conf.getBoolean("spark.history.ui.acls.enable", false) - - val STATIC_RESOURCE_DIR = SparkUI.STATIC_RESOURCE_DIR - def main(argStrings: Array[String]) { initSecurity() - val args = new HistoryServerArguments(argStrings) + val args = new HistoryServerArguments(conf, argStrings) val securityManager = new SecurityManager(conf) - val server = new HistoryServer(args.logDir, securityManager, conf) + + val providerName = conf.getOption("spark.history.provider") + .getOrElse(classOf[FsHistoryProvider].getName()) + val provider = Class.forName(providerName) + .getConstructor(classOf[SparkConf]) + .newInstance(conf) + .asInstanceOf[ApplicationHistoryProvider] + + val port = conf.getInt("spark.history.ui.port", 18080) + + val server = new HistoryServer(conf, provider, securityManager, port) server.bind() + Runtime.getRuntime().addShutdownHook(new Thread("HistoryServerStopper") { + override def run() = { + server.stop() + } + }) + // Wait until the end of the world... or if the HistoryServer process is manually stopped while(true) { Thread.sleep(Int.MaxValue) } - server.stop() } def initSecurity() { @@ -291,17 +213,3 @@ object HistoryServer { } } - - -private[spark] case class ApplicationHistoryInfo( - id: String, - name: String, - startTime: Long, - endTime: Long, - lastUpdated: Long, - sparkUser: String, - logDirPath: Path, - ui: SparkUI) { - def started = startTime != -1 - def completed = endTime != -1 -} diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 943c061743dbd..be9361b754fc3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -17,17 +17,14 @@ package org.apache.spark.deploy.history -import java.net.URI - -import org.apache.hadoop.fs.Path - +import org.apache.spark.SparkConf import org.apache.spark.util.Utils /** * Command-line parser for the master. */ -private[spark] class HistoryServerArguments(args: Array[String]) { - var logDir = "" +private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) { + private var logDir: String = null parse(args.toList) @@ -45,32 +42,36 @@ private[spark] class HistoryServerArguments(args: Array[String]) { case _ => printUsageAndExit(1) } - validateLogDir() - } - - private def validateLogDir() { - if (logDir == "") { - System.err.println("Logging directory must be specified.") - printUsageAndExit(1) - } - val fileSystem = Utils.getHadoopFileSystem(new URI(logDir)) - val path = new Path(logDir) - if (!fileSystem.exists(path)) { - System.err.println("Logging directory specified does not exist: %s".format(logDir)) - printUsageAndExit(1) - } - if (!fileSystem.getFileStatus(path).isDir) { - System.err.println("Logging directory specified is not a directory: %s".format(logDir)) - printUsageAndExit(1) + if (logDir != null) { + conf.set("spark.history.fs.logDirectory", logDir) } } private def printUsageAndExit(exitCode: Int) { System.err.println( - "Usage: HistoryServer [options]\n" + - "\n" + - "Options:\n" + - " -d DIR, --dir DIR Location of event log files") + """ + |Usage: HistoryServer + | + |Configuration options can be set by setting the corresponding JVM system property. + |History Server options are always available; additional options depend on the provider. + | + |History Server options: + | + | spark.history.ui.port Port where server will listen for connections + | (default 18080) + | spark.history.acls.enable Whether to enable view acls for all applications + | (default false) + | spark.history.provider Name of history provider class (defaults to + | file system-based provider) + | spark.history.retainedApplications Max number of application UIs to keep loaded in memory + | (default 50) + |FsHistoryProvider options: + | + | spark.history.fs.logDirectory Directory where app logs are stored (required) + | spark.history.fs.updateInterval How often to reload log data from storage (in seconds, + | default 10) + |""".stripMargin) System.exit(exitCode) } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 6433aac1c23e0..467317dd9b44c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -77,6 +77,7 @@ private[spark] class ExecutorRunner( * @param message the exception message which caused the executor's death */ private def killProcess(message: Option[String]) { + var exitCode: Option[Int] = None if (process != null) { logInfo("Killing process!") process.destroy() @@ -87,9 +88,9 @@ private[spark] class ExecutorRunner( if (stderrAppender != null) { stderrAppender.stop() } - val exitCode = process.waitFor() - worker ! ExecutorStateChanged(appId, execId, state, message, Some(exitCode)) + exitCode = Some(process.waitFor()) } + worker ! ExecutorStateChanged(appId, execId, state, message, exitCode) } /** Stop this executor runner, including killing the process it launched */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 6a5ffb1b71bfb..b389cb546de6c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -120,7 +120,7 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") w
- UIUtils.basicSparkPage(content, logType + " log page for " + appId) + UIUtils.basicSparkPage(content, logType + " log page for " + appId.getOrElse("unknown app")) } /** Get the part of the log files given the offset and desired length of bytes */ diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 2279d77c91c89..b5fd334f40203 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,25 +19,26 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import akka.actor._ -import akka.remote._ +import scala.concurrent.Await -import org.apache.spark.{SparkEnv, Logging, SecurityManager, SparkConf} +import akka.actor.{Actor, ActorSelection, Props} +import akka.pattern.Patterns +import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent} + +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskDescription +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{AkkaUtils, Utils} private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, executorId: String, hostPort: String, - cores: Int) - extends Actor - with ExecutorBackend - with Logging { + cores: Int, + sparkProperties: Seq[(String, String)]) extends Actor with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") @@ -52,7 +53,7 @@ private[spark] class CoarseGrainedExecutorBackend( } override def receive = { - case RegisteredExecutor(sparkProperties) => + case RegisteredExecutor => logInfo("Successfully registered with driver") // Make this host instead of hostPort ? executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties, @@ -101,26 +102,33 @@ private[spark] object CoarseGrainedExecutorBackend { workerUrl: Option[String]) { SparkHadoopUtil.get.runAsSparkUser { () => - // Debug code - Utils.checkHost(hostname) - - val conf = new SparkConf - // Create a new ActorSystem to run the backend, because we can't create a - // SparkEnv / Executor before getting started with all our system properties, etc - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0, - conf, new SecurityManager(conf)) - // set it - val sparkHostPort = hostname + ":" + boundPort - actorSystem.actorOf( - Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, - sparkHostPort, cores), - name = "Executor") - workerUrl.foreach { - url => - actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") - } - actorSystem.awaitTermination() - + // Debug code + Utils.checkHost(hostname) + + // Bootstrap to fetch the driver's Spark properties. + val executorConf = new SparkConf + val (fetcher, _) = AkkaUtils.createActorSystem( + "driverPropsFetcher", hostname, 0, executorConf, new SecurityManager(executorConf)) + val driver = fetcher.actorSelection(driverUrl) + val timeout = AkkaUtils.askTimeout(executorConf) + val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) + val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] + fetcher.shutdown() + + // Create a new ActorSystem using driver's Spark properties to run the backend. + val driverConf = new SparkConf().setAll(props) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + "sparkExecutor", hostname, 0, driverConf, new SecurityManager(driverConf)) + // set it + val sparkHostPort = hostname + ":" + boundPort + actorSystem.actorOf( + Props(classOf[CoarseGrainedExecutorBackend], + driverUrl, executorId, sparkHostPort, cores, props), + name = "Executor") + workerUrl.foreach { url => + actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") + } + actorSystem.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index baee7a216a7c3..2debc4d75a546 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -26,8 +26,8 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler._ +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util.{AkkaUtils, Utils} @@ -180,7 +180,7 @@ private[spark] class Executor( attemptedTask = Some(task) logDebug("Task " + taskId + "'s epoch is " + task.epoch) - env.mapOutputTracker.updateEpoch(task.epoch) + env.shuffleManager.mapOutputTracker.updateEpoch(task.epoch) // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() @@ -212,7 +212,7 @@ private[spark] class Executor( val serializedDirectResult = ser.serialize(directResult) logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit) val serializedResult = { - if (serializedDirectResult.limit >= akkaFrameSize - 1024) { + if (serializedDirectResult.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { logInfo("Storing result for " + taskId + " in local BlockManager") val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala index 3d34960653f5d..e07cb31cbe4ba 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala @@ -27,3 +27,4 @@ import org.apache.spark.TaskState.TaskState private[spark] trait ExecutorBackend { def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) } + diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index 3b6298a26d7c5..5285ec82c1b64 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -17,11 +17,6 @@ package org.apache.spark.network -import org.apache.spark._ -import org.apache.spark.SparkSaslServer - -import scala.collection.mutable.{HashMap, Queue, ArrayBuffer} - import java.net._ import java.nio._ import java.nio.channels._ @@ -41,7 +36,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = { this(channel_, selector_, ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]), id_) + channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_) } channel.configureBlocking(false) @@ -89,7 +84,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, private def disposeSasl() { if (sparkSaslServer != null) { - sparkSaslServer.dispose(); + sparkSaslServer.dispose() } if (sparkSaslClient != null) { @@ -328,15 +323,13 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // Is highly unlikely unless there was an unclean close of socket, etc registerInterest() logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") - true } catch { case e: Exception => { logWarning("Error finishing connection to " + address, e) callOnExceptionCallback(e) - // ignore - return true } } + true } override def write(): Boolean = { @@ -546,7 +539,7 @@ private[spark] class ReceivingConnection( /* println("Filled buffer at " + System.currentTimeMillis) */ val bufferMessage = inbox.getMessageForChunk(currentChunk).get if (bufferMessage.isCompletelyReceived) { - bufferMessage.flip + bufferMessage.flip() bufferMessage.finishTime = System.currentTimeMillis logDebug("Finished receiving [" + bufferMessage + "] from " + "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index cf1c985c2fff9..8a1cdb812962e 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -249,7 +249,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, def run() { try { while(!selectorThread.isInterrupted) { - while (! registerRequests.isEmpty) { + while (!registerRequests.isEmpty) { val conn: SendingConnection = registerRequests.dequeue() addListeners(conn) conn.connect() @@ -308,7 +308,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, // Some keys within the selectors list are invalid/closed. clear them. val allKeys = selector.keys().iterator() - while (allKeys.hasNext()) { + while (allKeys.hasNext) { val key = allKeys.next() try { if (! key.isValid) { @@ -341,7 +341,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, if (0 != selectedKeysCount) { val selectedKeys = selector.selectedKeys().iterator() - while (selectedKeys.hasNext()) { + while (selectedKeys.hasNext) { val key = selectedKeys.next selectedKeys.remove() try { @@ -419,62 +419,63 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, connectionsByKey -= connection.key try { - if (connection.isInstanceOf[SendingConnection]) { - val sendingConnection = connection.asInstanceOf[SendingConnection] - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - logInfo("Removing SendingConnection to " + sendingConnectionManagerId) - - connectionsById -= sendingConnectionManagerId - connectionsAwaitingSasl -= connection.connectionId + connection match { + case sendingConnection: SendingConnection => + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + logInfo("Removing SendingConnection to " + sendingConnectionManagerId) + + connectionsById -= sendingConnectionManagerId + connectionsAwaitingSasl -= connection.connectionId + + messageStatuses.synchronized { + messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) + .foreach(status => { + logInfo("Notifying " + status) + status.synchronized { + status.attempted = true + status.acked = false + status.markDone() + } + }) - messageStatuses.synchronized { - messageStatuses - .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { - logInfo("Notifying " + status) - status.synchronized { - status.attempted = true - status.acked = false - status.markDone() - } + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId }) + } + case receivingConnection: ReceivingConnection => + val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() + logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - } else if (connection.isInstanceOf[ReceivingConnection]) { - val receivingConnection = connection.asInstanceOf[ReceivingConnection] - val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() - logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - - val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) - if (! sendingConnectionOpt.isDefined) { - logError("Corresponding SendingConnectionManagerId not found") - return - } + val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) + if (!sendingConnectionOpt.isDefined) { + logError("Corresponding SendingConnectionManagerId not found") + return + } - val sendingConnection = sendingConnectionOpt.get - connectionsById -= remoteConnectionManagerId - sendingConnection.close() + val sendingConnection = sendingConnectionOpt.get + connectionsById -= remoteConnectionManagerId + sendingConnection.close() - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - assert (sendingConnectionManagerId == remoteConnectionManagerId) + assert(sendingConnectionManagerId == remoteConnectionManagerId) - messageStatuses.synchronized { - for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { - logInfo("Notifying " + s) - s.synchronized { - s.attempted = true - s.acked = false - s.markDone() + messageStatuses.synchronized { + for (s <- messageStatuses.values + if s.connectionManagerId == sendingConnectionManagerId) { + logInfo("Notifying " + s) + s.synchronized { + s.attempted = true + s.acked = false + s.markDone() + } } - } - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId + }) + } + case _ => logError("Unsupported type of connection.") } } finally { // So that the selection keys can be removed. @@ -517,13 +518,13 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, logDebug("Client sasl completed for id: " + waitingConn.connectionId) connectionsAwaitingSasl -= waitingConn.connectionId waitingConn.getAuthenticated().synchronized { - waitingConn.getAuthenticated().notifyAll(); + waitingConn.getAuthenticated().notifyAll() } return } else { var replyToken : Array[Byte] = null try { - replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken); + replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken) if (waitingConn.isSaslComplete()) { logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) connectionsAwaitingSasl -= waitingConn.connectionId @@ -533,7 +534,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, return } val securityMsgResp = SecurityMessage.fromResponse(replyToken, - securityMsg.getConnectionId.toString()) + securityMsg.getConnectionId.toString) val message = securityMsgResp.toBufferMessage if (message == null) throw new Exception("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) @@ -630,13 +631,13 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, case bufferMessage: BufferMessage => { if (authEnabled) { val res = handleAuthentication(connection, bufferMessage) - if (res == true) { + if (res) { // message was security negotiation so skip the rest logDebug("After handleAuth result was true, returning") return } } - if (bufferMessage.hasAckId) { + if (bufferMessage.hasAckId()) { val sentMessageStatus = messageStatuses.synchronized { messageStatuses.get(bufferMessage.ackId) match { case Some(status) => { @@ -646,7 +647,6 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, case None => { throw new Exception("Could not find reference for received ack message " + message.id) - null } } } @@ -668,7 +668,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, if (ackMessage.isDefined) { if (!ackMessage.get.isInstanceOf[BufferMessage]) { logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " - + ackMessage.get.getClass()) + + ackMessage.get.getClass) } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { logDebug("Response to " + bufferMessage + " does not have ack id set") ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala index b82edb6850d23..57f7586883af1 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala @@ -32,6 +32,6 @@ private[spark] case class ConnectionManagerId(host: String, port: Int) { private[spark] object ConnectionManagerId { def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { - new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) + new ConnectionManagerId(socketAddress.getHostName, socketAddress.getPort) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 6a3f698444283..f1f4b4324edfd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -57,7 +57,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, */ def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = { val part = new RangePartitioner(numPartitions, self, ascending) - val shuffled = new ShuffledRDD[K, V, P](self, part) + val shuffled = new ShuffledRDD[K, V, V, P](self, part).setKeyOrdering(ordering) shuffled.mapPartitions(iter => { val buf = iter.toArray if (ascending) { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index fe36c80e0be84..fc9beb166befe 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -90,21 +90,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) self.mapPartitionsWithContext((context, iter) => { new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context)) }, preservesPartitioning = true) - } else if (mapSideCombine) { - val combined = self.mapPartitionsWithContext((context, iter) => { - aggregator.combineValuesByKey(iter, context) - }, preservesPartitioning = true) - val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner) - .setSerializer(serializer) - partitioned.mapPartitionsWithContext((context, iter) => { - new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter, context)) - }, preservesPartitioning = true) } else { - // Don't apply map-side combiner. - val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializer) - values.mapPartitionsWithContext((context, iter) => { - new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context)) - }, preservesPartitioning = true) + new ShuffledRDD[K, V, C, (K, C)](self, partitioner) + .setSerializer(serializer) + .setAggregator(aggregator) + .setMapSideCombine(mapSideCombine) } } @@ -401,7 +391,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) if (self.partitioner == Some(partitioner)) { self } else { - new ShuffledRDD[K, V, (K, V)](self, partitioner) + new ShuffledRDD[K, V, V, (K, V)](self, partitioner) } } @@ -567,6 +557,28 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) new FlatMappedValuesRDD(self, cleanF) } + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: RDD[(K, W1)], + other2: RDD[(K, W2)], + other3: RDD[(K, W3)], + partitioner: Partitioner) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + val cg = new CoGroupedRDD[K](Seq(self, other1, other2, other3), partitioner) + cg.mapValues { case Seq(vs, w1s, w2s, w3s) => + (vs.asInstanceOf[Seq[V]], + w1s.asInstanceOf[Seq[W1]], + w2s.asInstanceOf[Seq[W2]], + w3s.asInstanceOf[Seq[W3]]) + } + } + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -599,6 +611,16 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } } + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)]) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3)) + } + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -633,6 +655,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) cogroup(other1, other2, new HashPartitioner(numPartitions)) } + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: RDD[(K, W1)], + other2: RDD[(K, W2)], + other3: RDD[(K, W3)], + numPartitions: Int) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + cogroup(other1, other2, other3, new HashPartitioner(numPartitions)) + } + /** Alias for cogroup. */ def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = { cogroup(other, defaultPartitioner(self, other)) @@ -644,6 +679,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) cogroup(other1, other2, defaultPartitioner(self, other1, other2)) } + /** Alias for cogroup. */ + def groupWith[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)]) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3)) + } + /** * Return an RDD with the pairs from `this` whose keys are not in `other`. * @@ -721,7 +762,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) outputFormatClass: Class[_ <: NewOutputFormat[_, _]], conf: Configuration = self.context.hadoopConfiguration) { - val job = new NewAPIHadoopJob(conf) + // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). + val hadoopConf = conf + val job = new NewAPIHadoopJob(hadoopConf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) @@ -754,22 +797,25 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) outputFormatClass: Class[_ <: OutputFormat[_, _]], conf: JobConf = new JobConf(self.context.hadoopConfiguration), codec: Option[Class[_ <: CompressionCodec]] = None) { - conf.setOutputKeyClass(keyClass) - conf.setOutputValueClass(valueClass) + // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). + val hadoopConf = conf + hadoopConf.setOutputKeyClass(keyClass) + hadoopConf.setOutputValueClass(valueClass) // Doesn't work in Scala 2.9 due to what may be a generics bug // TODO: Should we uncomment this for Scala 2.10? // conf.setOutputFormat(outputFormatClass) - conf.set("mapred.output.format.class", outputFormatClass.getName) + hadoopConf.set("mapred.output.format.class", outputFormatClass.getName) for (c <- codec) { - conf.setCompressMapOutput(true) - conf.set("mapred.output.compress", "true") - conf.setMapOutputCompressorClass(c) - conf.set("mapred.output.compression.codec", c.getCanonicalName) - conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) + hadoopConf.setCompressMapOutput(true) + hadoopConf.set("mapred.output.compress", "true") + hadoopConf.setMapOutputCompressorClass(c) + hadoopConf.set("mapred.output.compression.codec", c.getCanonicalName) + hadoopConf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) } - conf.setOutputCommitter(classOf[FileOutputCommitter]) - FileOutputFormat.setOutputPath(conf, SparkHadoopWriter.createPathFromString(path, conf)) - saveAsHadoopDataset(conf) + hadoopConf.setOutputCommitter(classOf[FileOutputCommitter]) + FileOutputFormat.setOutputPath(hadoopConf, + SparkHadoopWriter.createPathFromString(path, hadoopConf)) + saveAsHadoopDataset(hadoopConf) } /** @@ -779,7 +825,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * configured for a Hadoop MapReduce job. */ def saveAsNewAPIHadoopDataset(conf: Configuration) { - val job = new NewAPIHadoopJob(conf) + // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). + val hadoopConf = conf + val job = new NewAPIHadoopJob(hadoopConf) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id @@ -836,9 +884,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * MapReduce job. */ def saveAsHadoopDataset(conf: JobConf) { - val outputFormatInstance = conf.getOutputFormat - val keyClass = conf.getOutputKeyClass - val valueClass = conf.getOutputValueClass + // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). + val hadoopConf = conf + val outputFormatInstance = hadoopConf.getOutputFormat + val keyClass = hadoopConf.getOutputKeyClass + val valueClass = hadoopConf.getOutputValueClass if (outputFormatInstance == null) { throw new SparkException("Output format class not set") } @@ -848,18 +898,18 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) if (valueClass == null) { throw new SparkException("Output value class not set") } - SparkHadoopUtil.get.addCredentials(conf) + SparkHadoopUtil.get.addCredentials(hadoopConf) logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + valueClass.getSimpleName + ")") if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) { // FileOutputFormat ignores the filesystem parameter - val ignoredFs = FileSystem.get(conf) - conf.getOutputFormat.checkOutputSpecs(ignoredFs, conf) + val ignoredFs = FileSystem.get(hadoopConf) + hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf) } - val writer = new SparkHadoopWriter(conf) + val writer = new SparkHadoopWriter(hadoopConf) writer.preSetup() def writeToFile(context: TaskContext, iter: Iterator[(K, V)]) { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index cebfd109d825f..4e841bc992bff 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -340,7 +340,7 @@ abstract class RDD[T: ClassTag]( // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( - new ShuffledRDD[Int, T, (Int, T)](mapPartitionsWithIndex(distributePartition), + new ShuffledRDD[Int, T, T, (Int, T)](mapPartitionsWithIndex(distributePartition), new HashPartitioner(numPartitions)), numPartitions).values } else { diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index bb108ef163c56..bf02f68d0d3d3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag -import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext} +import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer @@ -35,23 +35,48 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { * @param part the partitioner used to partition the RDD * @tparam K the key class. * @tparam V the value class. + * @tparam C the combiner class. */ @DeveloperApi -class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( - @transient var prev: RDD[P], +class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( + @transient var prev: RDD[_ <: Product2[K, V]], part: Partitioner) extends RDD[P](prev.context, Nil) { private var serializer: Option[Serializer] = None + private var keyOrdering: Option[Ordering[K]] = None + + private var aggregator: Option[Aggregator[K, V, C]] = None + + private var mapSideCombine: Boolean = false + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ - def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = { + def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = { this.serializer = Option(serializer) this } + /** Set key ordering for RDD's shuffle. */ + def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C, P] = { + this.keyOrdering = Option(keyOrdering) + this + } + + /** Set aggregator for RDD's shuffle. */ + def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C, P] = { + this.aggregator = Option(aggregator) + this + } + + /** Set mapSideCombine flag for RDD's shuffle. */ + def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C, P] = { + this.mapSideCombine = mapSideCombine + this + } + override def getDependencies: Seq[Dependency[_]] = { - List(new ShuffleDependency(prev, part, serializer)) + List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine)) } override val partitioner = Some(part) @@ -61,7 +86,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( } override def compute(split: Partition, context: TaskContext): Iterator[P] = { - val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]] + val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) .read() .asInstanceOf[Iterator[P]] diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b3ebaa547de0d..d7f20bb1d1bd8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.{NotSerializableException, PrintWriter, StringWriter} +import java.io.{NotSerializableException} import java.util.Properties import java.util.concurrent.atomic.AtomicInteger @@ -37,6 +37,7 @@ import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.{MapStatus, ShuffleManager} import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId} import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils} @@ -59,7 +60,7 @@ class DAGScheduler( private[scheduler] val sc: SparkContext, private[scheduler] val taskScheduler: TaskScheduler, listenerBus: LiveListenerBus, - mapOutputTracker: MapOutputTrackerMaster, + shuffleManager: ShuffleManager, blockManagerMaster: BlockManagerMaster, env: SparkEnv, clock: Clock = SystemClock) @@ -72,7 +73,7 @@ class DAGScheduler( sc, taskScheduler, sc.listenerBus, - sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], + sc.env.shuffleManager, sc.env.blockManager.master, sc.env) } @@ -241,18 +242,17 @@ class DAGScheduler( : Stage = { val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) - if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { - val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) - val locs = MapOutputTracker.deserializeMapStatuses(serLocs) - for (i <- 0 until locs.size) { - stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing + if (shuffleManager.containsShuffle(shuffleDep.shuffleId)) { + val mapStatusArray = shuffleManager.getShuffleMetadata(shuffleDep.shuffleId) + for (i <- 0 until mapStatusArray.size) { + stage.outputLocs(i) = Option(mapStatusArray(i)).toList // locs(i) will be null if missing } - stage.numAvailableOutputs = locs.count(_ != null) + stage.numAvailableOutputs = mapStatusArray.count(_ != null) } else { - // Kind of ugly: need to register RDDs with the cache and map output tracker here + // Kind of ugly: need to register RDDs with the cache and shuffleManager here // since we can't do it in the RDD constructor because # of partitions is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")") - mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size) + shuffleManager.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size, shuffleDep) } stage } @@ -866,7 +866,7 @@ class DAGScheduler( // epoch incremented to refetch them. // TODO: Only increment the epoch number if this is not the first time // we registered these map outputs. - mapOutputTracker.registerMapOutputs( + shuffleManager.registerMapOutputs( stage.shuffleDep.get.shuffleId, stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, changeEpoch = true) @@ -915,7 +915,7 @@ class DAGScheduler( val mapStage = shuffleToMapStage(shuffleId) if (mapId != -1) { mapStage.removeOutputLoc(mapId, bmAddress) - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + shuffleManager.unregisterMapOutput(shuffleId, mapId, bmAddress) } logInfo("The failed fetch was from " + mapStage + " (" + mapStage.name + "); marking it for resubmission") @@ -955,7 +955,7 @@ class DAGScheduler( * stray fetch failures from possibly retriggering the detection of a node as lost. */ private[scheduler] def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) { - val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) + val currentEpoch = maybeEpoch.getOrElse(shuffleManager.mapOutputTracker.getEpoch) if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { failedEpoch(execId) = currentEpoch logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch)) @@ -964,10 +964,10 @@ class DAGScheduler( for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnExecutor(execId) val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray - mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) + shuffleManager.registerMapOutputs(shuffleId, locs, changeEpoch = true) } if (shuffleToMapStage.isEmpty) { - mapOutputTracker.incrementEpoch() + shuffleManager.mapOutputTracker.incrementEpoch() } clearCacheLocs() } else { @@ -1038,7 +1038,7 @@ class DAGScheduler( private def failJobAndIndependentStages(job: ActiveJob, failureReason: String, resultStage: Option[Stage]) { val error = new SparkException(failureReason) - job.listener.jobFailed(error) + var ableToCancelStages = true val shouldInterruptThread = if (job.properties == null) false @@ -1062,18 +1062,26 @@ class DAGScheduler( // This is the only job that uses this stage, so fail the stage if it is running. val stage = stageIdToStage(stageId) if (runningStages.contains(stage)) { - taskScheduler.cancelTasks(stageId, shouldInterruptThread) - val stageInfo = stageToInfos(stage) - stageInfo.stageFailed(failureReason) - listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage))) + try { // cancelTasks will fail if a SchedulerBackend does not implement killTask + taskScheduler.cancelTasks(stageId, shouldInterruptThread) + val stageInfo = stageToInfos(stage) + stageInfo.stageFailed(failureReason) + listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage))) + } catch { + case e: UnsupportedOperationException => + logInfo(s"Could not cancel tasks for stage $stageId", e) + ableToCancelStages = false + } } } } } - cleanupStateForJobAndIndependentStages(job, resultStage) - - listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) + if (ableToCancelStages) { + job.listener.jobFailed(error) + cleanupStateForJobAndIndependentStages(job, resultStage) + listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) + } } /** @@ -1155,7 +1163,11 @@ private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler) case x: Exception => logError("eventProcesserActor failed due to the error %s; shutting down SparkContext" .format(x.getMessage)) - dagScheduler.doCancelAllJobs() + try { + dagScheduler.doCancelAllJobs() + } catch { + case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) + } dagScheduler.sc.stop() Stop } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 0e8d551e4b2ab..bbf9f7388b074 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -17,11 +17,12 @@ package org.apache.spark.scheduler +import scala.language.existentials + import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.HashMap -import scala.language.existentials import org.apache.spark._ import org.apache.spark.rdd.{RDD, RDDCheckpointData} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 0098b5a59d1a5..1a02075346912 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -25,11 +25,8 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.rdd.{RDD, RDDCheckpointData} -import org.apache.spark.serializer.Serializer -import org.apache.spark.storage._ -import org.apache.spark.shuffle.ShuffleWriter +import org.apache.spark.shuffle.{ShuffleWriter, MapStatus} private[spark] object ShuffleMapTask { @@ -147,9 +144,7 @@ private[spark] class ShuffleMapTask( try { val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) - for (elem <- rdd.iterator(split, context)) { - writer.write(elem.asInstanceOf[Product2[Any, Any]]) - } + writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) return writer.stop(success = true).get } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 378cf1aaebe7b..82163eadd56e9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -75,9 +75,11 @@ case class SparkListenerBlockManagerRemoved(blockManagerId: BlockManagerId) @DeveloperApi case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent +@DeveloperApi case class SparkListenerApplicationStart(appName: String, time: Long, sparkUser: String) extends SparkListenerEvent +@DeveloperApi case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent /** An event used in the listener to shutdown the listener daemon thread. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 9a4be43ee219f..286e3e4db5c93 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import org.apache.spark._ import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.MapStatus import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite @@ -106,6 +107,8 @@ private[spark] class Stage( id } + def attemptId: Int = nextAttemptId + val name = callSite.short val details = callSite.long diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index 1481d70db42e1..4c96b9e5fef60 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -21,6 +21,10 @@ import java.nio.ByteBuffer import org.apache.spark.util.SerializableBuffer +/** + * Description of a task that gets passed onto executors to be executed, usually created by + * [[TaskSetManager.resourceOffer]]. + */ private[spark] class TaskDescription( val taskId: Long, val executorId: String, diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 4c62e4dc0bac8..6aecdfe8e6656 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -27,10 +27,12 @@ import org.apache.spark.annotation.DeveloperApi class TaskInfo( val taskId: Long, val index: Int, + val attempt: Int, val launchTime: Long, val executorId: String, val host: String, - val taskLocality: TaskLocality.TaskLocality) { + val taskLocality: TaskLocality.TaskLocality, + val speculative: Boolean) { /** * The time when the task started remotely getting the result. Will not be set if the diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 17292b4c15b8b..aeac0c758a7eb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -95,7 +95,7 @@ private[spark] class TaskSchedulerImpl( var backend: SchedulerBackend = null - val mapOutputTracker = SparkEnv.get.mapOutputTracker + val shuffleManager = SparkEnv.get.shuffleManager var schedulableBuilder: SchedulableBuilder = null var rootPool: Pool = null @@ -210,11 +210,14 @@ private[spark] class TaskSchedulerImpl( SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname + // Also track if new executor is added + var newExecAvail = false for (o <- offers) { executorIdToHost(o.executorId) = o.host if (!executorsByHost.contains(o.host)) { executorsByHost(o.host) = new HashSet[String]() executorAdded(o.executorId, o.host) + newExecAvail = true } } @@ -227,12 +230,15 @@ private[spark] class TaskSchedulerImpl( for (taskSet <- sortedTaskSets) { logDebug("parentName: %s, name: %s, runningTasks: %s".format( taskSet.parent.name, taskSet.name, taskSet.runningTasks)) + if (newExecAvail) { + taskSet.executorAdded() + } } // Take each TaskSet in our scheduling order, and then offer it each node in increasing order // of locality levels so that it gets a chance to launch local tasks on all of them. var launchedTask = false - for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) { + for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) { do { launchedTask = false for (i <- 0 until shuffledOffers.size) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index f3bd0797aa035..259addbcd19ee 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -118,7 +118,7 @@ private[spark] class TaskSetManager( private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] // Set containing pending tasks with no locality preferences. - val pendingTasksWithNoPrefs = new ArrayBuffer[Int] + var pendingTasksWithNoPrefs = new ArrayBuffer[Int] // Set containing all pending tasks (also used as a stack, as above). val allPendingTasks = new ArrayBuffer[Int] @@ -140,7 +140,7 @@ private[spark] class TaskSetManager( val recentExceptions = HashMap[String, (Int, Long)]() // Figure out the current map output tracker epoch and set it on all tasks - val epoch = sched.mapOutputTracker.getEpoch + val epoch = sched.shuffleManager.mapOutputTracker.getEpoch logDebug("Epoch for " + taskSet + ": " + epoch) for (t <- tasks) { t.epoch = epoch @@ -153,8 +153,8 @@ private[spark] class TaskSetManager( } // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling - val myLocalityLevels = computeValidLocalityLevels() - val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level + var myLocalityLevels = computeValidLocalityLevels() + var localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level // Delay scheduling variables: we keep track of our current locality level and the time we // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. @@ -181,16 +181,14 @@ private[spark] class TaskSetManager( var hadAliveLocations = false for (loc <- tasks(index).preferredLocations) { for (execId <- loc.executorId) { - if (sched.isExecutorAlive(execId)) { - addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) - hadAliveLocations = true - } + addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) } if (sched.hasExecutorsAliveOnHost(loc.host)) { - addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) - for (rack <- sched.getRackForHost(loc.host)) { - addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) - } + hadAliveLocations = true + } + addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) + for (rack <- sched.getRackForHost(loc.host)) { + addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) hadAliveLocations = true } } @@ -337,17 +335,19 @@ private[spark] class TaskSetManager( /** * Dequeue a pending task for a given node and return its index and locality level. * Only search for tasks matching the given locality constraint. + * + * @return An option containing (task index within the task set, locality, is speculative?) */ private def findTask(execId: String, host: String, locality: TaskLocality.Value) - : Option[(Int, TaskLocality.Value)] = + : Option[(Int, TaskLocality.Value, Boolean)] = { for (index <- findTaskFromList(execId, getPendingTasksForExecutor(execId))) { - return Some((index, TaskLocality.PROCESS_LOCAL)) + return Some((index, TaskLocality.PROCESS_LOCAL, false)) } if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { for (index <- findTaskFromList(execId, getPendingTasksForHost(host))) { - return Some((index, TaskLocality.NODE_LOCAL)) + return Some((index, TaskLocality.NODE_LOCAL, false)) } } @@ -356,23 +356,25 @@ private[spark] class TaskSetManager( rack <- sched.getRackForHost(host) index <- findTaskFromList(execId, getPendingTasksForRack(rack)) } { - return Some((index, TaskLocality.RACK_LOCAL)) + return Some((index, TaskLocality.RACK_LOCAL, false)) } } // Look for no-pref tasks after rack-local tasks since they can run anywhere. for (index <- findTaskFromList(execId, pendingTasksWithNoPrefs)) { - return Some((index, TaskLocality.PROCESS_LOCAL)) + return Some((index, TaskLocality.PROCESS_LOCAL, false)) } if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { for (index <- findTaskFromList(execId, allPendingTasks)) { - return Some((index, TaskLocality.ANY)) + return Some((index, TaskLocality.ANY, false)) } } // Finally, if all else has failed, find a speculative task - findSpeculativeTask(execId, host, locality) + findSpeculativeTask(execId, host, locality).map { case (taskIndex, allowedLocality) => + (taskIndex, allowedLocality, true) + } } /** @@ -393,7 +395,7 @@ private[spark] class TaskSetManager( } findTask(execId, host, allowedLocality) match { - case Some((index, taskLocality)) => { + case Some((index, taskLocality, speculative)) => { // Found a task; do some bookkeeping and return a task description val task = tasks(index) val taskId = sched.newTaskId() @@ -402,7 +404,9 @@ private[spark] class TaskSetManager( taskSet.id, index, taskId, execId, host, taskLocality)) // Do various bookkeeping copiesRunning(index) += 1 - val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality) + val attemptNum = taskAttempts(index).size + val info = new TaskInfo( + taskId, index, attemptNum + 1, curTime, execId, host, taskLocality, speculative) taskInfos(taskId) = info taskAttempts(index) = info :: taskAttempts(index) // Update our locality level for delay scheduling @@ -643,7 +647,9 @@ private[spark] class TaskSetManager( addPendingTask(index, readding=true) } - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage. + // The reason is the next stage wouldn't be able to fetch the data from this dead executor + // so we would need to rerun these tasks on other executors. if (tasks(0).isInstanceOf[ShuffleMapTask]) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index @@ -725,10 +731,12 @@ private[spark] class TaskSetManager( private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} val levels = new ArrayBuffer[TaskLocality.TaskLocality] - if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) { + if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0 && + pendingTasksForExecutor.keySet.exists(sched.isExecutorAlive(_))) { levels += PROCESS_LOCAL } - if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) { + if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0 && + pendingTasksForHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) { levels += NODE_LOCAL } if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) { @@ -738,4 +746,21 @@ private[spark] class TaskSetManager( logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", ")) levels.toArray } + + // Re-compute pendingTasksWithNoPrefs since new preferred locations may become available + def executorAdded() { + def newLocAvail(index: Int): Boolean = { + for (loc <- tasks(index).preferredLocations) { + if (sched.hasExecutorsAliveOnHost(loc.host) || + sched.getRackForHost(loc.host).isDefined) { + return true + } + } + false + } + logInfo("Re-computing pending task lists.") + pendingTasksWithNoPrefs = pendingTasksWithNoPrefs.filter(!newLocAvail(_)) + myLocalityLevels = computeValidLocalityLevels() + localityWaits = myLocalityLevels.map(getLocalityWait) + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index ca74069ef885c..318e16552201c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -20,21 +20,21 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler.TaskDescription import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable private[spark] object CoarseGrainedClusterMessages { + case object RetrieveSparkProps extends CoarseGrainedClusterMessage + // Driver to executors case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage case class KillTask(taskId: Long, executor: String, interruptThread: Boolean) extends CoarseGrainedClusterMessage - case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) - extends CoarseGrainedClusterMessage + case object RegisteredExecutor extends CoarseGrainedClusterMessage case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index e47a060683a2d..05d01b0c821f9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -75,7 +75,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) } else { logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredExecutor(sparkProperties) + sender ! RegisteredExecutor executorActor(executorId) = sender executorHost(executorId) = Utils.parseHostPort(hostPort)._1 totalCores(executorId) = cores @@ -124,6 +124,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated")) + case RetrieveSparkProps => + sender ! sparkProperties } // Make fake resource offers on all executors @@ -143,14 +145,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A for (task <- tasks.flatten) { val ser = SparkEnv.get.closureSerializer.newInstance() val serializedTask = ser.serialize(task) - if (serializedTask.limit >= akkaFrameSize - 1024) { + if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { val taskSetId = scheduler.taskIdToTaskSetId(task.taskId) scheduler.activeTaskSets.get(taskSetId).foreach { taskSet => try { - var msg = "Serialized task %s:%d was %d bytes which " + - "exceeds spark.akka.frameSize (%d bytes). " + - "Consider using broadcast variables for large values." - msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize) + var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + + "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " + + "spark.akka.frameSize or using broadcast variables for large values." + msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize, + AkkaUtils.reservedSizeBytes) taskSet.abort(msg) } catch { case e: Exception => logError("Exception in error callback", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index a089a02d42170..c717e7c621a8f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -185,8 +185,8 @@ private[spark] class MesosSchedulerBackend( synchronized { // Build a big list of the offerable workers, and remember their indices so that we can // figure out which Offer to reply to for each worker - val offerableIndices = new ArrayBuffer[Int] val offerableWorkers = new ArrayBuffer[WorkerOffer] + val offerableIndices = new HashMap[String, Int] def enoughMemory(o: Offer) = { val mem = getResource(o.getResourcesList, "mem") @@ -195,7 +195,7 @@ private[spark] class MesosSchedulerBackend( } for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) { - offerableIndices += index + offerableIndices.put(offer.getSlaveId.getValue, index) offerableWorkers += new WorkerOffer( offer.getSlaveId.getValue, offer.getHostname, @@ -206,14 +206,13 @@ private[spark] class MesosSchedulerBackend( val taskLists = scheduler.resourceOffers(offerableWorkers) // Build a list of Mesos tasks for each slave - val mesosTasks = offers.map(o => Collections.emptyList[MesosTaskInfo]()) + val mesosTasks = offers.map(o => new JArrayList[MesosTaskInfo]()) for ((taskList, index) <- taskLists.zipWithIndex) { if (!taskList.isEmpty) { - val offerNum = offerableIndices(index) - val slaveId = offers(offerNum).getSlaveId.getValue - slaveIdsWithExecutors += slaveId - mesosTasks(offerNum) = new JArrayList[MesosTaskInfo](taskList.size) for (taskDesc <- taskList) { + val slaveId = taskDesc.executorId + val offerNum = offerableIndices(slaveId) + slaveIdsWithExecutors += slaveId taskIdToSlaveId(taskDesc.taskId) = slaveId mesosTasks(offerNum).add(createMesosTask(taskDesc, slaveId)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 43f0e18a0cbe0..9b95ccca0443e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -97,7 +97,8 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: localActor ! ReviveOffers } - override def defaultParallelism() = totalCores + override def defaultParallelism() = + scheduler.conf.getInt("spark.default.parallelism", totalCores) override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { localActor ! KillTask(taskId, interruptThread) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 5286f7b4c211a..5aec0ec05f08d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -27,7 +27,7 @@ import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.spark._ import org.apache.spark.broadcast.HttpBroadcast -import org.apache.spark.scheduler.MapStatus +import org.apache.spark.shuffle.MapStatus import org.apache.spark.storage._ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock} @@ -64,6 +64,9 @@ class KryoSerializer(conf: SparkConf) kryo.register(cls) } + // For results returned by asJavaIterable. See JavaIterableWrapperSerializer. + kryo.register(JavaIterableWrapperSerializer.wrapperClass, new JavaIterableWrapperSerializer) + // Allow sending SerializableWritable kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) @@ -183,3 +186,50 @@ private[serializer] object KryoSerializer { classOf[Array[Byte]] ) } + +/** + * A Kryo serializer for serializing results returned by asJavaIterable. + * + * The underlying object is scala.collection.convert.Wrappers$IterableWrapper. + * Kryo deserializes this into an AbstractCollection, which unfortunately doesn't work. + */ +private class JavaIterableWrapperSerializer + extends com.esotericsoftware.kryo.Serializer[java.lang.Iterable[_]] { + + import JavaIterableWrapperSerializer._ + + override def write(kryo: Kryo, out: KryoOutput, obj: java.lang.Iterable[_]): Unit = { + // If the object is the wrapper, simply serialize the underlying Scala Iterable object. + // Otherwise, serialize the object itself. + if (obj.getClass == wrapperClass && underlyingMethodOpt.isDefined) { + kryo.writeClassAndObject(out, underlyingMethodOpt.get.invoke(obj)) + } else { + kryo.writeClassAndObject(out, obj) + } + } + + override def read(kryo: Kryo, in: KryoInput, clz: Class[java.lang.Iterable[_]]) + : java.lang.Iterable[_] = { + kryo.readClassAndObject(in) match { + case scalaIterable: Iterable[_] => + scala.collection.JavaConversions.asJavaIterable(scalaIterable) + case javaIterable: java.lang.Iterable[_] => + javaIterable + } + } +} + +private object JavaIterableWrapperSerializer extends Logging { + // The class returned by asJavaIterable (scala.collection.convert.Wrappers$IterableWrapper). + val wrapperClass = + scala.collection.convert.WrapAsJava.asJavaIterable(Seq(1)).getClass + + // Get the underlying method so we can use it to get the Scala collection for serialization. + private val underlyingMethodOpt = { + try Some(wrapperClass.getDeclaredMethod("underlying")) catch { + case e: Exception => + logError("Failed to find the underlying field in " + wrapperClass, e) + None + } + } +} diff --git a/core/src/main/scala/org/apache/spark/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala similarity index 50% rename from core/src/main/scala/org/apache/spark/FetchFailedException.scala rename to core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 8eaa26bdb1b5b..71c08e9d5a8c3 100644 --- a/core/src/main/scala/org/apache/spark/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -15,31 +15,38 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.shuffle import org.apache.spark.storage.BlockManagerId +import org.apache.spark.{FetchFailed, TaskEndReason} +/** + * Failed to fetch a shuffle block. The executor catches this exception and propagates it + * back to DAGScheduler (through TaskEndReason) so we'd resubmit the previous stage. + * + * Note that bmAddress can be null. + */ private[spark] class FetchFailedException( - taskEndReason: TaskEndReason, - message: String, - cause: Throwable) + bmAddress: BlockManagerId, + shuffleId: Int, + mapId: Int, + reduceId: Int) extends Exception { - def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, - cause: Throwable) = - this(FetchFailed(bmAddress, shuffleId, mapId, reduceId), - "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId), - cause) - - def this (shuffleId: Int, reduceId: Int, cause: Throwable) = - this(FetchFailed(null, shuffleId, -1, reduceId), - "Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause) - - override def getMessage(): String = message + override def getMessage: String = + "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId) + def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId) +} - override def getCause(): Throwable = cause - - def toTaskEndReason: TaskEndReason = taskEndReason +/** + * Failed to get shuffle metadata from [[org.apache.spark.MapOutputTracker]]. + */ +private[spark] class MetadataFetchFailedException( + shuffleId: Int, + reduceId: Int, + message: String) + extends FetchFailedException(null, shuffleId, -1, reduceId) { + override def getMessage: String = message } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/shuffle/MapOutputTracker.scala similarity index 96% rename from core/src/main/scala/org/apache/spark/MapOutputTracker.scala rename to core/src/main/scala/org/apache/spark/shuffle/MapOutputTracker.scala index ee82d9fa7874b..e8a29efcfb04f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/MapOutputTracker.scala @@ -15,19 +15,23 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.shuffle import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashSet, HashMap, Map} +import scala.collection.mutable.{HashMap, HashSet, Map} import scala.concurrent.Await import akka.actor._ import akka.pattern.ask -import org.apache.spark.scheduler.MapStatus -import org.apache.spark.storage.BlockManagerId + +import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.util._ +import org.apache.spark.storage.BlockManagerId + +import scala.collection.mutable +import scala.concurrent.Await private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) @@ -168,8 +172,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } } else { - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing all output locations for shuffle " + shuffleId)) + throw new MetadataFetchFailedException( + shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId) } } else { statuses.synchronized { @@ -185,6 +189,13 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } + def incrementEpoch() { + epochLock.synchronized { + epoch += 1 + logDebug("Increasing epoch to " + epoch) + } + } + /** * Called from executors to update the epoch number, potentially clearing old outputs * because of a fetch failure. Each worker task calls this with the latest epoch @@ -279,13 +290,6 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) } - def incrementEpoch() { - epochLock.synchronized { - epoch += 1 - logDebug("Increasing epoch to " + epoch) - } - } - def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = { var statuses: Array[MapStatus] = null var epochGotten: Long = -1 @@ -371,8 +375,8 @@ private[spark] object MapOutputTracker { statuses.map { status => if (status == null) { - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing an output location for shuffle " + shuffleId)) + throw new MetadataFetchFailedException( + shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) } else { (status.location, decompressSize(status.compressedSizes(reduceId))) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/shuffle/MapStatus.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala rename to core/src/main/scala/org/apache/spark/shuffle/MapStatus.scala index d3f63ff92ac6f..041fe7ea601eb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/MapStatus.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.scheduler +package org.apache.spark.shuffle import java.io.{Externalizable, ObjectInput, ObjectOutput} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index 9c859b8b4a118..c38dce9c703ea 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -17,24 +17,112 @@ package org.apache.spark.shuffle -import org.apache.spark.{TaskContext, ShuffleDependency} +import scala.concurrent.Await + +import akka.actor.{Props, ActorSystem} + +import org.apache.spark._ +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.{AkkaUtils, Utils} + /** * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on both the * driver and executors, based on the spark.shuffle.manager setting. The driver registers shuffles * with it, and executors (or tasks running locally in the driver) can ask to read and write data. - * - * NOTE: this will be instantiated by SparkEnv so its constructor can take a SparkConf and - * boolean isDriver as parameters. */ private[spark] trait ShuffleManager { + + protected var isDriver: Boolean = true + protected[spark] var mapOutputTracker: MapOutputTracker = _ + + /** + * initialize the mapOutputTracker + */ + def initMapOutputTracker(conf: SparkConf, isDriver: Boolean, actorSystem: ActorSystem) { + this.isDriver = isDriver + if (isDriver) { + val masterCls = Class.forName(conf.get("spark.shuffle.mapOutputTrackerMasterClass", + "org.apache.spark.shuffle.MapOutputTrackerMaster")) + mapOutputTracker = masterCls.getConstructor(classOf[SparkConf]).newInstance(conf). + asInstanceOf[MapOutputTrackerMaster] + mapOutputTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], + conf)), "MapOutputTracker") + } else { + val workerCls = Class.forName(conf.get("spark.shuffle.mapOutputTrackerWorkerClass", + "org.apache.spark.shuffle.MapOutputTrackerWorker")) + mapOutputTracker = workerCls.getConstructor(classOf[SparkConf]).newInstance(conf). + asInstanceOf[MapOutputTrackerWorker] + val driverHost: String = conf.get("spark.driver.host", "localhost") + val driverPort: Int = conf.getInt("spark.driver.port", 7077) + Utils.checkHost(driverHost, "Expected hostname") + val url = s"akka.tcp://spark@$driverHost:$driverPort/user/MapOutputTracker" + val timeout = AkkaUtils.lookupTimeout(conf) + mapOutputTracker.trackerActor = Await.result( + actorSystem.actorSelection(url).resolveOne(timeout), timeout) + } + } + /** * Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ def registerShuffle[K, V, C]( shuffleId: Int, numMaps: Int, - dependency: ShuffleDependency[K, V, C]): ShuffleHandle + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + if (isDriver) { + mapOutputTracker.asInstanceOf[MapOutputTrackerMaster].registerShuffle(shuffleId, numMaps) + } + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + def unregisterShuffle(shuffleId: Int) { + mapOutputTracker.unregisterShuffle(shuffleId) + } + + def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { + if (isDriver) { + mapOutputTracker.asInstanceOf[MapOutputTrackerMaster].registerMapOutput(shuffleId, mapId, + status) + } + } + + /** Register multiple map output information for the given shuffle */ + def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) { + if (isDriver) { + mapOutputTracker.asInstanceOf[MapOutputTrackerMaster].registerMapOutputs( + shuffleId, statuses, changeEpoch) + } + } + + def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { + if (isDriver) { + mapOutputTracker.asInstanceOf[MapOutputTrackerMaster].unregisterMapOutput(shuffleId, mapId, + bmAddress) + } + } + + def containsShuffle(shuffleId: Int): Boolean = { + if (isDriver) { + mapOutputTracker.asInstanceOf[MapOutputTrackerMaster].containsShuffle(shuffleId) + } else { + false + } + } + + // TODO: MapStatus should be customizable + def getShuffleMetadata(shuffleId: Int): Array[MapStatus] = { + if (isDriver) { + val serLocs = mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]. + getSerializedMapOutputStatuses(shuffleId) + MapOutputTracker.deserializeMapStatuses(serLocs) + } else { + null + } + } + /** Get a writer for a given partition. Called on executors by map tasks. */ def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] @@ -49,9 +137,8 @@ private[spark] trait ShuffleManager { endPartition: Int, context: TaskContext): ShuffleReader[K, C] - /** Remove a shuffle's metadata from the ShuffleManager. */ - def unregisterShuffle(shuffleId: Int) - /** Shut down this ShuffleManager. */ - def stop(): Unit + def stop() { + mapOutputTracker.stop() + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index ead3ebd652ca5..17111827da815 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -17,14 +17,12 @@ package org.apache.spark.shuffle -import org.apache.spark.scheduler.MapStatus - /** * Obtained inside a map task to write out records to the shuffle system. */ private[spark] trait ShuffleWriter[K, V] { - /** Write a record to this task's output */ - def write(record: Product2[K, V]): Unit + /** Write a bunch of records to this task's output */ + def write(records: Iterator[_ <: Product2[K, V]]): Unit /** Close this writer, passing along whether the map completed */ def stop(success: Boolean): Option[MapStatus] diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index b05b6ea345df3..cc44a34ee5fa8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -20,11 +20,12 @@ package org.apache.spark.shuffle.hash import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import org.apache.spark._ import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator -import org.apache.spark._ private[hash] object BlockStoreShuffleFetcher extends Logging { def fetch[T]( @@ -38,7 +39,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { val blockManager = SparkEnv.get.blockManager val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) + val statuses = SparkEnv.get.shuffleManager.mapOutputTracker.getServerStatuses(shuffleId, + reduceId) logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) @@ -63,7 +65,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { blockId match { case ShuffleBlockId(shufId, mapId, _) => val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null) + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId) case _ => throw new SparkException( "Failed to get block " + blockId + ", which is not a shuffle block") diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 5b0940ecce29d..12456d23cd2a1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -25,13 +25,6 @@ import org.apache.spark.shuffle._ * mapper (possibly reusing these across waves of tasks). */ class HashShuffleManager(conf: SparkConf) extends ShuffleManager { - /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ - override def registerShuffle[K, V, C]( - shuffleId: Int, - numMaps: Int, - dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - new BaseShuffleHandle(shuffleId, numMaps, dependency) - } /** * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). @@ -51,10 +44,4 @@ class HashShuffleManager(conf: SparkConf) extends ShuffleManager { : ShuffleWriter[K, V] = { new HashShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) } - - /** Remove a shuffle's metadata from the ShuffleManager. */ - override def unregisterShuffle(shuffleId: Int): Unit = {} - - /** Shut down this ShuffleManager. */ - override def stop(): Unit = {} } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index f6a790309a587..d45258c0a492b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,9 +17,9 @@ package org.apache.spark.shuffle.hash +import org.apache.spark.{InterruptibleIterator, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} -import org.apache.spark.TaskContext class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], @@ -31,10 +31,24 @@ class HashShuffleReader[K, C]( require(endPartition == startPartition + 1, "Hash shuffle currently only supports fetching one partition") + private val dep = handle.dependency + /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, - Serializer.getSerializer(handle.dependency.serializer)) + val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, + Serializer.getSerializer(dep.serializer)) + + if (dep.aggregator.isDefined) { + if (dep.mapSideCombine) { + new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context)) + } else { + new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context)) + } + } else if (dep.aggregator.isEmpty && dep.mapSideCombine) { + throw new IllegalStateException("Aggregator is empty for map-side combine") + } else { + iter + } } /** Close this reader */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 4c6749098c110..5fdde436010c3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -17,12 +17,11 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} -import org.apache.spark.{Logging, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.shuffle.{MapStatus, MapOutputTracker, BaseShuffleHandle, ShuffleWriter} +import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.storage.{BlockObjectWriter} import org.apache.spark.serializer.Serializer import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.scheduler.MapStatus class HashShuffleWriter[K, V]( handle: BaseShuffleHandle[K, V, _], @@ -37,14 +36,27 @@ class HashShuffleWriter[K, V]( private val blockManager = SparkEnv.get.blockManager private val shuffleBlockManager = blockManager.shuffleBlockManager - private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null)) + private val ser = Serializer.getSerializer(dep.serializer.orNull) private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser) - /** Write a record to this task's output */ - override def write(record: Product2[K, V]): Unit = { - val pair = record.asInstanceOf[Product2[Any, Any]] - val bucketId = dep.partitioner.getPartition(pair._1) - shuffle.writers(bucketId).write(pair) + /** Write a bunch of records to this task's output */ + override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + val iter = if (dep.aggregator.isDefined) { + if (dep.mapSideCombine) { + dep.aggregator.get.combineValuesByKey(records, context) + } else { + records + } + } else if (dep.aggregator.isEmpty && dep.mapSideCombine) { + throw new IllegalStateException("Aggregator is empty for map-side combine") + } else { + records + } + + for (elem <- iter) { + val bucketId = dep.partitioner.getPartition(elem._1) + shuffle.writers(bucketId).write(elem) + } } /** Close this writer, passing along whether the map completed */ @@ -56,7 +68,7 @@ class HashShuffleWriter[K, V]( stopping = true if (success) { try { - return Some(commitWritesAndBuildStatus()) + Some(commitWritesAndBuildStatus()) } catch { case e: Exception => revertWrites() @@ -64,7 +76,7 @@ class HashShuffleWriter[K, V]( } } else { revertWrites() - return None + None } } finally { // Release the writers back to the shuffle block manager. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d2f7baf928b62..0606db4e281f9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -30,6 +30,7 @@ import sun.nio.ch.DirectBuffer import org.apache.spark._ import org.apache.spark.io.CompressionCodec +import org.apache.spark.shuffle.{ShuffleManager} import org.apache.spark.network._ import org.apache.spark.serializer.Serializer import org.apache.spark.util._ @@ -47,7 +48,7 @@ private[spark] class BlockManager( maxMemory: Long, val conf: SparkConf, securityManager: SecurityManager, - mapOutputTracker: MapOutputTracker) + shuffleManager: ShuffleManager) extends Logging { val shuffleBlockManager = new ShuffleBlockManager(this) @@ -98,7 +99,7 @@ private[spark] class BlockManager( private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) private val slaveActor = actorSystem.actorOf( - Props(new BlockManagerSlaveActor(this, mapOutputTracker)), + Props(new BlockManagerSlaveActor(this, shuffleManager)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) // Pending re-registration action being executed asynchronously or null if none is pending. @@ -139,9 +140,9 @@ private[spark] class BlockManager( serializer: Serializer, conf: SparkConf, securityManager: SecurityManager, - mapOutputTracker: MapOutputTracker) = { + shuffleManager: ShuffleManager) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, securityManager, mapOutputTracker) + conf, securityManager, shuffleManager) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 6d4db064dff58..7e56541be9485 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -21,8 +21,9 @@ import scala.concurrent.Future import akka.actor.{ActorRef, Actor} -import org.apache.spark.{Logging, MapOutputTracker} +import org.apache.spark.Logging import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.shuffle.ShuffleManager /** * An actor to take commands from the master to execute options. For example, @@ -31,7 +32,7 @@ import org.apache.spark.storage.BlockManagerMessages._ private[storage] class BlockManagerSlaveActor( blockManager: BlockManager, - mapOutputTracker: MapOutputTracker) + shuffleManager: ShuffleManager) extends Actor with Logging { import context.dispatcher @@ -51,8 +52,8 @@ class BlockManagerSlaveActor( case RemoveShuffle(shuffleId) => doAsync[Boolean]("removing shuffle " + shuffleId, sender) { - if (mapOutputTracker != null) { - mapOutputTracker.unregisterShuffle(shuffleId) + if (shuffleManager != null) { + shuffleManager.unregisterShuffle(shuffleId) } blockManager.shuffleBlockManager.removeShuffle(shuffleId) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala index 9a9be047c7245..b9b53b1a2f118 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -24,11 +24,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging /** - * Abstract class to store blocks + * Abstract class to store blocks. */ -private[spark] -abstract class BlockStore(val blockManager: BlockManager) extends Logging { - def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) : PutResult +private[spark] abstract class BlockStore(val blockManager: BlockManager) extends Logging { + + def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult /** * Put in a block and, possibly, also return its content as either bytes or another Iterator. @@ -37,11 +37,17 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging { * @return a PutResult that contains the size of the data, as well as the values put if * returnValues is true (if not, the result's data field can be null) */ - def putValues(blockId: BlockId, values: Iterator[Any], level: StorageLevel, - returnValues: Boolean) : PutResult + def putValues( + blockId: BlockId, + values: Iterator[Any], + level: StorageLevel, + returnValues: Boolean): PutResult - def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, - returnValues: Boolean) : PutResult + def putValues( + blockId: BlockId, + values: ArrayBuffer[Any], + level: StorageLevel, + returnValues: Boolean): PutResult /** * Return the size of a block in bytes. diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 084a566c48560..71f66c826c5b3 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -58,11 +58,11 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) val elements = new ArrayBuffer[Any] elements ++= values val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) - tryToPut(blockId, elements, sizeEstimate, true) - PutResult(sizeEstimate, Left(values.toIterator)) + val putAttempt = tryToPut(blockId, elements, sizeEstimate, deserialized = true) + PutResult(sizeEstimate, Left(values.toIterator), putAttempt.droppedBlocks) } else { - tryToPut(blockId, bytes, bytes.limit, false) - PutResult(bytes.limit(), Right(bytes.duplicate())) + val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false) + PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index a107c5182b3be..ccf13e639df54 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -22,9 +22,10 @@ import java.util.concurrent.ArrayBlockingQueue import akka.actor._ import util.Random -import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.shuffle.hash.HashShuffleManager /** * This class tests the BlockManager and MemoryStore for thread safety and @@ -99,9 +100,10 @@ private[spark] object ThreadingTest { val blockManagerMaster = new BlockManagerMaster( actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), conf) + val shuffleManager = new HashShuffleManager(conf) val blockManager = new BlockManager( "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, - new SecurityManager(conf), new MapOutputTrackerMaster(conf)) + new SecurityManager(conf), shuffleManager) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index b3ac2320f3431..a2535e3c1c41f 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -194,11 +194,16 @@ private[spark] object JettyUtils extends Logging { case s: Success[_] => (server, server.getConnectors.head.getLocalPort) case f: Failure[_] => + val nextPort = (currentPort + 1) % 65536 server.stop() pool.stop() - logInfo("Failed to create UI at port, %s. Trying again.".format(currentPort)) - logInfo("Error was: " + f.toString) - connect((currentPort + 1) % 65536) + val msg = s"Failed to create UI on port $currentPort. Trying again on port $nextPort." + if (f.toString.contains("Address already in use")) { + logWarning(s"$msg - $f") + } else { + logError(msg, f.exception) + } + connect(nextPort) } } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index b08f308fda1dd..856273e1d4e21 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -51,6 +51,7 @@ private[spark] abstract class WebUI( def getTabs: Seq[WebUITab] = tabs.toSeq def getHandlers: Seq[ServletContextHandler] = handlers.toSeq + def getSecurityManager: SecurityManager = securityManager /** Attach a tab to this UI, along with all of its attached pages. */ def attachTab(tab: WebUITab) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 396cbcbc8d268..381a5443df8b5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui.jobs import scala.collection.mutable.{HashMap, ListBuffer} -import org.apache.spark.{ExceptionFailure, SparkConf, SparkContext, Success} +import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -51,6 +51,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { var totalShuffleRead = 0L var totalShuffleWrite = 0L + // TODO: Should probably consolidate all following into a single hash map. val stageIdToTime = HashMap[Int, Long]() val stageIdToShuffleRead = HashMap[Int, Long]() val stageIdToShuffleWrite = HashMap[Int, Long]() @@ -183,14 +184,17 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { // Remove by taskId, rather than by TaskInfo, in case the TaskInfo is from storage tasksActive.remove(info.taskId) - val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = + val (errorMessage, metrics): (Option[String], Option[TaskMetrics]) = taskEnd.reason match { - case e: ExceptionFailure => - stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1 - (Some(e), e.metrics) - case _ => + case org.apache.spark.Success => stageIdToTasksComplete(sid) = stageIdToTasksComplete.getOrElse(sid, 0) + 1 (None, Option(taskEnd.taskMetrics)) + case e: ExceptionFailure => // Handle ExceptionFailure because we might have metrics + stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1 + (Some(e.toErrorString), e.metrics) + case e: TaskFailedReason => // All other failure cases + stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1 + (Some(e.toErrorString), None) } stageIdToTime.getOrElseUpdate(sid, 0L) @@ -218,7 +222,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { stageIdToDiskBytesSpilled(sid) += diskBytesSpilled val taskMap = stageIdToTaskData.getOrElse(sid, HashMap[Long, TaskUIData]()) - taskMap(info.taskId) = new TaskUIData(info, metrics, failureInfo) + taskMap(info.taskId) = new TaskUIData(info, metrics, errorMessage) stageIdToTaskData(sid) = taskMap } } @@ -253,7 +257,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { case class TaskUIData( taskInfo: TaskInfo, taskMetrics: Option[TaskMetrics] = None, - exception: Option[ExceptionFailure] = None) + errorMessage: Option[String] = None) private object JobProgressListener { val DEFAULT_POOL_NAME = "default" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 4bce472036f7d..8e3d5d1cd4c6b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -95,8 +95,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
// scalastyle:on val taskHeaders: Seq[String] = - Seq("Task Index", "Task ID", "Status", "Locality Level", "Executor", "Launch Time") ++ - Seq("Duration", "GC Time", "Result Ser Time") ++ + Seq( + "Index", "ID", "Attempt", "Status", "Locality Level", "Executor", + "Launch Time", "Duration", "GC Time") ++ {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++ {if (hasBytesSpilled) Seq("Shuffle Spill (Memory)", "Shuffle Spill (Disk)") else Nil} ++ @@ -210,10 +211,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { def taskRow(shuffleRead: Boolean, shuffleWrite: Boolean, bytesSpilled: Boolean) (taskData: TaskUIData): Seq[Node] = { - def fmtStackTrace(trace: Seq[StackTraceElement]): Seq[Node] = - trace.map(e => {e.toString}) - - taskData match { case TaskUIData(info, metrics, exception) => + taskData match { case TaskUIData(info, metrics, errorMessage) => val duration = if (info.status == "RUNNING") info.timeRunning(System.currentTimeMillis()) else metrics.map(_.executorRunTime).getOrElse(1L) val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) @@ -248,6 +246,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {info.index} {info.taskId} + { + if (info.speculative) s"${info.attempt} (speculative)" else info.attempt.toString + } {info.status} {info.taskLocality} {info.host} @@ -258,9 +259,12 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} + {if (shuffleRead) { {shuffleReadReadable} @@ -283,12 +287,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { }} - {exception.map { e => - - {e.className} ({e.description})
- {fmtStackTrace(e.stackTrace)} -
- }.getOrElse("")} + {errorMessage.map { e =>
{e}
}.getOrElse("")} } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index a3f824a4e1f57..30971f769682f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -91,13 +91,13 @@ private[ui] class StageTableBase( {s.name} - val details = if (s.details.nonEmpty) ( + val details = if (s.details.nonEmpty) { +show details - ) + } listener.stageIdToDescription.get(s.stageId) .map(d =>
{d}
{nameLink} {killLink}
) diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index a8d12bb2a0165..9930c717492f2 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -121,4 +121,7 @@ private[spark] object AkkaUtils extends Logging { def maxFrameSizeBytes(conf: SparkConf): Int = { conf.getInt("spark.akka.frameSize", 10) * 1024 * 1024 } + + /** Space reserved for extra data in an Akka message besides serialized task or task result. */ + val reservedSizeBytes = 200 * 1024 } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 7cecbfe62a382..6245b4b8023c2 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -32,6 +32,8 @@ import org.apache.spark.storage._ import org.apache.spark._ private[spark] object JsonProtocol { + // TODO: Remove this file and put JSON serialization into each individual class. + private implicit val format = DefaultFormats /** ------------------------------------------------- * @@ -194,10 +196,12 @@ private[spark] object JsonProtocol { def taskInfoToJson(taskInfo: TaskInfo): JValue = { ("Task ID" -> taskInfo.taskId) ~ ("Index" -> taskInfo.index) ~ + ("Attempt" -> taskInfo.attempt) ~ ("Launch Time" -> taskInfo.launchTime) ~ ("Executor ID" -> taskInfo.executorId) ~ ("Host" -> taskInfo.host) ~ ("Locality" -> taskInfo.taskLocality.toString) ~ + ("Speculative" -> taskInfo.speculative) ~ ("Getting Result Time" -> taskInfo.gettingResultTime) ~ ("Finish Time" -> taskInfo.finishTime) ~ ("Failed" -> taskInfo.failed) ~ @@ -487,16 +491,19 @@ private[spark] object JsonProtocol { def taskInfoFromJson(json: JValue): TaskInfo = { val taskId = (json \ "Task ID").extract[Long] val index = (json \ "Index").extract[Int] + val attempt = (json \ "Attempt").extractOpt[Int].getOrElse(1) val launchTime = (json \ "Launch Time").extract[Long] val executorId = (json \ "Executor ID").extract[String] val host = (json \ "Host").extract[String] val taskLocality = TaskLocality.withName((json \ "Locality").extract[String]) + val speculative = (json \ "Speculative").extractOpt[Boolean].getOrElse(false) val gettingResultTime = (json \ "Getting Result Time").extract[Long] val finishTime = (json \ "Finish Time").extract[Long] val failed = (json \ "Failed").extract[Boolean] val serializedSize = (json \ "Serialized Size").extract[Int] - val taskInfo = new TaskInfo(taskId, index, launchTime, executorId, host, taskLocality) + val taskInfo = + new TaskInfo(taskId, index, attempt, launchTime, executorId, host, taskLocality, speculative) taskInfo.gettingResultTime = gettingResultTime taskInfo.finishTime = finishTime taskInfo.failed = failed diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala index 8e9c3036d09c2..1d5467060623c 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -125,16 +125,16 @@ private[spark] object FileAppender extends Logging { val validatedParams: Option[(Long, String)] = rollingInterval match { case "daily" => logInfo(s"Rolling executor logs enabled for $file with daily rolling") - Some(24 * 60 * 60 * 1000L, "--YYYY-MM-dd") + Some(24 * 60 * 60 * 1000L, "--yyyy-MM-dd") case "hourly" => logInfo(s"Rolling executor logs enabled for $file with hourly rolling") - Some(60 * 60 * 1000L, "--YYYY-MM-dd--HH") + Some(60 * 60 * 1000L, "--yyyy-MM-dd--HH") case "minutely" => logInfo(s"Rolling executor logs enabled for $file with rolling every minute") - Some(60 * 1000L, "--YYYY-MM-dd--HH-mm") + Some(60 * 1000L, "--yyyy-MM-dd--HH-mm") case IntParam(seconds) => logInfo(s"Rolling executor logs enabled for $file with rolling $seconds seconds") - Some(seconds * 1000L, "--YYYY-MM-dd--HH-mm-ss") + Some(seconds * 1000L, "--yyyy-MM-dd--HH-mm-ss") case _ => logWarning(s"Illegal interval for rolling executor logs [$rollingInterval], " + s"rolling logs not enabled") diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index 1bbbd20cf076f..e579421676343 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -19,7 +19,7 @@ package org.apache.spark.util.logging import java.io.{File, FileFilter, InputStream} -import org.apache.commons.io.FileUtils +import com.google.common.io.Files import org.apache.spark.SparkConf import RollingFileAppender._ @@ -83,7 +83,7 @@ private[spark] class RollingFileAppender( logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile") if (activeFile.exists) { if (!rolloverFile.exists) { - FileUtils.moveFile(activeFile, rolloverFile) + Files.move(activeFile, rolloverFile) logInfo(s"Rolled over $activeFile to $rolloverFile") } else { // In case the rollover file name clashes, make a unique file name. @@ -100,7 +100,7 @@ private[spark] class RollingFileAppender( logWarning(s"Rollover file $rolloverFile already exists, " + s"rolled over $activeFile to file $altRolloverFile") - FileUtils.moveFile(activeFile, altRolloverFile) + Files.move(activeFile, altRolloverFile) } } else { logWarning(s"File $activeFile does not exist") diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala index 84e5c3c917dcb..d7b7219e179d0 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala @@ -109,7 +109,7 @@ private[spark] class SizeBasedRollingPolicy( } @volatile private var bytesWrittenSinceRollover = 0L - val formatter = new SimpleDateFormat("--YYYY-MM-dd--HH-mm-ss--SSSS") + val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS") /** Should rollover if the next set of bytes is going to exceed the size limit */ def shouldRollover(bytesToBeWritten: Long): Boolean = { diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e46298c6a9e63..b2868b59ce6c6 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -18,9 +18,13 @@ package org.apache.spark; import java.io.*; +import java.net.URI; import java.util.*; import scala.Tuple2; +import scala.Tuple3; +import scala.Tuple4; + import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; @@ -304,6 +308,66 @@ public void cogroup() { cogrouped.collect(); } + @SuppressWarnings("unchecked") + @Test + public void cogroup3() { + JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( + new Tuple2("Apples", "Fruit"), + new Tuple2("Oranges", "Fruit"), + new Tuple2("Oranges", "Citrus") + )); + JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 2), + new Tuple2("Apples", 3) + )); + JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 21), + new Tuple2("Apples", 42) + )); + + JavaPairRDD, Iterable, Iterable>> cogrouped = + categories.cogroup(prices, quantities); + Assert.assertEquals("[Fruit, Citrus]", + Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); + Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); + + + cogrouped.collect(); + } + + @SuppressWarnings("unchecked") + @Test + public void cogroup4() { + JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( + new Tuple2("Apples", "Fruit"), + new Tuple2("Oranges", "Fruit"), + new Tuple2("Oranges", "Citrus") + )); + JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 2), + new Tuple2("Apples", 3) + )); + JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 21), + new Tuple2("Apples", 42) + )); + JavaPairRDD countries = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", "BR"), + new Tuple2("Apples", "US") + )); + + JavaPairRDD, Iterable, Iterable, Iterable>> cogrouped = + categories.cogroup(prices, quantities, countries); + Assert.assertEquals("[Fruit, Citrus]", + Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); + Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); + Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4())); + + cogrouped.collect(); + } + @SuppressWarnings("unchecked") @Test public void leftOuterJoin() { @@ -678,7 +742,7 @@ public void persist() { public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); TaskContext context = new TaskContext(0, 0, 0, false, new TaskMetrics()); - Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); + Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } @Test @@ -705,7 +769,7 @@ public void textFiles() throws IOException { } @Test - public void wholeTextFiles() throws IOException { + public void wholeTextFiles() throws Exception { byte[] content1 = "spark is easy to use.\n".getBytes("utf-8"); byte[] content2 = "spark is also easy to use.\n".getBytes("utf-8"); @@ -721,7 +785,7 @@ public void wholeTextFiles() throws IOException { List> result = readRDD.collect(); for (Tuple2 res : result) { - Assert.assertEquals(res._2(), container.get(res._1())); + Assert.assertEquals(res._2(), container.get(new URI(res._1()).getPath())); } } diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala index 4ab870e751778..1db4266869372 100644 --- a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark -import org.scalatest.FunSuite +import scala.concurrent.Await import akka.actor._ -import org.apache.spark.scheduler.MapStatus +import org.scalatest.FunSuite import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.AkkaUtils -import scala.concurrent.Await +import org.apache.spark.shuffle._ + /** * Test the AkkaUtils with various security settings. diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index f64f3c9036034..fc00458083a33 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { test("ShuffledRDD") { testRDD(rdd => { // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD - new ShuffledRDD[Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner) + new ShuffledRDD[Int, Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner) }) } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 13b415cccb647..266f21bf08201 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -32,6 +32,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, BroadcastBlockId, RDDBlockId, ShuffleBlockId} +import org.apache.spark.shuffle.MapOutputTrackerMaster class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { @@ -413,5 +414,6 @@ class CleanerTester( } private def blockManager = sc.env.blockManager - private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + private def mapOutputTrackerMaster = sc.env.shuffleManager.mapOutputTracker. + asInstanceOf[MapOutputTrackerMaster] } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 47112ce66d695..476a2706316d5 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -56,15 +56,18 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { } // If the Kryo serializer is not used correctly, the shuffle would fail because the // default Java serializer cannot handle the non serializable class. - val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( - b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf)) + val c = new ShuffledRDD[Int, + NonJavaSerializableClass, + NonJavaSerializableClass, + (Int, NonJavaSerializableClass)](b, new HashPartitioner(NUM_BLOCKS)) + c.setSerializer(new KryoSerializer(conf)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 10) // All blocks must have non-zero size (0 until NUM_BLOCKS).foreach { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + val statuses = SparkEnv.get.shuffleManager.mapOutputTracker.getServerStatuses(shuffleId, id) assert(statuses.forall(s => s._2 > 0)) } } @@ -78,8 +81,11 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { } // If the Kryo serializer is not used correctly, the shuffle would fail because the // default Java serializer cannot handle the non serializable class. - val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( - b, new HashPartitioner(3)).setSerializer(new KryoSerializer(conf)) + val c = new ShuffledRDD[Int, + NonJavaSerializableClass, + NonJavaSerializableClass, + (Int, NonJavaSerializableClass)](b, new HashPartitioner(3)) + c.setSerializer(new KryoSerializer(conf)) assert(c.count === 10) } @@ -94,14 +100,14 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // NOTE: The default Java serializer doesn't create zero-sized blocks. // So, use Kryo - val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) + val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10)) .setSerializer(new KryoSerializer(conf)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + val statuses = SparkEnv.get.shuffleManager.mapOutputTracker.getServerStatuses(shuffleId, id) statuses.map(x => x._2) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) @@ -120,13 +126,13 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { val b = a.map(x => (x, x*2)) // NOTE: The default Java serializer should create zero-sized blocks - val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) + val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + val statuses = SparkEnv.get.shuffleManager.mapOutputTracker.getServerStatuses(shuffleId, id) statuses.map(x => x._2) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) @@ -141,8 +147,8 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) - val results = new ShuffledRDD[Int, Int, MutablePair[Int, Int]](pairs, new HashPartitioner(2)) - .collect() + val results = new ShuffledRDD[Int, Int, Int, MutablePair[Int, Int]](pairs, + new HashPartitioner(2)).collect() data.foreach { pair => results should contain (pair) } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 94fba102865b3..67e3be21c3c93 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -77,6 +77,22 @@ class SparkContextSchedulerCreationSuite } } + test("local-default-parallelism") { + val defaultParallelism = System.getProperty("spark.default.parallelism") + System.setProperty("spark.default.parallelism", "16") + val sched = createTaskScheduler("local") + + sched.backend match { + case s: LocalBackend => assert(s.defaultParallelism() === 16) + case _ => fail() + } + + Option(defaultParallelism) match { + case Some(v) => System.setProperty("spark.default.parallelism", v) + case _ => System.clearProperty("spark.default.parallelism") + } + } + test("simr") { createTaskScheduler("simr://uri").backend match { case s: SimrSchedulerBackend => // OK diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 0b9004448a63e..447e38ec9dbd0 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -249,6 +249,39 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { )) } + test("groupWith3") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd'))) + val joined = rdd1.groupWith(rdd2, rdd3).collect() + assert(joined.size === 4) + val joinedSet = joined.map(x => (x._1, + (x._2._1.toList, x._2._2.toList, x._2._3.toList))).toSet + assert(joinedSet === Set( + (1, (List(1, 2), List('x'), List('a'))), + (2, (List(1), List('y', 'z'), List())), + (3, (List(1), List(), List('b'))), + (4, (List(), List('w'), List('c', 'd'))) + )) + } + + test("groupWith4") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd'))) + val rdd4 = sc.parallelize(Array((2, '@'))) + val joined = rdd1.groupWith(rdd2, rdd3, rdd4).collect() + assert(joined.size === 4) + val joinedSet = joined.map(x => (x._1, + (x._2._1.toList, x._2._2.toList, x._2._3.toList, x._2._4.toList))).toSet + assert(joinedSet === Set( + (1, (List(1, 2), List('x'), List('a'), List())), + (2, (List(1), List('y', 'z'), List(), List('@'))), + (3, (List(1), List(), List('b'), List())), + (4, (List(), List('w'), List('c', 'd'), List())) + )) + } + test("zero-partition RDD") { val emptyDir = Files.createTempDir() emptyDir.deleteOnExit() diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 0e5625b7645d5..0f9cbe213ea17 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -276,7 +276,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { // we can optionally shuffle to keep the upstream parallel val coalesced5 = data.coalesce(1, shuffle = true) val isEquals = coalesced5.dependencies.head.rdd.dependencies.head.rdd. - asInstanceOf[ShuffledRDD[_, _, _]] != null + asInstanceOf[ShuffledRDD[_, _, _, _]] != null assert(isEquals) // when shuffling, we can increase the number of partitions @@ -509,7 +509,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("takeSample") { val n = 1000000 val data = sc.parallelize(1 to n, 2) - + for (num <- List(5, 20, 100)) { val sample = data.takeSample(withReplacement=false, num=num) assert(sample.size === num) // Got exactly num elements @@ -704,11 +704,11 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(ancestors3.count(_.isInstanceOf[MappedRDD[_, _]]) === 2) // Any ancestors before the shuffle are not considered - assert(ancestors4.size === 1) - assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 1) - assert(ancestors5.size === 4) - assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 1) - assert(ancestors5.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) === 1) + assert(ancestors4.size === 0) + assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 0) + assert(ancestors5.size === 3) + assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 1) + assert(ancestors5.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) === 0) assert(ancestors5.count(_.isInstanceOf[MappedValuesRDD[_, _, _]]) === 2) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index efef9d26dadca..f77661ccbd1c5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -35,7 +35,7 @@ class CoarseGrainedSchedulerBackendSuite extends FunSuite with LocalSparkContext val thrown = intercept[SparkException] { larger.collect() } - assert(thrown.getMessage.contains("Consider using broadcast variables for large values")) + assert(thrown.getMessage.contains("using broadcast variables for large values")) val smaller = sc.parallelize(1 to 4).collect() assert(smaller.size === 4) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 45368328297d3..318bf17fc7b2b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import org.apache.spark.shuffle.hash.HashShuffleManager + import scala.Tuple2 import scala.collection.mutable.{HashSet, HashMap, Map} import scala.language.reflectiveCalls @@ -28,6 +30,7 @@ import org.scalatest.{BeforeAndAfter, FunSuiteLike} import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.shuffle.{MapStatus, ShuffleManager, MapOutputTrackerMaster} import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.CallSite @@ -55,7 +58,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F override def stop() = {} override def submitTasks(taskSet: TaskSet) = { // normally done by TaskSetManager - taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) + taskSet.tasks.foreach(_.epoch = shuffleManager.mapOutputTracker.getEpoch) taskSets += taskSet } override def cancelTasks(stageId: Int, interruptThread: Boolean) { @@ -80,7 +83,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } } - var mapOutputTracker: MapOutputTrackerMaster = null + var shuffleManager: ShuffleManager = null var scheduler: DAGScheduler = null var dagEventProcessTestActor: TestActorRef[DAGSchedulerEventProcessActor] = null @@ -115,17 +118,18 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F sc = new SparkContext("local", "DAGSchedulerSuite") sparkListener.successfulStages.clear() sparkListener.failedStages.clear() + failure = null sc.addSparkListener(sparkListener) taskSets.clear() cancelledStages.clear() cacheLocations.clear() results.clear() - mapOutputTracker = new MapOutputTrackerMaster(conf) + shuffleManager = sc.env.shuffleManager scheduler = new DAGScheduler( sc, taskScheduler, sc.listenerBus, - mapOutputTracker, + shuffleManager, blockManagerMaster, sc.env) { override def runLocally(job: ActiveJob) { @@ -314,6 +318,53 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assertDataStructuresEmpty } + test("job cancellation no-kill backend") { + // make sure that the DAGScheduler doesn't crash when the TaskScheduler + // doesn't implement killTask() + val noKillTaskScheduler = new TaskScheduler() { + override def rootPool: Pool = null + override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def start() = {} + override def stop() = {} + override def submitTasks(taskSet: TaskSet) = { + taskSets += taskSet + } + override def cancelTasks(stageId: Int, interruptThread: Boolean) { + throw new UnsupportedOperationException + } + override def setDAGScheduler(dagScheduler: DAGScheduler) = {} + override def defaultParallelism() = 2 + } + val noKillScheduler = new DAGScheduler( + sc, + noKillTaskScheduler, + sc.listenerBus, + shuffleManager, + blockManagerMaster, + sc.env) { + override def runLocally(job: ActiveJob) { + // don't bother with the thread while unit testing + runLocallyWithinThread(job) + } + } + dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor]( + Props(classOf[DAGSchedulerEventProcessActor], noKillScheduler))(system) + val rdd = makeRdd(1, Nil) + val jobId = submit(rdd, Array(0)) + cancel(jobId) + // Because the job wasn't actually cancelled, we shouldn't have received a failure message. + assert(failure === null) + + // When the task set completes normally, state should be correctly updated. + complete(taskSets(0), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty + + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.failedStages.isEmpty) + assert(sparkListener.successfulStages.contains(0)) + } + test("run trivial shuffle") { val shuffleMapRdd = makeRdd(2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) @@ -323,7 +374,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + assert(shuffleManager.mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) @@ -350,7 +401,8 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) + assert(shuffleManager.mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) + === Array("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty @@ -363,9 +415,9 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F val reduceRdd = makeRdd(2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) // pretend we were told hostA went away - val oldEpoch = mapOutputTracker.getEpoch + val oldEpoch = shuffleManager.mapOutputTracker.getEpoch runEvent(ExecutorLost("exec-hostA")) - val newEpoch = mapOutputTracker.getEpoch + val newEpoch = shuffleManager.mapOutputTracker.getEpoch assert(newEpoch > oldEpoch) val noAccum = Map[Long, Any]() val taskSet = taskSets(0) @@ -378,7 +430,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F // should work because it's a new epoch taskSet.tasks(1).epoch = newEpoch runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null)) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + assert(shuffleManager.mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) @@ -477,7 +529,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F (Success, makeMapStatus("hostB", 1)))) // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + assert(shuffleManager.mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index abd7b22310f1a..6df0a080961b6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -181,7 +181,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) listener.stageInfos.size should be {2} // Shuffle map stage + result stage val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 2).get - stageInfo3.rddInfos.size should be {2} // ShuffledRDD, MapPartitionsRDD + stageInfo3.rddInfos.size should be {1} // ShuffledRDD stageInfo3.rddInfos.forall(_.numPartitions == 4) should be {true} stageInfo3.rddInfos.exists(_.name == "Trois") should be {true} } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 6f1fd25764544..59a618956a356 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -77,6 +77,10 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex override def isExecutorAlive(execId: String): Boolean = executors.contains(execId) override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host) + + def addExecutor(execId: String, host: String) { + executors.put(execId, host) + } } class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { @@ -400,6 +404,36 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(sched.taskSetsFailed.contains(taskSet.id)) } + test("new executors get added") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host1", "execB")), + Seq(TaskLocation("host2", "execC")), + Seq()) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + // All tasks added to no-pref list since no preferred location is available + assert(manager.pendingTasksWithNoPrefs.size === 4) + // Only ANY is valid + assert(manager.myLocalityLevels.sameElements(Array(ANY))) + // Add a new executor + sched.addExecutor("execD", "host1") + manager.executorAdded() + // Task 0 and 1 should be removed from no-pref list + assert(manager.pendingTasksWithNoPrefs.size === 2) + // Valid locality should contain NODE_LOCAL and ANY + assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, ANY))) + // Add another executor + sched.addExecutor("execC", "host2") + manager.executorAdded() + // No-pref list now only contains task 3 + assert(manager.pendingTasksWithNoPrefs.size === 1) + // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) + } + def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index cdd6b3d8feed7..79280d1a06653 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -128,6 +128,21 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { check(1.0 until 1000000.0 by 2.0) } + test("asJavaIterable") { + // Serialize a collection wrapped by asJavaIterable + val ser = new KryoSerializer(conf).newInstance() + val a = ser.serialize(scala.collection.convert.WrapAsJava.asJavaIterable(Seq(12345))) + val b = ser.deserialize[java.lang.Iterable[Int]](a) + assert(b.iterator().next() === 12345) + + // Serialize a normal Java collection + val col = new java.util.ArrayList[Int] + col.add(54321) + val c = ser.serialize(col) + val d = ser.deserialize[java.lang.Iterable[Int]](c) + assert(b.iterator().next() === 12345) + } + test("custom registrator") { val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/MapOutputTrackerSuite.scala similarity index 99% rename from core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala rename to core/src/test/scala/org/apache/spark/shuffle/MapOutputTrackerSuite.scala index 95ba273f16a71..3cfae8db84f48 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/MapOutputTrackerSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.shuffle import scala.concurrent.Await @@ -23,7 +23,7 @@ import akka.actor._ import akka.testkit.TestActorRef import org.scalatest.FunSuite -import org.apache.spark.scheduler.MapStatus +import org.apache.spark._ import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.AkkaUtils diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index d7dbe5164b7f6..03d22dbd8fe0f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -20,25 +20,22 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays +import scala.language.{implicitConversions, postfixOps} + import akka.actor._ -import org.apache.spark.SparkConf -import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers import org.scalatest.time.SpanSugar._ +import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} -import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} +import org.apache.spark.shuffle.hash.HashShuffleManager +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} -import scala.language.implicitConversions -import scala.language.postfixOps - class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter with PrivateMethodTester { private val conf = new SparkConf(false) @@ -49,7 +46,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter var oldArch: String = null conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) - val mapOutputTracker = new MapOutputTrackerMaster(conf) + val shuffleManager = new HashShuffleManager(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer.mb", "1") @@ -137,7 +134,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("master + 1 manager interaction") { store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -168,9 +165,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("master + 2 managers interaction") { store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val peers = master.getPeers(store.blockManagerId, 1) assert(peers.size === 1, "master did not return the other manager as a peer") @@ -186,7 +183,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("removing block") { store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -235,7 +232,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("removing rdd") { store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -269,10 +266,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("removing broadcast") { store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val driverStore = store val executorStore = new BlockManager("executor", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -342,7 +339,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) @@ -359,7 +356,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("reregistration on block update") { store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -379,7 +376,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("reregistration doesn't dead lock") { val heartBeat = PrivateMethod[Unit]('heartBeat) store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) @@ -417,7 +414,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("in-memory LRU storage") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -437,7 +434,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("in-memory LRU storage with serialization") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -457,7 +454,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("in-memory LRU for partitions of same RDD") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -477,7 +474,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("in-memory LRU for partitions of multiple RDDs") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) @@ -504,7 +501,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter val tachyonUnitTestEnabled = conf.getBoolean("spark.test.tachyon.enable", false) if (tachyonUnitTestEnabled) { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -521,7 +518,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("on-disk storage") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -535,7 +532,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("disk and memory storage") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -551,7 +548,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("disk and memory storage with getLocalBytes") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -567,7 +564,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("disk and memory storage with serialization") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -583,7 +580,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("disk and memory storage with serialization and getLocalBytes") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -599,7 +596,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("LRU with mixed storage levels") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -622,7 +619,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("in-memory LRU with streams") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -647,7 +644,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("LRU with mixed storage levels and streams") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -694,7 +691,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("overly large block") { store = new BlockManager("", actorSystem, master, serializer, 500, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.getSingle("a1") === None, "a1 was in store") store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) @@ -706,7 +703,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter try { conf.set("spark.shuffle.compress", "true") store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100, "shuffle_0_0_0 was not compressed") @@ -715,7 +712,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter conf.set("spark.shuffle.compress", "false") store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000, "shuffle_0_0_0 was compressed") @@ -724,7 +721,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter conf.set("spark.broadcast.compress", "true") store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100, "broadcast_0 was not compressed") @@ -733,7 +730,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter conf.set("spark.broadcast.compress", "false") store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed") store.stop() @@ -741,7 +738,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter conf.set("spark.rdd.compress", "true") store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed") store.stop() @@ -749,7 +746,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter conf.set("spark.rdd.compress", "false") store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed") store.stop() @@ -757,7 +754,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter // Check that any other block types are also kept uncompressed store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") store.stop() @@ -772,7 +769,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("block store put failure") { // Use Java serializer so we can create an unserializable error. store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) // The put should fail since a1 is not serializable. class UnserializableClass @@ -836,7 +833,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("updated block statuses") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val list = List.fill(2)(new Array[Byte](200)) val bigList = List.fill(8)(new Array[Byte](200)) @@ -891,7 +888,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("query block statuses") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val list = List.fill(2)(new Array[Byte](200)) // Tell master. By LRU, only list2 and list3 remains. @@ -931,7 +928,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("get matching blocks") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) val list = List.fill(2)(new Array[Byte](10)) // insert some blocks @@ -965,7 +962,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, shuffleManager) store.putSingle(rdd(0, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(1, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY) // Access rdd_1_0 to ensure it's not least recently used. diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index c3a14f48de38e..fa43b66c6cb5a 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs import org.scalatest.FunSuite import org.scalatest.Matchers -import org.apache.spark.{LocalSparkContext, SparkConf, Success} +import org.apache.spark._ import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -66,7 +66,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc // finish this task, should get updated shuffleRead shuffleReadMetrics.remoteBytesRead = 1000 taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) - var taskInfo = new TaskInfo(1234L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL) + var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 var task = new ShuffleMapTask(0, null, null, 0, null) val taskType = Utils.getFormattedClassName(task) @@ -75,7 +75,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc .shuffleRead == 1000) // finish a task with unknown executor-id, nothing should happen - taskInfo = new TaskInfo(1234L, 0, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL) + taskInfo = + new TaskInfo(1234L, 0, 1, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL, true) taskInfo.finishTime = 1 task = new ShuffleMapTask(0, null, null, 0, null) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) @@ -84,7 +85,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc // finish this task, should get updated duration shuffleReadMetrics.remoteBytesRead = 1000 taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) - taskInfo = new TaskInfo(1235L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL) + taskInfo = new TaskInfo(1235L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 task = new ShuffleMapTask(0, null, null, 0, null) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) @@ -94,11 +95,39 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc // finish this task, should get updated duration shuffleReadMetrics.remoteBytesRead = 1000 taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) - taskInfo = new TaskInfo(1236L, 0, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL) + taskInfo = new TaskInfo(1236L, 0, 2, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 task = new ShuffleMapTask(0, null, null, 0, null) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-2", fail()) .shuffleRead == 1000) } + + test("test task success vs failure counting for different task end reasons") { + val conf = new SparkConf() + val listener = new JobProgressListener(conf) + val metrics = new TaskMetrics() + val taskInfo = new TaskInfo(1234L, 0, 3, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) + taskInfo.finishTime = 1 + val task = new ShuffleMapTask(0, null, null, 0, null) + val taskType = Utils.getFormattedClassName(task) + + // Go through all the failure cases to make sure we are counting them as failures. + val taskFailedReasons = Seq( + Resubmitted, + new FetchFailed(null, 0, 0, 0), + new ExceptionFailure("Exception", "description", null, None), + TaskResultLost, + TaskKilled, + ExecutorLostFailure, + UnknownReason) + for (reason <- taskFailedReasons) { + listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, reason, taskInfo, metrics)) + assert(listener.stageIdToTasksComplete.get(task.stageId) === None) + } + + // Make sure we count success as success. + listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, metrics)) + assert(listener.stageIdToTasksComplete.get(task.stageId) === Some(1)) + } } diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 02e228945bbd9..ca37d707b06ca 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -18,13 +18,16 @@ package org.apache.spark.util import java.io._ +import java.nio.charset.Charset import scala.collection.mutable.HashSet import scala.reflect._ -import org.apache.commons.io.{FileUtils, IOUtils} -import org.apache.spark.{Logging, SparkConf} import org.scalatest.{BeforeAndAfter, FunSuite} + +import com.google.common.io.Files + +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy, FileAppender} class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { @@ -41,11 +44,11 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { test("basic file appender") { val testString = (1 to 1000).mkString(", ") - val inputStream = IOUtils.toInputStream(testString) + val inputStream = new ByteArrayInputStream(testString.getBytes(Charset.forName("UTF-8"))) val appender = new FileAppender(inputStream, testFile) inputStream.close() appender.awaitTermination() - assert(FileUtils.readFileToString(testFile) === testString) + assert(Files.toString(testFile, Charset.forName("UTF-8")) === testString) } test("rolling file appender - time-based rolling") { @@ -93,7 +96,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { val allGeneratedFiles = new HashSet[String]() val items = (1 to 10).map { _.toString * 10000 } for (i <- 0 until items.size) { - testOutputStream.write(items(i).getBytes("UTF8")) + testOutputStream.write(items(i).getBytes(Charset.forName("UTF-8"))) testOutputStream.flush() allGeneratedFiles ++= RollingFileAppender.getSortedRolledOverFiles( testFile.getParentFile.toString, testFile.getName).map(_.toString) @@ -197,7 +200,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { // send data to appender through the input stream, and wait for the data to be written val expectedText = textToAppend.mkString("") for (i <- 0 until textToAppend.size) { - outputStream.write(textToAppend(i).getBytes("UTF8")) + outputStream.write(textToAppend(i).getBytes(Charset.forName("UTF-8"))) outputStream.flush() Thread.sleep(sleepTimeBetweenTexts) } @@ -212,7 +215,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { logInfo("Filtered files: \n" + generatedFiles.mkString("\n")) assert(generatedFiles.size > 1) val allText = generatedFiles.map { file => - FileUtils.readFileToString(file) + Files.toString(file, Charset.forName("UTF-8")) }.mkString("") assert(allText === expectedText) generatedFiles diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index f72389b6b323f..6c49870455873 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -35,10 +35,11 @@ class JsonProtocolSuite extends FunSuite { val stageSubmitted = SparkListenerStageSubmitted(makeStageInfo(100, 200, 300, 400L, 500L), properties) val stageCompleted = SparkListenerStageCompleted(makeStageInfo(101, 201, 301, 401L, 501L)) - val taskStart = SparkListenerTaskStart(111, makeTaskInfo(222L, 333, 444L)) - val taskGettingResult = SparkListenerTaskGettingResult(makeTaskInfo(1000L, 2000, 3000L)) + val taskStart = SparkListenerTaskStart(111, makeTaskInfo(222L, 333, 1, 444L, false)) + val taskGettingResult = + SparkListenerTaskGettingResult(makeTaskInfo(1000L, 2000, 5, 3000L, true)) val taskEnd = SparkListenerTaskEnd(1, "ShuffleMapTask", Success, - makeTaskInfo(123L, 234, 345L), makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800)) + makeTaskInfo(123L, 234, 67, 345L, false), makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800)) val jobStart = SparkListenerJobStart(10, Seq[Int](1, 2, 3, 4), properties) val jobEnd = SparkListenerJobEnd(20, JobSucceeded) val environmentUpdate = SparkListenerEnvironmentUpdate(Map[String, Seq[(String, String)]]( @@ -73,7 +74,7 @@ class JsonProtocolSuite extends FunSuite { test("Dependent Classes") { testRDDInfo(makeRddInfo(2, 3, 4, 5L, 6L)) testStageInfo(makeStageInfo(10, 20, 30, 40L, 50L)) - testTaskInfo(makeTaskInfo(999L, 888, 777L)) + testTaskInfo(makeTaskInfo(999L, 888, 55, 777L, false)) testTaskMetrics(makeTaskMetrics(33333L, 44444L, 55555L, 66666L, 7, 8)) testBlockManagerId(BlockManagerId("Hong", "Kong", 500, 1000)) @@ -269,10 +270,12 @@ class JsonProtocolSuite extends FunSuite { private def assertEquals(info1: TaskInfo, info2: TaskInfo) { assert(info1.taskId === info2.taskId) assert(info1.index === info2.index) + assert(info1.attempt === info2.attempt) assert(info1.launchTime === info2.launchTime) assert(info1.executorId === info2.executorId) assert(info1.host === info2.host) assert(info1.taskLocality === info2.taskLocality) + assert(info1.speculative === info2.speculative) assert(info1.gettingResultTime === info2.gettingResultTime) assert(info1.finishTime === info2.finishTime) assert(info1.failed === info2.failed) @@ -366,7 +369,7 @@ class JsonProtocolSuite extends FunSuite { private def assertJsonStringEquals(json1: String, json2: String) { val formatJsonString = (json: String) => json.replaceAll("[\\s|]", "") - formatJsonString(json1) === formatJsonString(json2) + assert(formatJsonString(json1) === formatJsonString(json2)) } private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit) { @@ -449,12 +452,12 @@ class JsonProtocolSuite extends FunSuite { } private def makeStageInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { - val rddInfos = (1 to a % 5).map { i => makeRddInfo(a % i, b % i, c % i, d % i, e % i) } + val rddInfos = (0 until a % 5).map { i => makeRddInfo(a + i, b + i, c + i, d + i, e + i) } new StageInfo(a, "greetings", b, rddInfos, "details") } - private def makeTaskInfo(a: Long, b: Int, c: Long) = { - new TaskInfo(a, b, c, "executor", "your kind sir", TaskLocality.NODE_LOCAL) + private def makeTaskInfo(a: Long, b: Int, c: Int, d: Long, speculative: Boolean) = { + new TaskInfo(a, b, c, d, "executor", "your kind sir", TaskLocality.NODE_LOCAL, speculative) } private def makeTaskMetrics(a: Long, b: Long, c: Long, d: Long, e: Int, f: Int) = { @@ -493,55 +496,77 @@ class JsonProtocolSuite extends FunSuite { private val stageSubmittedJsonString = """ {"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":100,"Stage Name": - "greetings","Number of Tasks":200,"RDD Info":{"RDD ID":100,"Name":"mayor","Storage - Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":true, - "Replication":1},"Number of Partitions":200,"Number of Cached Partitions":300, - "Memory Size":400,"Disk Size":500,"Tachyon Size":0},"Emitted Task Size Warning":false}, - "Properties":{"France":"Paris","Germany":"Berlin","Russia":"Moscow","Ukraine":"Kiev"}} + "greetings","Number of Tasks":200,"RDD Info":[],"Details":"details", + "Emitted Task Size Warning":false},"Properties":{"France":"Paris","Germany":"Berlin", + "Russia":"Moscow","Ukraine":"Kiev"}} """ private val stageCompletedJsonString = """ {"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":101,"Stage Name": - "greetings","Number of Tasks":201,"RDD Info":{"RDD ID":101,"Name":"mayor","Storage + "greetings","Number of Tasks":201,"RDD Info":[{"RDD ID":101,"Name":"mayor","Storage Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":true, "Replication":1},"Number of Partitions":201,"Number of Cached Partitions":301, - "Memory Size":401,"Disk Size":501,"Tachyon Size":0},"Emitted Task Size Warning":false}} + "Memory Size":401,"Tachyon Size":0,"Disk Size":501}],"Details":"details", + "Emitted Task Size Warning":false}} """ private val taskStartJsonString = """ - {"Event":"SparkListenerTaskStart","Stage ID":111,"Task Info":{"Task ID":222, - "Index":333,"Launch Time":444,"Executor ID":"executor","Host":"your kind sir", - "Locality":"NODE_LOCAL","Getting Result Time":0,"Finish Time":0,"Failed":false, - "Serialized Size":0}} - """ + |{"Event":"SparkListenerTaskStart","Stage ID":111,"Task Info":{"Task ID":222, + |"Index":333,"Attempt":1,"Launch Time":444,"Executor ID":"executor","Host":"your kind sir", + |"Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0, + |"Failed":false,"Serialized Size":0}} + """.stripMargin private val taskGettingResultJsonString = """ - {"Event":"SparkListenerTaskGettingResult","Task Info":{"Task ID":1000,"Index": - 2000,"Launch Time":3000,"Executor ID":"executor","Host":"your kind sir", - "Locality":"NODE_LOCAL","Getting Result Time":0,"Finish Time":0,"Failed":false, - "Serialized Size":0}} - """ + |{"Event":"SparkListenerTaskGettingResult","Task Info": + | {"Task ID":1000,"Index":2000,"Attempt":5,"Launch Time":3000,"Executor ID":"executor", + | "Host":"your kind sir","Locality":"NODE_LOCAL","Speculative":true,"Getting Result Time":0, + | "Finish Time":0,"Failed":false,"Serialized Size":0 + | } + |} + """.stripMargin private val taskEndJsonString = """ - {"Event":"SparkListenerTaskEnd","Stage ID":1,"Task Type":"ShuffleMapTask", - "Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":123,"Index": - 234,"Launch Time":345,"Executor ID":"executor","Host":"your kind sir", - "Locality":"NODE_LOCAL","Getting Result Time":0,"Finish Time":0,"Failed": - false,"Serialized Size":0},"Task Metrics":{"Host Name":"localhost", - "Executor Deserialize Time":300,"Executor Run Time":400,"Result Size":500, - "JVM GC Time":600,"Result Serialization Time":700,"Memory Bytes Spilled": - 800,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Shuffle Finish Time": - 900,"Total Blocks Fetched":1500,"Remote Blocks Fetched":800,"Local Blocks Fetched": - 700,"Fetch Wait Time":900,"Remote Bytes Read":1000},"Shuffle Write Metrics": - {"Shuffle Bytes Written":1200,"Shuffle Write Time":1500},"Updated Blocks": - [{"Block ID":{"Type":"RDDBlockId","RDD ID":0,"Split Index":0},"Status": - {"Storage Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false, - "Deserialized":false,"Replication":2},"Memory Size":0,"Disk Size":0,"Tachyon Size":0}}]}} - """ + |{"Event":"SparkListenerTaskEnd","Stage ID":1,"Task Type":"ShuffleMapTask", + |"Task End Reason":{"Reason":"Success"}, + |"Task Info":{ + | "Task ID":123,"Index":234,"Attempt":67,"Launch Time":345,"Executor ID":"executor", + | "Host":"your kind sir","Locality":"NODE_LOCAL","Speculative":false, + | "Getting Result Time":0,"Finish Time":0,"Failed":false,"Serialized Size":0 + |}, + |"Task Metrics":{ + | "Host Name":"localhost","Executor Deserialize Time":300,"Executor Run Time":400, + | "Result Size":500,"JVM GC Time":600,"Result Serialization Time":700, + | "Memory Bytes Spilled":800,"Disk Bytes Spilled":0, + | "Shuffle Read Metrics":{ + | "Shuffle Finish Time":900, + | "Total Blocks Fetched":1500, + | "Remote Blocks Fetched":800, + | "Local Blocks Fetched":700, + | "Fetch Wait Time":900, + | "Remote Bytes Read":1000 + | }, + | "Shuffle Write Metrics":{ + | "Shuffle Bytes Written":1200, + | "Shuffle Write Time":1500}, + | "Updated Blocks":[ + | {"Block ID":"rdd_0_0", + | "Status":{ + | "Storage Level":{ + | "Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":false, + | "Replication":2 + | }, + | "Memory Size":0,"Tachyon Size":0,"Disk Size":0 + | } + | } + | ] + | } + |} + """.stripMargin private val jobStartJsonString = """ diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index e15fd59a5a8bb..ef7178bcdf5c2 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.util.random import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.commons.math3.stat.inference.ChiSquareTest + import org.apache.spark.util.Utils.times import scala.language.reflectiveCalls @@ -33,45 +35,30 @@ class XORShiftRandomSuite extends FunSuite with Matchers { } /* - * This test is based on a chi-squared test for randomness. The values are hard-coded - * so as not to create Spark's dependency on apache.commons.math3 just to call one - * method for calculating the exact p-value for a given number of random numbers - * and bins. In case one would want to move to a full-fledged test based on - * apache.commons.math3, the relevant class is here: - * org.apache.commons.math3.stat.inference.ChiSquareTest + * This test is based on a chi-squared test for randomness. */ test ("XORShift generates valid random numbers") { val f = fixture - val numBins = 10 - // create 10 bins - val bins = Array.fill(numBins)(0) + val numBins = 10 // create 10 bins + val numRows = 5 // create 5 rows + val bins = Array.ofDim[Long](numRows, numBins) - // populate bins based on modulus of the random number - times(f.hundMil) {bins(math.abs(f.xorRand.nextInt) % 10) += 1} + // populate bins based on modulus of the random number for each row + for (r <- 0 to numRows-1) { + times(f.hundMil) {bins(r)(math.abs(f.xorRand.nextInt) % numBins) += 1} + } - /* since the seed is deterministic, until the algorithm is changed, we know the result will be - * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, - * 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%) - * significance level. However, should the RNG implementation change, the test should still - * pass at the same significance level. The chi-squared test done in R gave the following - * results: - * > chisq.test(c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, - * 10000790, 10002286, 9998699)) - * Chi-squared test for given probabilities - * data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790, - * 10002286, 9998699) - * X-squared = 11.975, df = 9, p-value = 0.2147 - * Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million - * random numbers - * and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared - * is greater than or equal to that number. + /* + * Perform the chi square test on the 5 rows of randomly generated numbers evenly divided into + * 10 bins. chiSquareTest returns true iff the null hypothesis (that the classifications + * represented by the counts in the columns of the input 2-way table are independent of the + * rows) can be rejected with 100 * (1 - alpha) percent confidence, where alpha is prespeficied + * as 0.05 */ - val binSize = f.hundMil/numBins - val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum - xSquared should be < (16.9196) - + val chiTest = new ChiSquareTest + assert(chiTest.chiSquareTest(bins, 0.05) === false) } test ("XORShift with zero seed") { diff --git a/dev/audit-release/blank_maven_build/pom.xml b/dev/audit-release/blank_maven_build/pom.xml index 047659e4a8b7c..02dd9046c9a49 100644 --- a/dev/audit-release/blank_maven_build/pom.xml +++ b/dev/audit-release/blank_maven_build/pom.xml @@ -28,10 +28,6 @@ Spray.cc repository http://repo.spray.cc - - Akka repository - http://repo.akka.io/releases - Spark Staging Repo ${spark.release.repository} diff --git a/dev/audit-release/blank_sbt_build/build.sbt b/dev/audit-release/blank_sbt_build/build.sbt index 1cf52743f27f4..696c7f651837c 100644 --- a/dev/audit-release/blank_sbt_build/build.sbt +++ b/dev/audit-release/blank_sbt_build/build.sbt @@ -25,5 +25,4 @@ libraryDependencies += "org.apache.spark" % System.getenv.get("SPARK_MODULE") % resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), - "Akka Repository" at "http://repo.akka.io/releases/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/maven_app_core/pom.xml b/dev/audit-release/maven_app_core/pom.xml index 76a381f8e17e0..b516396825573 100644 --- a/dev/audit-release/maven_app_core/pom.xml +++ b/dev/audit-release/maven_app_core/pom.xml @@ -28,10 +28,6 @@ Spray.cc repository http://repo.spray.cc - - Akka repository - http://repo.akka.io/releases - Spark Staging Repo ${spark.release.repository} diff --git a/dev/audit-release/sbt_app_core/build.sbt b/dev/audit-release/sbt_app_core/build.sbt index 97a8cc3a4e095..291b1d6440bac 100644 --- a/dev/audit-release/sbt_app_core/build.sbt +++ b/dev/audit-release/sbt_app_core/build.sbt @@ -25,5 +25,4 @@ libraryDependencies += "org.apache.spark" %% "spark-core" % System.getenv.get("S resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), - "Akka Repository" at "http://repo.akka.io/releases/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_ganglia/build.sbt b/dev/audit-release/sbt_app_ganglia/build.sbt index 55db675c722d1..6d9474acf5bbc 100644 --- a/dev/audit-release/sbt_app_ganglia/build.sbt +++ b/dev/audit-release/sbt_app_ganglia/build.sbt @@ -27,5 +27,4 @@ libraryDependencies += "org.apache.spark" %% "spark-ganglia-lgpl" % System.geten resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), - "Akka Repository" at "http://repo.akka.io/releases/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_graphx/build.sbt b/dev/audit-release/sbt_app_graphx/build.sbt index 66f2db357d49b..dd11245e67d44 100644 --- a/dev/audit-release/sbt_app_graphx/build.sbt +++ b/dev/audit-release/sbt_app_graphx/build.sbt @@ -25,5 +25,4 @@ libraryDependencies += "org.apache.spark" %% "spark-graphx" % System.getenv.get( resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), - "Akka Repository" at "http://repo.akka.io/releases/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_hive/build.sbt b/dev/audit-release/sbt_app_hive/build.sbt index 7ac1be729c561..a0d4f25da5842 100644 --- a/dev/audit-release/sbt_app_hive/build.sbt +++ b/dev/audit-release/sbt_app_hive/build.sbt @@ -25,5 +25,4 @@ libraryDependencies += "org.apache.spark" %% "spark-hive" % System.getenv.get("S resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), - "Akka Repository" at "http://repo.akka.io/releases/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_sql/build.sbt b/dev/audit-release/sbt_app_sql/build.sbt index 6e0ad3b4b2960..9116180f71a44 100644 --- a/dev/audit-release/sbt_app_sql/build.sbt +++ b/dev/audit-release/sbt_app_sql/build.sbt @@ -25,5 +25,4 @@ libraryDependencies += "org.apache.spark" %% "spark-sql" % System.getenv.get("SP resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), - "Akka Repository" at "http://repo.akka.io/releases/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_streaming/build.sbt b/dev/audit-release/sbt_app_streaming/build.sbt index 492e5e7c8d763..cb369d516dd16 100644 --- a/dev/audit-release/sbt_app_streaming/build.sbt +++ b/dev/audit-release/sbt_app_streaming/build.sbt @@ -25,5 +25,4 @@ libraryDependencies += "org.apache.spark" %% "spark-streaming" % System.getenv.g resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), - "Akka Repository" at "http://repo.akka.io/releases/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index ffb70096d6014..c44320239bbbf 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -130,7 +130,9 @@ def merge_pr(pr_num, target_ref): merge_message_flags += ["-m", title] if body != None: - merge_message_flags += ["-m", body] + # We remove @ symbols from the body to avoid triggering e-mails + # to people every time someone creates a public fork of Spark. + merge_message_flags += ["-m", body.replace("@", "")] authors = "\n".join(["Author: %s" % a for a in distinct_authors]) diff --git a/dev/mima b/dev/mima index b68800d6d0173..7857294f61caf 100755 --- a/dev/mima +++ b/dev/mima @@ -18,6 +18,7 @@ # set -o pipefail +set -e # Go to the Spark project root directory FWDIR="$(cd `dirname $0`/..; pwd)" diff --git a/dev/run-tests b/dev/run-tests index c82a47ebb618b..d9df020f7563c 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -21,6 +21,9 @@ FWDIR="$(cd `dirname $0`/..; pwd)" cd $FWDIR +export SPARK_HADOOP_VERSION=2.3.0 +export SPARK_YARN=true + # Remove work directory rm -rf ./work diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 4ba20e590f2c2..b30ab1e5218c0 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -136,21 +136,31 @@

{{ page.title }}

- + }); + diff --git a/docs/configuration.md b/docs/configuration.md index b84104cc7e653..f1297c05a89df 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -237,6 +237,32 @@ Apart from these, the following properties are also available, and may be useful spark.storage.memoryFraction. + + spark.shuffle.manager + org.apache.spark.shuffle.hash.HashShuffleManager + + Specify the ShuffleManager class in SparkEnv. The default HashShuffleManager uses hashing and creates + one output file per reduce partition on each mapper. + + + + spark.shuffle.mapOutputTrackerMasterClass + org.apache.spark.shuffle.MapOutputTrackerMaster + + Specify the MapOutputTrackerMaster initialized in ShuffleManager. The default MapOutputTrackerMaster runs on the + driver and uses TimeStampedHashMap to keep track of map output information (includes the block manager address that + the task ran on as well as the sizes of outputs for each reducer), which allows old output information based on + a TTL. + + + + spark.shuffle.mapOutputTrackerWorkerClass + org.apache.spark.shuffle.MapOutputTrackerWorker + + Specify the MapOutputTrackerWorker initialized in ShuffleManager. The default MapOutputTrackerWorker runs on workers + and fetches map output information from the driver's MapOutputTrackerMaster. + + spark.shuffle.compress true diff --git a/docs/index.md b/docs/index.md index 1a4ff3dbf57be..4ac0982ae54f1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -6,7 +6,7 @@ title: Spark Overview Apache Spark is a fast and general-purpose cluster computing system. It provides high-level APIs in Java, Scala and Python, and an optimized engine that supports general execution graphs. -It also supports a rich set of higher-level tools including [Shark](http://shark.cs.berkeley.edu) (Hive on Spark), [Spark SQL](sql-programming-guide.html) for structured data, [MLlib](mllib-guide.html) for machine learning, [GraphX](graphx-programming-guide.html) for graph processing, and [Spark Streaming](streaming-programming-guide.html). +It also supports a rich set of higher-level tools including [Spark SQL](sql-programming-guide.html) for SQL and structured data processing, [MLlib](mllib-guide.html) for machine learning, [GraphX](graphx-programming-guide.html) for graph processing, and [Spark Streaming](streaming-programming-guide.html). # Downloading @@ -109,10 +109,9 @@ options for deployment: **External Resources:** * [Spark Homepage](http://spark.apache.org) -* [Shark](http://shark.cs.berkeley.edu): Apache Hive over Spark * [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here * [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and - exercises about Spark, Shark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/3/), + exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/3/), [slides](http://ampcamp.berkeley.edu/3/) and [exercises](http://ampcamp.berkeley.edu/3/exercises/) are available online for free. * [Code Examples](http://spark.apache.org/examples.html): more are also available in the `examples` subfolder of Spark ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), diff --git a/docs/monitoring.md b/docs/monitoring.md index 2b9e9e5bd7ea0..84073fe4d949a 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -35,11 +35,13 @@ If Spark is run on Mesos or YARN, it is still possible to reconstruct the UI of application through Spark's history server, provided that the application's event logs exist. You can start a the history server by executing: - ./sbin/start-history-server.sh + ./sbin/start-history-server.sh -The base logging directory must be supplied, and should contain sub-directories that each -represents an application's event logs. This creates a web interface at -`http://:18080` by default. The history server can be configured as follows: +When using the file-system provider class (see spark.history.provider below), the base logging +directory must be supplied in the spark.history.fs.logDirectory configuration option, +and should contain sub-directories that each represents an application's event logs. This creates a +web interface at `http://:18080` by default. The history server can be configured as +follows: @@ -69,7 +71,14 @@ represents an application's event logs. This creates a web interface at
Environment VariableMeaning
- + + + + + + - + + + + + +
Property NameDefaultMeaning
spark.history.updateIntervalspark.history.providerorg.apache.spark.deploy.history.FsHistoryProviderName of the class implementing the application history backend. Currently there is only + one implementation, provided by Spark, which looks for application logs stored in the + file system.
spark.history.fs.updateInterval 10 The period, in seconds, at which information displayed by this history server is updated. @@ -78,7 +87,7 @@ represents an application's event logs. This creates a web interface at
spark.history.retainedApplications25050 The number of application UIs to retain. If this cap is exceeded, then the oldest applications will be removed. diff --git a/docs/quick-start.md b/docs/quick-start.md index 64023994771b7..23313d8aa6152 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -266,8 +266,6 @@ version := "1.0" scalaVersion := "{{site.SCALA_VERSION}}" libraryDependencies += "org.apache.spark" %% "spark-core" % "{{site.SPARK_VERSION}}" - -resolvers += "Akka Repository" at "http://repo.akka.io/releases/" {% endhighlight %} For sbt to work correctly, we'll need to layout `SimpleApp.scala` and `simple.sbt` @@ -349,12 +347,6 @@ Note that Spark artifacts are tagged with a Scala version. Simple Project jar 1.0 - - - Akka repository - http://repo.akka.io/releases - - org.apache.spark diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index fecd8f2cc2d48..5d8d603aa3e37 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -95,10 +95,19 @@ Most of the configs are the same for Spark on YARN as for other deployment modes The amount of off heap memory (in megabytes) to be allocated per driver. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc.
spark.yarn.jar(none) + The location of the Spark jar file, in case overriding the default location is desired. + By default, Spark on YARN will use a Spark jar installed locally, but the Spark jar can also be + in a world-readable location on HDFS. This allows YARN to cache it on nodes so that it doesn't + need to be distributed each time an application runs. To point to a jar on HDFS, for example, + set this configuration to "hdfs:///some/path". +
-By default, Spark on YARN will use a Spark jar installed locally, but the Spark JAR can also be in a world-readable location on HDFS. This allows YARN to cache it on nodes so that it doesn't need to be distributed each time an application runs. To point to a JAR on HDFS, `export SPARK_JAR=hdfs:///some/path`. - # Launching Spark on YARN Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. @@ -119,7 +128,7 @@ For example: --num-executors 3 \ --driver-memory 4g \ --executor-memory 2g \ - --executor-cores 1 + --executor-cores 1 \ lib/spark-examples*.jar \ 10 @@ -156,7 +165,20 @@ all environment variables used for launching each container. This process is use classpath problems in particular. (Note that enabling this requires admin privileges on cluster settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). -# Important Notes +To use a custom log4j configuration for the application master or executors, there are two options: + +- upload a custom log4j.properties using spark-submit, by adding it to the "--files" list of files + to be uploaded with the application. +- add "-Dlog4j.configuration=" to "spark.driver.extraJavaOptions" + (for the driver) or "spark.executor.extraJavaOptions" (for executors). Note that if using a file, + the "file:" protocol should be explicitly provided, and the file needs to exist locally on all + the nodes. + +Note that for the first option, both executors and the application master will share the same +log4j configuration, which may cause issues when they run on the same node (e.g. trying to write +to the same log file). + +# Important notes - Before Hadoop 2.2, YARN does not support cores in container resource requests. Thus, when running against an earlier version, the numbers of cores given via command line arguments cannot be passed to YARN. Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. - The local directories used by Spark executors will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 index 454057aa0d279..31f9771223e51 100755 --- a/ec2/spark-ec2 +++ b/ec2/spark-ec2 @@ -19,4 +19,4 @@ # cd "`dirname $0`" -PYTHONPATH="./third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" python ./spark_ec2.py $@ +PYTHONPATH="./third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" python ./spark_ec2.py "$@" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index a40311d9fcf02..e22d93bd31bc2 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -203,6 +203,8 @@ def get_spark_shark_version(opts): # Attempt to resolve an appropriate AMI given the architecture and # region of the request. +# Information regarding Amazon Linux AMI instance type was update on 2014-6-20: +# http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ def get_spark_ami(opts): instance_types = { "m1.small": "pvm", @@ -218,10 +220,12 @@ def get_spark_ami(opts): "cc1.4xlarge": "hvm", "cc2.8xlarge": "hvm", "cg1.4xlarge": "hvm", - "hs1.8xlarge": "hvm", - "hi1.4xlarge": "hvm", - "m3.xlarge": "hvm", - "m3.2xlarge": "hvm", + "hs1.8xlarge": "pvm", + "hi1.4xlarge": "pvm", + "m3.medium": "pvm", + "m3.large": "pvm", + "m3.xlarge": "pvm", + "m3.2xlarge": "pvm", "cr1.8xlarge": "hvm", "i2.xlarge": "hvm", "i2.2xlarge": "hvm", @@ -526,7 +530,8 @@ def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes): # Get number of local disks available for a given EC2 instance type. def get_num_disks(instance_type): - # From http://docs.amazonwebservices.com/AWSEC2/latest/UserGuide/index.html?InstanceStorage.html + # From http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html + # Updated 2014-6-20 disks_by_instance = { "m1.small": 1, "m1.medium": 1, @@ -544,8 +549,10 @@ def get_num_disks(instance_type): "hs1.8xlarge": 24, "cr1.8xlarge": 2, "hi1.4xlarge": 2, - "m3.xlarge": 0, - "m3.2xlarge": 0, + "m3.medium": 1, + "m3.large": 1, + "m3.xlarge": 2, + "m3.2xlarge": 2, "i2.xlarge": 1, "i2.2xlarge": 2, "i2.4xlarge": 4, @@ -559,7 +566,9 @@ def get_num_disks(instance_type): "r3.xlarge": 1, "r3.2xlarge": 1, "r3.4xlarge": 1, - "r3.8xlarge": 2 + "r3.8xlarge": 2, + "g2.2xlarge": 1, + "t1.micro": 0 } if instance_type in disks_by_instance: return disks_by_instance[instance_type] @@ -770,12 +779,16 @@ def real_main(): setup_cluster(conn, master_nodes, slave_nodes, opts, True) elif action == "destroy": - response = raw_input("Are you sure you want to destroy the cluster " + - cluster_name + "?\nALL DATA ON ALL NODES WILL BE LOST!!\n" + - "Destroy cluster " + cluster_name + " (y/N): ") + print "Are you sure you want to destroy the cluster %s?" % cluster_name + print "The following instances will be terminated:" + (master_nodes, slave_nodes) = get_existing_cluster( + conn, opts, cluster_name, die_on_error=False) + for inst in master_nodes + slave_nodes: + print "> %s" % inst.public_dns_name + + msg = "ALL DATA ON ALL NODES WILL BE LOST!!\nDestroy cluster %s (y/N): " % cluster_name + response = raw_input(msg) if response == "y": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) print "Terminating master..." for inst in master_nodes: inst.terminate() diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 21443ebbbfb0e..38095e88dcea9 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -38,7 +38,7 @@ import org.apache.spark.streaming.receiver.Receiver /** * Input stream that pulls messages from a Kafka Broker. * - * @param kafkaParams Map of kafka configuration paramaters. + * @param kafkaParams Map of kafka configuration parameters. * See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. @@ -76,29 +76,31 @@ class KafkaReceiver[ // Connection to Kafka var consumerConnector : ConsumerConnector = null - def onStop() { } + def onStop() { + if (consumerConnector != null) { + consumerConnector.shutdown() + } + } def onStart() { - // In case we are using multiple Threads to handle Kafka Messages - val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) - logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("group.id")) // Kafka connection properties val props = new Properties() kafkaParams.foreach(param => props.put(param._1, param._2)) + val zkConnect = kafkaParams("zookeeper.connect") // Create the connection to the cluster - logInfo("Connecting to Zookeper: " + kafkaParams("zookeeper.connect")) + logInfo("Connecting to Zookeeper: " + zkConnect) val consumerConfig = new ConsumerConfig(props) consumerConnector = Consumer.create(consumerConfig) - logInfo("Connected to " + kafkaParams("zookeeper.connect")) + logInfo("Connected to " + zkConnect) - // When autooffset.reset is defined, it is our responsibility to try and whack the + // When auto.offset.reset is defined, it is our responsibility to try and whack the // consumer group zk node. if (kafkaParams.contains("auto.offset.reset")) { - tryZookeeperConsumerGroupCleanup(kafkaParams("zookeeper.connect"), kafkaParams("group.id")) + tryZookeeperConsumerGroupCleanup(zkConnect, kafkaParams("group.id")) } val keyDecoder = manifest[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) @@ -112,10 +114,14 @@ class KafkaReceiver[ val topicMessageStreams = consumerConnector.createMessageStreams( topics, keyDecoder, valueDecoder) - - // Start the messages handler for each partition - topicMessageStreams.values.foreach { streams => - streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } + val executorPool = Executors.newFixedThreadPool(topics.values.sum) + try { + // Start the messages handler for each partition + topicMessageStreams.values.foreach { streams => + streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } + } + } finally { + executorPool.shutdown() // Just causes threads to terminate after work is done } } @@ -124,30 +130,35 @@ class KafkaReceiver[ extends Runnable { def run() { logInfo("Starting MessageHandler.") - for (msgAndMetadata <- stream) { - store((msgAndMetadata.key, msgAndMetadata.message)) + try { + for (msgAndMetadata <- stream) { + store((msgAndMetadata.key, msgAndMetadata.message)) + } + } catch { + case e: Throwable => logError("Error handling message; exiting", e) } } } - // It is our responsibility to delete the consumer group when specifying autooffset.reset. This + // It is our responsibility to delete the consumer group when specifying auto.offset.reset. This // is because Kafka 0.7.2 only honors this param when the group is not in zookeeper. // // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied - // from Kafkas' ConsoleConsumer. See code related to 'autooffset.reset' when it is set to + // from Kafka's ConsoleConsumer. See code related to 'auto.offset.reset' when it is set to // 'smallest'/'largest': // scalastyle:off // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala // scalastyle:on private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { + val dir = "/consumers/" + groupId + logInfo("Cleaning up temporary Zookeeper data under " + dir + ".") + val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer) try { - val dir = "/consumers/" + groupId - logInfo("Cleaning up temporary zookeeper data under " + dir + ".") - val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer) zk.deleteRecursive(dir) - zk.close() } catch { - case _ : Throwable => // swallow + case e: Throwable => logWarning("Error cleaning up temporary Zookeeper data", e) + } finally { + zk.close() } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala index 1c6d7e59e9a27..d85afa45b1264 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala @@ -62,7 +62,8 @@ class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/* , AnyRe private[graphx] class VertexBroadcastMsgRDDFunctions[T: ClassTag](self: RDD[VertexBroadcastMsg[T]]) { def partitionBy(partitioner: Partitioner): RDD[VertexBroadcastMsg[T]] = { - val rdd = new ShuffledRDD[PartitionID, (VertexId, T), VertexBroadcastMsg[T]](self, partitioner) + val rdd = new ShuffledRDD[PartitionID, (VertexId, T), (VertexId, T), VertexBroadcastMsg[T]]( + self, partitioner) // Set a custom serializer if the data is of int or double type. if (classTag[T] == ClassTag.Int) { @@ -84,7 +85,7 @@ class MsgRDDFunctions[T: ClassTag](self: RDD[MessageToPartition[T]]) { * Return a copy of the RDD partitioned using the specified partitioner. */ def partitionBy(partitioner: Partitioner): RDD[MessageToPartition[T]] = { - new ShuffledRDD[PartitionID, T, MessageToPartition[T]](self, partitioner) + new ShuffledRDD[PartitionID, T, T, MessageToPartition[T]](self, partitioner) } } @@ -103,7 +104,7 @@ object MsgRDDFunctions { private[graphx] class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) { def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = { - val rdd = new ShuffledRDD[VertexId, VD, (VertexId, VD)](self, partitioner) + val rdd = new ShuffledRDD[VertexId, VD, VD, (VertexId, VD)](self, partitioner) // Set a custom serializer if the data is of int or double type. if (classTag[VD] == ClassTag.Int) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index d02e9238adba5..3827ac8d0fd6a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -46,8 +46,8 @@ private[graphx] class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) { /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */ def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = { - new ShuffledRDD[VertexId, (PartitionID, Byte), RoutingTableMessage](self, partitioner) - .setSerializer(new RoutingTableMessageSerializer) + new ShuffledRDD[VertexId, (PartitionID, Byte), (PartitionID, Byte), RoutingTableMessage]( + self, partitioner).setSerializer(new RoutingTableMessageSerializer) } } diff --git a/make-distribution.sh b/make-distribution.sh index ae52b4976dc25..86868438e75c3 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -84,17 +84,28 @@ while (( "$#" )); do shift done +if [ -z "$JAVA_HOME" ]; then + # Fall back on JAVA_HOME from rpm, if found + if which rpm &>/dev/null; then + RPM_JAVA_HOME=$(rpm -E %java_home 2>/dev/null) + if [ "$RPM_JAVA_HOME" != "%java_home" ]; then + JAVA_HOME=$RPM_JAVA_HOME + echo "No JAVA_HOME set, proceeding with '$JAVA_HOME' learned from rpm" + fi + fi +fi + if [ -z "$JAVA_HOME" ]; then echo "Error: JAVA_HOME is not set, cannot proceed." exit -1 fi -VERSION=$(mvn help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) -if [ $? != 0 ]; then +if ! which mvn &>/dev/null; then echo -e "You need Maven installed to build Spark." echo -e "Download Maven from https://maven.apache.org/" exit -1; fi +VERSION=$(mvn help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) JAVA_CMD="$JAVA_HOME"/bin/java JAVA_VERSION=$("$JAVA_CMD" -version 2>&1) diff --git a/mllib/pom.xml b/mllib/pom.xml index 878cb83dbf783..b622f96dd7901 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -84,5 +84,13 @@ scalatest-maven-plugin + + + ../python + + pyspark/mllib/*.py + + + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 8f187c9df5102..7bbed9c8fdbef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -60,7 +60,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. * Smaller value will lead to higher accuracy with the cost of more iterations. */ - def setConvergenceTol(tolerance: Int): this.type = { + def setConvergenceTol(tolerance: Double): this.type = { this.convergenceTol = tolerance this } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 4b1850659a18e..fe7a9033cd5f4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -195,4 +195,38 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { assert(lossLBFGS3.length == 6) assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol) } + + test("Optimize via class LBFGS.") { + val regParam = 0.2 + + // Prepare another non-zero weights to compare the loss in the first iteration. + val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12) + val convergenceTol = 1e-12 + val maxNumIterations = 10 + + val lbfgsOptimizer = new LBFGS(gradient, squaredL2Updater) + .setNumCorrections(numCorrections) + .setConvergenceTol(convergenceTol) + .setMaxNumIterations(maxNumIterations) + .setRegParam(regParam) + + val weightLBFGS = lbfgsOptimizer.optimize(dataRDD, initialWeightsWithIntercept) + + val numGDIterations = 50 + val stepSize = 1.0 + val (weightGD, _) = GradientDescent.runMiniBatchSGD( + dataRDD, + gradient, + squaredL2Updater, + stepSize, + numGDIterations, + regParam, + miniBatchFrac, + initialWeightsWithIntercept) + + // for class LBFGS and the optimize method, we only look at the weights + assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) && + compareDouble(weightLBFGS(1), weightGD(1), 0.02), + "The weight differences between LBFGS and GD should be within 2%.") + } } diff --git a/pom.xml b/pom.xml index 0d46bb4114f73..05f76d566e9d1 100644 --- a/pom.xml +++ b/pom.xml @@ -468,6 +468,13 @@ 3.1 test
+ + + asm + asm + 3.3.1 + test + org.mockito mockito-all diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 042fdfcc47261..1621833e124f5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,8 +34,13 @@ object MimaExcludes { val excludes = SparkBuild.SPARK_VERSION match { case v if v.startsWith("1.1") => - Seq(MimaBuild.excludeSparkPackage("graphx")) ++ Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("graphx") + ) ++ + Seq( + // Adding new method to JavaRDLike trait - we should probably mark this as a developer API. + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"), // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values // for countApproxDistinct* functions, which does not work in Java. We later removed // them, and use the following to tell Mima to not care about them. diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7bb39dc77120b..55a2aa0fc7141 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -293,7 +293,9 @@ object SparkBuild extends Build { "com.novocode" % "junit-interface" % "0.10" % "test", "org.easymock" % "easymockclassextension" % "3.1" % "test", "org.mockito" % "mockito-all" % "1.9.0" % "test", - "junit" % "junit" % "4.10" % "test" + "junit" % "junit" % "4.10" % "test", + // Needed by cglib which is needed by easymock. + "asm" % "asm" % "3.3.1" % "test" ), testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), @@ -461,7 +463,7 @@ object SparkBuild extends Build { def toolsSettings = sharedSettings ++ Seq( name := "spark-tools", - libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-compiler" % v ), + libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-compiler" % v), libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-reflect" % v ) ) ++ assemblySettings ++ extraAssemblySettings @@ -630,9 +632,9 @@ object SparkBuild extends Build { scalaVersion := "2.10.4", retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", - libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", + libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", - "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", + "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", "spark-core").map(sparkPreviousArtifact(_).get intransitive()) ) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 062bec2381a8f..95c54e7a5ad63 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -704,7 +704,7 @@ def runJob(self, rdd, partitionFunc, partitions = None, allowLocal = False): [0, 1, 16, 25] """ if partitions == None: - partitions = range(rdd._jrdd.splits().size()) + partitions = range(rdd._jrdd.partitions().size()) javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client) # Implementation note: This is implemented as a mapPartitions followed diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 19235d5f79f85..0dbead4415b02 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -43,18 +43,23 @@ def launch_gateway(): # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE, preexec_fn=preexec_func) + proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func) else: # preexec_fn not supported on Windows - proc = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE) - + proc = Popen(command, stdout=PIPE, stdin=PIPE) + try: # Determine which ephemeral port the server started on: - gateway_port = int(proc.stdout.readline()) - except: - error_code = proc.poll() - raise Exception("Launching GatewayServer failed with exit code %d: %s" % - (error_code, "".join(proc.stderr.readlines()))) + gateway_port = proc.stdout.readline() + gateway_port = int(gateway_port) + except ValueError: + (stdout, _) = proc.communicate() + exit_code = proc.poll() + error_msg = "Launching GatewayServer failed" + error_msg += " with exit code %d!" % exit_code if exit_code else "! " + error_msg += "(Warning: unexpected output detected.)\n\n" + error_msg += gateway_port + stdout + raise Exception(error_msg) # Create a thread to echo output from the GatewayServer, which is required # for Java log output to show up: diff --git a/python/pyspark/join.py b/python/pyspark/join.py index 6f94d26ef86a9..5f3a7e71f7866 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -79,15 +79,15 @@ def dispatch(seq): return _do_python_join(rdd, other, numPartitions, dispatch) -def python_cogroup(rdd, other, numPartitions): - vs = rdd.map(lambda (k, v): (k, (1, v))) - ws = other.map(lambda (k, v): (k, (2, v))) +def python_cogroup(rdds, numPartitions): + def make_mapper(i): + return lambda (k, v): (k, (i, v)) + vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)] + union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds) + rdd_len = len(vrdds) def dispatch(seq): - vbuf, wbuf = [], [] + bufs = [[] for i in range(rdd_len)] for (n, v) in seq: - if n == 1: - vbuf.append(v) - elif n == 2: - wbuf.append(v) - return (ResultIterable(vbuf), ResultIterable(wbuf)) - return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch) + bufs[n].append(v) + return tuple(map(ResultIterable, bufs)) + return union_vrdds.groupByKey(numPartitions).mapValues(dispatch) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 62a95c84675dd..f64f48e3a4c9c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -321,7 +321,7 @@ def getNumPartitions(self): >>> rdd.getNumPartitions() 2 """ - return self._jrdd.splits().size() + return self._jrdd.partitions().size() def filter(self, f): """ @@ -922,7 +922,7 @@ def take(self, num): [91, 92, 93] """ items = [] - totalParts = self._jrdd.splits().size() + totalParts = self._jrdd.partitions().size() partsScanned = 0 while len(items) < num and partsScanned < totalParts: @@ -1233,7 +1233,7 @@ def _mergeCombiners(iterator): combiners[k] = mergeCombiners(combiners[k], v) return combiners.iteritems() return shuffled.mapPartitions(_mergeCombiners) - + def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ Aggregate the values of each key, using given combine functions and a neutral "zero value". @@ -1245,7 +1245,7 @@ def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ def createZero(): return copy.deepcopy(zeroValue) - + return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) def foldByKey(self, zeroValue, func, numPartitions=None): @@ -1323,12 +1323,20 @@ def mapValues(self, f): map_values_fn = lambda (k, v): (k, f(v)) return self.map(map_values_fn, preservesPartitioning=True) - # TODO: support varargs cogroup of several RDDs. - def groupWith(self, other): + def groupWith(self, other, *others): """ - Alias for cogroup. + Alias for cogroup but with support for multiple RDDs. + + >>> w = sc.parallelize([("a", 5), ("b", 6)]) + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) + >>> z = sc.parallelize([("b", 42)]) + >>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), list(y[3])))), \ + sorted(list(w.groupWith(x, y, z).collect()))) + [('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))] + """ - return self.cogroup(other) + return python_cogroup((self, other) + others, numPartitions=None) # TODO: add variant with custom parittioner def cogroup(self, other, numPartitions=None): @@ -1342,7 +1350,7 @@ def cogroup(self, other, numPartitions=None): >>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect()))) [('a', ([1], [2])), ('b', ([4], []))] """ - return python_cogroup(self, other, numPartitions) + return python_cogroup((self, other), numPartitions) def subtractByKey(self, other, numPartitions=None): """ diff --git a/sbin/start-history-server.sh b/sbin/start-history-server.sh index 4a90c68763b68..e30493da32a7a 100755 --- a/sbin/start-history-server.sh +++ b/sbin/start-history-server.sh @@ -19,19 +19,18 @@ # Starts the history server on the machine this script is executed on. # -# Usage: start-history-server.sh [] -# Example: ./start-history-server.sh --dir /tmp/spark-events --port 18080 +# Usage: start-history-server.sh +# +# Use the SPARK_HISTORY_OPTS environment variable to set history server configuration. # sbin=`dirname "$0"` sbin=`cd "$sbin"; pwd` -if [ $# -lt 1 ]; then - echo "Usage: ./start-history-server.sh " - echo "Example: ./start-history-server.sh /tmp/spark-events" - exit +if [ $# != 0 ]; then + echo "Using command line arguments for setting the log directory is deprecated. Please " + echo "set the spark.history.fs.logDirectory configuration option instead." + export SPARK_HISTORY_OPTS="$SPARK_HISTORY_OPTS -Dspark.history.fs.logDirectory=$1" fi -LOG_DIR=$1 - -"$sbin"/spark-daemon.sh start org.apache.spark.deploy.history.HistoryServer 1 --dir "$LOG_DIR" +exec "$sbin"/spark-daemon.sh start org.apache.spark.deploy.history.HistoryServer 1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 196695a0a188f..ada48eaf5dc0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -30,53 +30,56 @@ import org.apache.spark.sql.catalyst.types._ object ScalaReflection { import scala.reflect.runtime.universe._ + case class Schema(dataType: DataType, nullable: Boolean) + /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { - case s: StructType => - s.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)()) + case Schema(s: StructType, _) => + s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) } - /** Returns a catalyst DataType for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: DataType = schemaFor(typeOf[T]) + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T]) - /** Returns a catalyst DataType for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`): DataType = tpe match { + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + def schemaFor(tpe: `Type`): Schema = tpe match { case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - schemaFor(optType) + Schema(schemaFor(optType).dataType, nullable = true) case t if t <:< typeOf[Product] => val params = t.member("": TermName).asMethod.paramss - StructType( - params.head.map(p => - StructField(p.name.toString, schemaFor(p.typeSignature), nullable = true))) + Schema(StructType( + params.head.map { p => + val Schema(dataType, nullable) = schemaFor(p.typeSignature) + StructField(p.name.toString, dataType, nullable) + }), nullable = true) // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => BinaryType + case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) case t if t <:< typeOf[Array[_]] => sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") case t if t <:< typeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - ArrayType(schemaFor(elementType)) + Schema(ArrayType(schemaFor(elementType).dataType), nullable = true) case t if t <:< typeOf[Map[_,_]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - MapType(schemaFor(keyType), schemaFor(valueType)) - case t if t <:< typeOf[String] => StringType - case t if t <:< typeOf[Timestamp] => TimestampType - case t if t <:< typeOf[BigDecimal] => DecimalType - case t if t <:< typeOf[java.lang.Integer] => IntegerType - case t if t <:< typeOf[java.lang.Long] => LongType - case t if t <:< typeOf[java.lang.Double] => DoubleType - case t if t <:< typeOf[java.lang.Float] => FloatType - case t if t <:< typeOf[java.lang.Short] => ShortType - case t if t <:< typeOf[java.lang.Byte] => ByteType - case t if t <:< typeOf[java.lang.Boolean] => BooleanType - // TODO: The following datatypes could be marked as non-nullable. - case t if t <:< definitions.IntTpe => IntegerType - case t if t <:< definitions.LongTpe => LongType - case t if t <:< definitions.DoubleTpe => DoubleType - case t if t <:< definitions.FloatTpe => FloatType - case t if t <:< definitions.ShortTpe => ShortType - case t if t <:< definitions.ByteTpe => ByteType - case t if t <:< definitions.BooleanTpe => BooleanType + Schema(MapType(schemaFor(keyType).dataType, schemaFor(valueType).dataType), nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true) + case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) + case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) + case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) + case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) + case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) + case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) + case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) + case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) + case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) + case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) + case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) + case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) + case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) + case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) } implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index c9b7cea6a3e5f..2c71d2c7b3563 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -45,8 +45,10 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) { * that schema. * * In contrast to a normal projection, a MutableProjection reuses the same underlying row object - * each time an input row is added. This significatly reduces the cost of calcuating the - * projection, but means that it is not safe + * each time an input row is added. This significantly reduces the cost of calculating the + * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()` + * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()` + * and hold on to the returned [[Row]] before calling `next()`. */ case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = @@ -67,7 +69,7 @@ case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) } /** - * A mutable wrapper that makes two rows appear appear as a single concatenated row. Designed to + * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to * be instantiated once per thread and reused. */ class JoinedRow extends Row { @@ -81,6 +83,18 @@ class JoinedRow extends Row { this } + /** Updates this JoinedRow by updating its left base row. Returns itself. */ + def withLeft(newLeft: Row): Row = { + row1 = newLeft + this + } + + /** Updates this JoinedRow by updating its right base row. Returns itself. */ + def withRight(newRight: Row): Row = { + row2 = newRight + this + } + def iterator = row1.iterator ++ row2.iterator def length = row1.length + row2.length @@ -124,4 +138,9 @@ class JoinedRow extends Row { } new GenericRow(copiedValues) } + + override def toString() = { + val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) + s"[${row.mkString(",")}]" + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b20b5de8c46eb..fb517e40677ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -257,8 +257,11 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details */ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { - // split the condition expression into 3 parts, - // (canEvaluateInLeftSide, canEvaluateInRightSide, haveToEvaluateWithBothSide) + /** + * Splits join condition expressions into three categories based on the attributes required + * to evaluate them. + * @returns (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) + */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { val (leftEvaluateCondition, rest) = condition.partition(_.references subsetOf left.outputSet) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index a43bef389c4bf..026692abe067d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -105,57 +105,39 @@ object PhysicalOperation extends PredicateHelper { } /** - * A pattern that finds joins with equality conditions that can be evaluated using hashing - * techniques. For inner joins, any filters on top of the join operator are also matched. + * A pattern that finds joins with equality conditions that can be evaluated using equi-join. */ -object HashFilteredJoin extends Logging with PredicateHelper { +object ExtractEquiJoinKeys extends Logging with PredicateHelper { /** (joinType, rightKeys, leftKeys, condition, leftChild, rightChild) */ type ReturnType = (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - // All predicates can be evaluated for inner join (i.e., those that are in the ON - // clause and WHERE clause.) - case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) => - logger.debug(s"Considering hash inner join on: ${predicates ++ condition}") - splitPredicates(predicates ++ condition, join) - // All predicates can be evaluated for left semi join (those that are in the WHERE - // clause can only from left table, so they can all be pushed down.) - case FilteredOperation(predicates, join @ Join(left, right, LeftSemi, condition)) => - logger.debug(s"Considering hash left semi join on: ${predicates ++ condition}") - splitPredicates(predicates ++ condition, join) case join @ Join(left, right, joinType, condition) => - logger.debug(s"Considering hash join on: $condition") - splitPredicates(condition.toSeq, join) - case _ => None - } - - // Find equi-join predicates that can be evaluated before the join, and thus can be used - // as join keys. - def splitPredicates(allPredicates: Seq[Expression], join: Join): Option[ReturnType] = { - val Join(left, right, joinType, _) = join - val (joinPredicates, otherPredicates) = - allPredicates.flatMap(splitConjunctivePredicates).partition { - case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) || - (canEvaluate(l, right) && canEvaluate(r, left)) => true - case _ => false + logger.debug(s"Considering join on: $condition") + // Find equi-join predicates that can be evaluated before the join, and thus can be used + // as join keys. + val (joinPredicates, otherPredicates) = + condition.map(splitConjunctivePredicates).getOrElse(Nil).partition { + case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) || + (canEvaluate(l, right) && canEvaluate(r, left)) => true + case _ => false + } + + val joinKeys = joinPredicates.map { + case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r) + case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l) } - - val joinKeys = joinPredicates.map { - case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r) - case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l) - } - - // Do not consider this strategy if there are no join keys. - if (joinKeys.nonEmpty) { val leftKeys = joinKeys.map(_._1) val rightKeys = joinKeys.map(_._2) - Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) - } else { - logger.debug(s"Avoiding hash join with no join keys.") - None - } + if (joinKeys.nonEmpty) { + logger.debug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") + Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) + } else { + None + } + case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala index 7c616788a3830..582334aa42590 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala @@ -21,5 +21,4 @@ abstract class BaseRelation extends LeafNode { self: Product => def tableName: String - def isPartitioned: Boolean = false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala new file mode 100644 index 0000000000000..489d7e9c2437f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import java.sql.Timestamp + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ + +case class PrimitiveData( + intField: Int, + longField: Long, + doubleField: Double, + floatField: Float, + shortField: Short, + byteField: Byte, + booleanField: Boolean) + +case class NullableData( + intField: java.lang.Integer, + longField: java.lang.Long, + doubleField: java.lang.Double, + floatField: java.lang.Float, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean, + stringField: String, + decimalField: BigDecimal, + timestampField: Timestamp, + binaryField: Array[Byte]) + +case class OptionalData( + intField: Option[Int], + longField: Option[Long], + doubleField: Option[Double], + floatField: Option[Float], + shortField: Option[Short], + byteField: Option[Byte], + booleanField: Option[Boolean]) + +case class ComplexData( + arrayField: Seq[Int], + mapField: Map[Int, String], + structField: PrimitiveData) + +class ScalaReflectionSuite extends FunSuite { + import ScalaReflection._ + + test("primitive data") { + val schema = schemaFor[PrimitiveData] + assert(schema === Schema( + StructType(Seq( + StructField("intField", IntegerType, nullable = false), + StructField("longField", LongType, nullable = false), + StructField("doubleField", DoubleType, nullable = false), + StructField("floatField", FloatType, nullable = false), + StructField("shortField", ShortType, nullable = false), + StructField("byteField", ByteType, nullable = false), + StructField("booleanField", BooleanType, nullable = false))), + nullable = true)) + } + + test("nullable data") { + val schema = schemaFor[NullableData] + assert(schema === Schema( + StructType(Seq( + StructField("intField", IntegerType, nullable = true), + StructField("longField", LongType, nullable = true), + StructField("doubleField", DoubleType, nullable = true), + StructField("floatField", FloatType, nullable = true), + StructField("shortField", ShortType, nullable = true), + StructField("byteField", ByteType, nullable = true), + StructField("booleanField", BooleanType, nullable = true), + StructField("stringField", StringType, nullable = true), + StructField("decimalField", DecimalType, nullable = true), + StructField("timestampField", TimestampType, nullable = true), + StructField("binaryField", BinaryType, nullable = true))), + nullable = true)) + } + + test("optinal data") { + val schema = schemaFor[OptionalData] + assert(schema === Schema( + StructType(Seq( + StructField("intField", IntegerType, nullable = true), + StructField("longField", LongType, nullable = true), + StructField("doubleField", DoubleType, nullable = true), + StructField("floatField", FloatType, nullable = true), + StructField("shortField", ShortType, nullable = true), + StructField("byteField", ByteType, nullable = true), + StructField("booleanField", BooleanType, nullable = true))), + nullable = true)) + } + + test("complex data") { + val schema = schemaFor[ComplexData] + assert(schema === Schema( + StructType(Seq( + StructField("arrayField", ArrayType(IntegerType), nullable = true), + StructField("mapField", MapType(IntegerType, StringType), nullable = true), + StructField( + "structField", + StructType(Seq( + StructField("intField", IntegerType, nullable = false), + StructField("longField", LongType, nullable = false), + StructField("doubleField", DoubleType, nullable = false), + StructField("floatField", FloatType, nullable = false), + StructField("shortField", ShortType, nullable = false), + StructField("byteField", ByteType, nullable = false), + StructField("booleanField", BooleanType, nullable = false))), + nullable = true))), + nullable = true)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index b378252ba2f55..2fe7f94663996 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -29,9 +29,26 @@ import scala.collection.JavaConverters._ */ trait SQLConf { + /** ************************ Spark SQL Params/Hints ******************* */ + // TODO: refactor so that these hints accessors don't pollute the name space of SQLContext? + /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt + /** + * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to + * a broadcast value during the physical executions of join operations. Setting this to 0 + * effectively disables auto conversion. + * Hive setting: hive.auto.convert.join.noconditionaltask.size. + */ + private[spark] def autoConvertJoinSize: Int = + get("spark.sql.auto.convert.join.size", "10000").toInt + + /** A comma-separated list of table names marked to be broadcasted during joins. */ + private[spark] def joinBroadcastTables: String = get("spark.sql.join.broadcastTables", "") + + /** ********************** SQLConf functionality methods ************ */ + @transient private val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ab376e5504d35..7edb548678c33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -170,7 +170,11 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { - catalog.registerTable(None, tableName, rdd.logicalPlan) + val name = tableName + val newPlan = rdd.logicalPlan transform { + case s @ SparkLogicalPlan(ExistingRdd(_, _), _) => s.copy(tableName = name) + } + catalog.registerTable(None, tableName, newPlan) } /** @@ -186,18 +190,23 @@ class SQLContext(@transient val sparkContext: SparkContext) /** Caches the specified table in-memory. */ def cacheTable(tableName: String): Unit = { - val currentTable = catalog.lookupRelation(None, tableName) - val useCompression = - sparkContext.conf.getBoolean("spark.sql.inMemoryColumnarStorage.compressed", false) - val asInMemoryRelation = - InMemoryRelation(useCompression, executePlan(currentTable).executedPlan) + val currentTable = table(tableName).queryExecution.analyzed + val asInMemoryRelation = currentTable match { + case _: InMemoryRelation => + currentTable.logicalPlan + + case _ => + val useCompression = + sparkContext.conf.getBoolean("spark.sql.inMemoryColumnarStorage.compressed", false) + InMemoryRelation(useCompression, executePlan(currentTable).executedPlan) + } catalog.registerTable(None, tableName, asInMemoryRelation) } /** Removes the specified table from the in-memory cache. */ def uncacheTable(tableName: String): Unit = { - EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) match { + table(tableName).queryExecution.analyzed match { // This is kind of a hack to make sure that if this was just an RDD registered as a table, // we reregister the RDD as a table. case inMem @ InMemoryRelation(_, _, e: ExistingRdd) => @@ -213,15 +222,17 @@ class SQLContext(@transient val sparkContext: SparkContext) /** Returns true if the table is currently cached in-memory. */ def isCached(tableName: String): Boolean = { - val relation = catalog.lookupRelation(None, tableName) - EliminateAnalysisOperators(relation) match { + val relation = table(tableName).queryExecution.analyzed + relation match { case _: InMemoryRelation => true case _ => false } } protected[sql] class SparkPlanner extends SparkStrategies { - val sparkContext = self.sparkContext + val sparkContext: SparkContext = self.sparkContext + + val sqlContext: SQLContext = self def numPartitions = self.numShufflePartitions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index ff6deeda2394d..790d9ef22cf16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -137,26 +137,25 @@ class JavaSQLContext(val sqlContext: SQLContext) { val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") fields.map { property => - val dataType = property.getPropertyType match { - case c: Class[_] if c == classOf[java.lang.String] => StringType - case c: Class[_] if c == java.lang.Short.TYPE => ShortType - case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType - case c: Class[_] if c == java.lang.Long.TYPE => LongType - case c: Class[_] if c == java.lang.Double.TYPE => DoubleType - case c: Class[_] if c == java.lang.Byte.TYPE => ByteType - case c: Class[_] if c == java.lang.Float.TYPE => FloatType - case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType - - case c: Class[_] if c == classOf[java.lang.Short] => ShortType - case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType - case c: Class[_] if c == classOf[java.lang.Long] => LongType - case c: Class[_] if c == classOf[java.lang.Double] => DoubleType - case c: Class[_] if c == classOf[java.lang.Byte] => ByteType - case c: Class[_] if c == classOf[java.lang.Float] => FloatType - case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType + val (dataType, nullable) = property.getPropertyType match { + case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) + case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) + case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) + case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) + case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) + case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) + case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) + + case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) + case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) + case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) + case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) + case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) + case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) + case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) } - // TODO: Nullability could be stricter. - AttributeReference(property.getName, dataType, nullable = true)() + AttributeReference(property.getName, dataType, nullable)() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 34d88fe4bd7de..d85d2d7844e0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkContext import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.SQLContext /** * :: DeveloperApi :: @@ -41,7 +42,7 @@ case class Aggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], - child: SparkPlan)(@transient sc: SparkContext) + child: SparkPlan)(@transient sqlContext: SQLContext) extends UnaryNode with NoBind { override def requiredChildDistribution = @@ -55,7 +56,7 @@ case class Aggregate( } } - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs = sqlContext :: Nil // HACK: Generators don't correctly preserve their output through serializations so we grab // out child's output attributes statically here. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index f46fa0516566f..00010ef6e798a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -47,7 +47,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(r => mutablePair.update(hashExpressions(r), r)) } val part = new HashPartitioner(numPartitions) - val shuffled = new ShuffledRDD[Row, Row, MutablePair[Row, Row]](rdd, part) + val shuffled = new ShuffledRDD[Row, Row, Row, MutablePair[Row, Row]](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) @@ -60,7 +60,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(row => mutablePair.update(row, null)) } val part = new RangePartitioner(numPartitions, rdd, ascending = true) - val shuffled = new ShuffledRDD[Row, Null, MutablePair[Row, Null]](rdd, part) + val shuffled = new ShuffledRDD[Row, Null, Null, MutablePair[Row, Null]](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._1) @@ -71,7 +71,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(r => mutablePair.update(null, r)) } val partitioner = new HashPartitioner(1) - val shuffled = new ShuffledRDD[Null, Row, MutablePair[Null, Row]](rdd, partitioner) + val shuffled = new ShuffledRDD[Null, Row, Row, MutablePair[Null, Row]](rdd, partitioner) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 07967fe75e882..27dc091b85812 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.{Logging, Row} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.catalyst.plans.{QueryPlan, logical} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.BaseRelation import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan /** * :: DeveloperApi :: @@ -66,19 +66,20 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { * linking. */ @DeveloperApi -case class SparkLogicalPlan(alreadyPlanned: SparkPlan) - extends logical.LogicalPlan with MultiInstanceRelation { +case class SparkLogicalPlan(alreadyPlanned: SparkPlan, tableName: String = "SparkLogicalPlan") + extends BaseRelation with MultiInstanceRelation { def output = alreadyPlanned.output - def references = Set.empty - def children = Nil + override def references = Set.empty + override def children = Nil override final def newInstance: this.type = { SparkLogicalPlan( alreadyPlanned match { case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) case _ => sys.error("Multiple instance of the same relation detected.") - }).asInstanceOf[this.type] + }, tableName) + .asInstanceOf[this.type] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4694f25d6d630..0925605b7c4d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -21,38 +21,75 @@ import org.apache.spark.sql.{SQLContext, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{BaseRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.parquet._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} +import org.apache.spark.sql.parquet._ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // Find left semi joins where at least some predicates can be evaluated by matching hash - // keys using the HashFilteredJoin pattern. - case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) => + // Find left semi joins where at least some predicates can be evaluated by matching join keys + case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => val semiJoin = execution.LeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => execution.LeftSemiJoinBNL( - planLater(left), planLater(right), condition)(sparkContext) :: Nil + planLater(left), planLater(right), condition)(sqlContext) :: Nil case _ => Nil } } + /** + * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be + * evaluated by matching hash keys. + */ object HashJoin extends Strategy with PredicateHelper { + private[this] def broadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: LogicalPlan, + right: LogicalPlan, + condition: Option[Expression], + side: BuildSide) = { + val broadcastHashJoin = execution.BroadcastHashJoin( + leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext) + condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil + } + + def broadcastTables: Seq[String] = sqlContext.joinBroadcastTables.split(",").toBuffer + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // Find inner joins where at least some predicates can be evaluated by matching hash keys - // using the HashFilteredJoin pattern. - case HashFilteredJoin(Inner, leftKeys, rightKeys, condition, left, right) => + case ExtractEquiJoinKeys( + Inner, + leftKeys, + rightKeys, + condition, + left, + right @ PhysicalOperation(_, _, b: BaseRelation)) + if broadcastTables.contains(b.tableName) => + broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) + + case ExtractEquiJoinKeys( + Inner, + leftKeys, + rightKeys, + condition, + left @ PhysicalOperation(_, _, b: BaseRelation), + right) + if broadcastTables.contains(b.tableName) => + broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) + + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val hashJoin = - execution.HashJoin(leftKeys, rightKeys, BuildRight, planLater(left), planLater(right)) + execution.ShuffledHashJoin( + leftKeys, rightKeys, BuildRight, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + case _ => Nil } } @@ -62,10 +99,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => // Collect all aggregate expressions. val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) + aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a }) // Collect all aggregate expressions that can be computed partially. val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p}) + aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p }) // Only do partial aggregation if supported by all aggregate expressions. if (allAggregates.size == partialAggregates.size) { @@ -103,7 +140,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { partial = true, groupingExpressions, partialComputation, - planLater(child))(sparkContext))(sparkContext) :: Nil + planLater(child))(sqlContext))(sqlContext) :: Nil } else { Nil } @@ -115,7 +152,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => execution.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joinType, condition)(sparkContext) :: Nil + planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil case _ => Nil } } @@ -143,7 +180,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object TakeOrdered extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) => - execution.TakeOrdered(limit, order, planLater(child))(sparkContext) :: Nil + execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil case _ => Nil } } @@ -155,9 +192,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val relation = ParquetRelation.create(path, child, sparkContext.hadoopConfiguration) // Note: overwrite=false because otherwise the metadata we just created will be deleted - InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sparkContext) :: Nil + InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => - InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil + InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => val prunePushedDownFilters = if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { @@ -186,7 +223,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { projectList, filters, prunePushedDownFilters, - ParquetTableScan(_, relation, filters)(sparkContext)) :: Nil + ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil case _ => Nil } @@ -211,7 +248,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Distinct(child) => execution.Aggregate( - partial = false, child.output, child.output, planLater(child))(sparkContext) :: Nil + partial = false, child.output, child.output, planLater(child))(sqlContext) :: Nil case logical.Sort(sortExprs, child) => // This sort is a global sort. Its requiredDistribution will be an OrderedDistribution. execution.Sort(sortExprs, global = true, planLater(child)):: Nil @@ -224,7 +261,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => - execution.Aggregate(partial = false, group, agg, planLater(child))(sparkContext) :: Nil + execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => @@ -233,16 +270,16 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row)) execution.ExistingRdd(output, dataAsRdd) :: Nil case logical.Limit(IntegerLiteral(limit), child) => - execution.Limit(limit, planLater(child))(sparkContext) :: Nil + execution.Limit(limit, planLater(child))(sqlContext) :: Nil case Unions(unionChildren) => - execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil + execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil case logical.Generate(generator, join, outer, _, child) => execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil case logical.NoRelation => execution.ExistingRdd(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil - case SparkLogicalPlan(existingPlan) => existingPlan :: Nil + case SparkLogicalPlan(existingPlan, _) => existingPlan :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 8969794c69933..a278f1ca98476 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution import scala.reflect.runtime.universe.TypeTag import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.{HashPartitioner, SparkConf, SparkContext} +import org.apache.spark.{HashPartitioner, SparkConf} import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ @@ -70,12 +71,12 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: * :: DeveloperApi :: */ @DeveloperApi -case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan { +case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes override def output = children.head.output - override def execute() = sc.union(children.map(_.execute())) + override def execute() = sqlContext.sparkContext.union(children.map(_.execute())) - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs = sqlContext :: Nil } /** @@ -87,11 +88,12 @@ case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends * data to a single partition to compute the global limit. */ @DeveloperApi -case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode { +case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext) + extends UnaryNode { // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: // partition local limit -> exchange into one partition -> partition local limit again - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs = sqlContext :: Nil override def output = child.output @@ -103,7 +105,7 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) exte iter.take(limit).map(row => mutablePair.update(false, row)) } val part = new HashPartitioner(1) - val shuffled = new ShuffledRDD[Boolean, Row, MutablePair[Boolean, Row]](rdd, part) + val shuffled = new ShuffledRDD[Boolean, Row, Row, MutablePair[Boolean, Row]](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.mapPartitions(_.take(limit).map(_._2)) } @@ -117,8 +119,8 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) exte */ @DeveloperApi case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) - (@transient sc: SparkContext) extends UnaryNode { - override def otherCopyArgs = sc :: Nil + (@transient sqlContext: SQLContext) extends UnaryNode { + override def otherCopyArgs = sqlContext :: Nil override def output = child.output @@ -129,7 +131,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - override def execute() = sc.makeRDD(executeCollect(), 1) + override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1) } /** @@ -203,4 +205,3 @@ object ExistingRdd { case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { override def execute() = rdd } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 8d7a5ba59f96a..32c5f26fe8aa0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -18,13 +18,15 @@ package org.apache.spark.sql.execution import scala.collection.mutable.{ArrayBuffer, BitSet} - -import org.apache.spark.SparkContext +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent._ +import scala.concurrent.duration._ import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical._ @DeveloperApi sealed abstract class BuildSide @@ -35,28 +37,19 @@ case object BuildLeft extends BuildSide @DeveloperApi case object BuildRight extends BuildSide -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class HashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) extends BinaryNode { +trait HashJoin { + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val buildSide: BuildSide + val left: SparkPlan + val right: SparkPlan - override def outputPartitioning: Partitioning = left.outputPartitioning - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - val (buildPlan, streamedPlan) = buildSide match { + lazy val (buildPlan, streamedPlan) = buildSide match { case BuildLeft => (left, right) case BuildRight => (right, left) } - val (buildKeys, streamedKeys) = buildSide match { + lazy val (buildKeys, streamedKeys) = buildSide match { case BuildLeft => (leftKeys, rightKeys) case BuildRight => (rightKeys, leftKeys) } @@ -67,73 +60,74 @@ case class HashJoin( @transient lazy val streamSideKeyGenerator = () => new MutableProjection(streamedKeys, streamedPlan.output) - def execute() = { - - buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - // TODO: Use Spark's HashMap implementation. - val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() - var currentRow: Row = null - - // Create a mapping of buildKeys -> rows - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if(!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new ArrayBuffer[Row]() - hashTable.put(rowKey, newMatchList) - newMatchList - } else { - existingMatchList - } - matchList += currentRow.copy() + def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { + // TODO: Use Spark's HashMap implementation. + + val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() + var currentRow: Row = null + + // Create a mapping of buildKeys -> rows + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if(!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new ArrayBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + existingMatchList } + matchList += currentRow.copy() } + } - new Iterator[Row] { - private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatches: ArrayBuffer[Row] = _ - private[this] var currentMatchPosition: Int = -1 + new Iterator[Row] { + private[this] var currentStreamedRow: Row = _ + private[this] var currentHashMatches: ArrayBuffer[Row] = _ + private[this] var currentMatchPosition: Int = -1 - // Mutable per row objects. - private[this] val joinRow = new JoinedRow + // Mutable per row objects. + private[this] val joinRow = new JoinedRow - private[this] val joinKeys = streamSideKeyGenerator() + private[this] val joinKeys = streamSideKeyGenerator() - override final def hasNext: Boolean = - (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || + override final def hasNext: Boolean = + (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || (streamIter.hasNext && fetchNext()) - override final def next() = { - val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) - currentMatchPosition += 1 - ret + override final def next() = { + val ret = buildSide match { + case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) } + currentMatchPosition += 1 + ret + } - /** - * Searches the streamed iterator for the next row that has at least one match in hashtable. - * - * @return true if the search is successful, and false the streamed iterator runs out of - * tuples. - */ - private final def fetchNext(): Boolean = { - currentHashMatches = null - currentMatchPosition = -1 - - while (currentHashMatches == null && streamIter.hasNext) { - currentStreamedRow = streamIter.next() - if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashTable.get(joinKeys.currentValue) - } + /** + * Searches the streamed iterator for the next row that has at least one match in hashtable. + * + * @return true if the search is successful, and false if the streamed iterator runs out of + * tuples. + */ + private final def fetchNext(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + if (!joinKeys(currentStreamedRow).anyNull) { + currentHashMatches = hashTable.get(joinKeys.currentValue) } + } - if (currentHashMatches == null) { - false - } else { - currentMatchPosition = 0 - true - } + if (currentHashMatches == null) { + false + } else { + currentMatchPosition = 0 + true } } } @@ -142,32 +136,49 @@ case class HashJoin( /** * :: DeveloperApi :: - * Build the right table's join keys into a HashSet, and iteratively go through the left - * table, to find the if join keys are in the Hash set. + * Performs an inner hash join of two child relations by first shuffling the data using the join + * keys. */ @DeveloperApi -case class LeftSemiJoinHash( +case class ShuffledHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + buildSide: BuildSide, left: SparkPlan, - right: SparkPlan) extends BinaryNode { + right: SparkPlan) extends BinaryNode with HashJoin { override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - val (buildPlan, streamedPlan) = (right, left) - val (buildKeys, streamedKeys) = (rightKeys, leftKeys) + def execute() = { + buildPlan.execute().zipPartitions(streamedPlan.execute()) { + (buildIter, streamIter) => joinIterators(buildIter, streamIter) + } + } +} - def output = left.output +/** + * :: DeveloperApi :: + * Build the right table's join keys into a HashSet, and iteratively go through the left + * table, to find the if join keys are in the Hash set. + */ +@DeveloperApi +case class LeftSemiJoinHash( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashJoin { - @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) - @transient lazy val streamSideKeyGenerator = - () => new MutableProjection(streamedKeys, streamedPlan.output) + val buildSide = BuildRight - def execute() = { + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + override def output = left.output + + def execute() = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => val hashSet = new java.util.HashSet[Row]() var currentRow: Row = null @@ -192,6 +203,43 @@ case class LeftSemiJoinHash( } } + +/** + * :: DeveloperApi :: + * Performs an inner hash join of two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +@DeveloperApi +case class BroadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode with HashJoin { + + override def otherCopyArgs = sqlContext :: Nil + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + + @transient + lazy val broadcastFuture = future { + sqlContext.sparkContext.broadcast(buildPlan.executeCollect()) + } + + def execute() = { + val broadcastRelation = Await.result(broadcastFuture, 5.minute) + + streamedPlan.execute().mapPartitions { streamedIter => + joinIterators(broadcastRelation.value.iterator, streamedIter) + } + } +} + /** * :: DeveloperApi :: * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys @@ -200,13 +248,13 @@ case class LeftSemiJoinHash( @DeveloperApi case class LeftSemiJoinBNL( streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - (@transient sc: SparkContext) + (@transient sqlContext: SQLContext) extends BinaryNode { // TODO: Override requiredChildDistribution. override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs = sqlContext :: Nil def output = left.output @@ -221,9 +269,9 @@ case class LeftSemiJoinBNL( .map(c => BindReferences.bindReference(c, left.output ++ right.output)) .getOrElse(Literal(true))) - def execute() = { - val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + val broadcastedRelation = + sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow @@ -263,13 +311,13 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod @DeveloperApi case class BroadcastNestedLoopJoin( streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression]) - (@transient sc: SparkContext) + (@transient sqlContext: SQLContext) extends BinaryNode { // TODO: Override requiredChildDistribution. override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs = sqlContext :: Nil def output = left.output ++ right.output @@ -284,9 +332,9 @@ case class BroadcastNestedLoopJoin( .map(c => BindReferences.bindReference(c, left.output ++ right.output)) .getOrElse(Literal(true))) - def execute() = { - val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + val broadcastedRelation = + sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => val matchedRows = new ArrayBuffer[Row] @@ -337,7 +385,7 @@ case class BroadcastNestedLoopJoin( } // TODO: Breaks lineage. - sc.union( - streamedPlusMatches.flatMap(_._1), sc.makeRDD(rightOuterMatches)) + sqlContext.sparkContext.union( + streamedPlusMatches.flatMap(_._1), sqlContext.sparkContext.makeRDD(rightOuterMatches)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 96c131a7f8af1..9c4771d1a9846 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -44,8 +44,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} * @param path The path to the Parquet file. */ private[sql] case class ParquetRelation( - val path: String, - @transient val conf: Option[Configuration] = None) extends LeafNode with MultiInstanceRelation { + path: String, + @transient conf: Option[Configuration] = None) extends LeafNode with MultiInstanceRelation { + self: Product => /** Schema derived from ParquetFile */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 624f2e2fa13f6..ade823b51c9cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -33,10 +33,10 @@ import parquet.hadoop.util.ContextUtil import parquet.io.InvalidRecordException import parquet.schema.MessageType -import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext} +import org.apache.spark.{Logging, SerializableWritable, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} -import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} /** @@ -49,10 +49,11 @@ case class ParquetTableScan( output: Seq[Attribute], relation: ParquetRelation, columnPruningPred: Seq[Expression])( - @transient val sc: SparkContext) + @transient val sqlContext: SQLContext) extends LeafNode { override def execute(): RDD[Row] = { + val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) ParquetInputFormat.setReadSupportClass( job, @@ -93,7 +94,7 @@ case class ParquetTableScan( .filter(_ != null) // Parquet's record filters may produce null values } - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs = sqlContext :: Nil /** * Applies a (candidate) projection. @@ -104,7 +105,7 @@ case class ParquetTableScan( def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = { val success = validateProjection(prunedAttributes) if (success) { - ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sc) + ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sqlContext) } else { sys.error("Warning: Could not validate Parquet schema projection in pruneColumns") this @@ -152,7 +153,7 @@ case class InsertIntoParquetTable( relation: ParquetRelation, child: SparkPlan, overwrite: Boolean = false)( - @transient val sc: SparkContext) + @transient val sqlContext: SQLContext) extends UnaryNode with SparkHadoopMapReduceUtil { /** @@ -168,7 +169,7 @@ case class InsertIntoParquetTable( val childRdd = child.execute() assert(childRdd != null) - val job = new Job(sc.hadoopConfiguration) + val job = new Job(sqlContext.sparkContext.hadoopConfiguration) val writeSupport = if (child.output.map(_.dataType).forall(_.isPrimitive)) { @@ -204,7 +205,7 @@ case class InsertIntoParquetTable( override def output = child.output - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs = sqlContext :: Nil /** * Stores the given Row RDD as a Hadoop file. @@ -231,7 +232,7 @@ case class InsertIntoParquetTable( val wrappedConf = new SerializableWritable(job.getConfiguration) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) - val stageId = sc.newRddId() + val stageId = sqlContext.sparkContext.newRddId() val taskIdOffset = if (overwrite) { @@ -270,7 +271,7 @@ case class InsertIntoParquetTable( val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) jobCommitter.setupJob(jobTaskContext) - sc.runJob(rdd, writeShard _) + sqlContext.sparkContext.runJob(rdd, writeShard _) jobCommitter.commitJob(jobTaskContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index c794da4da4069..c3c0dcb1aa00b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -20,10 +20,30 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ class CachedTableSuite extends QueryTest { TestData // Load test tables. + test("SPARK-1669: cacheTable should be idempotent") { + assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + + cacheTable("testData") + table("testData").queryExecution.analyzed match { + case _: InMemoryRelation => + case _ => + fail("testData should be cached") + } + + cacheTable("testData") + table("testData").queryExecution.analyzed match { + case InMemoryRelation(_, _, _: InMemoryColumnarTableScan) => + fail("cacheTable is not idempotent") + + case _ => + } + } + test("read from cached table and uncache") { TestSQLContext.cacheTable("testData") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index fb599e1e01e73..e4a64a7a482b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.test._ /* Implicits */ @@ -149,102 +148,4 @@ class DslQuerySuite extends QueryTest { test("zero count") { assert(emptyTableData.count() === 0) } - - test("inner join where, one match per row") { - checkAnswer( - upperCaseData.join(lowerCaseData, Inner).where('n === 'N), - Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d") - )) - } - - test("inner join ON, one match per row") { - checkAnswer( - upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)), - Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d") - )) - } - - test("inner join, where, multiple matches") { - val x = testData2.where('a === 1).as('x) - val y = testData2.where('a === 1).as('y) - checkAnswer( - x.join(y).where("x.a".attr === "y.a".attr), - (1,1,1,1) :: - (1,1,1,2) :: - (1,2,1,1) :: - (1,2,1,2) :: Nil - ) - } - - test("inner join, no matches") { - val x = testData2.where('a === 1).as('x) - val y = testData2.where('a === 2).as('y) - checkAnswer( - x.join(y).where("x.a".attr === "y.a".attr), - Nil) - } - - test("big inner join, 4 matches per row") { - val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) - val bigDataX = bigData.as('x) - val bigDataY = bigData.as('y) - - checkAnswer( - bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr), - testData.flatMap( - row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq) - } - - test("cartisian product join") { - checkAnswer( - testData3.join(testData3), - (1, null, 1, null) :: - (1, null, 2, 2) :: - (2, 2, 1, null) :: - (2, 2, 2, 2) :: Nil) - } - - test("left outer join") { - checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), - (1, "A", 1, "a") :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) - } - - test("right outer join") { - checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), - (1, "a", 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) - } - - test("full outer join") { - val left = upperCaseData.where('N <= 4).as('left) - val right = upperCaseData.where('N >= 3).as('right) - - checkAnswer( - left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala new file mode 100644 index 0000000000000..3d7d5eedbe8ed --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ + +class JoinSuite extends QueryTest { + + // Ensures tables are loaded. + TestData + + test("equi-join is hash-join") { + val x = testData2.as('x) + val y = testData2.as('y) + val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed + val planned = planner.HashJoin(join) + assert(planned.size === 1) + } + + test("plans broadcast hash join, given hints") { + + def mkTest(buildSide: BuildSide, leftTable: String, rightTable: String) = { + TestSQLContext.set("spark.sql.join.broadcastTables", + s"${if (buildSide == BuildRight) rightTable else leftTable}") + val rdd = sql(s"""SELECT * FROM $leftTable JOIN $rightTable ON key = a""") + // Using `sparkPlan` because for relevant patterns in HashJoin to be + // matched, other strategies need to be applied. + val physical = rdd.queryExecution.sparkPlan + val bhj = physical.collect { case j: BroadcastHashJoin if j.buildSide == buildSide => j } + + assert(bhj.size === 1, "planner does not pick up hint to generate broadcast hash join") + checkAnswer( + rdd, + Seq( + (1, "1", 1, 1), + (1, "1", 1, 2), + (2, "2", 2, 1), + (2, "2", 2, 2), + (3, "3", 3, 1), + (3, "3", 3, 2) + )) + } + + mkTest(BuildRight, "testData", "testData2") + mkTest(BuildLeft, "testData", "testData2") + } + + test("multiple-key equi-join is hash-join") { + val x = testData2.as('x) + val y = testData2.as('y) + val join = x.join(y, Inner, + Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed + val planned = planner.HashJoin(join) + assert(planned.size === 1) + } + + test("inner join where, one match per row") { + checkAnswer( + upperCaseData.join(lowerCaseData, Inner).where('n === 'N), + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + )) + } + + test("inner join ON, one match per row") { + checkAnswer( + upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)), + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + )) + } + + test("inner join, where, multiple matches") { + val x = testData2.where('a === 1).as('x) + val y = testData2.where('a === 1).as('y) + checkAnswer( + x.join(y).where("x.a".attr === "y.a".attr), + (1,1,1,1) :: + (1,1,1,2) :: + (1,2,1,1) :: + (1,2,1,2) :: Nil + ) + } + + test("inner join, no matches") { + val x = testData2.where('a === 1).as('x) + val y = testData2.where('a === 2).as('y) + checkAnswer( + x.join(y).where("x.a".attr === "y.a".attr), + Nil) + } + + test("big inner join, 4 matches per row") { + val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) + val bigDataX = bigData.as('x) + val bigDataY = bigData.as('y) + + checkAnswer( + bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr), + testData.flatMap( + row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq) + } + + test("cartisian product join") { + checkAnswer( + testData3.join(testData3), + (1, null, 1, null) :: + (1, null, 2, 2) :: + (2, 2, 1, null) :: + (2, 2, 2, 2) :: Nil) + } + + test("left outer join") { + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), + (1, "A", 1, "a") :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + } + + test("right outer join") { + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), + (1, "a", 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + } + + test("full outer join") { + val left = upperCaseData.where('N <= 4).as('left) + val right = upperCaseData.where('N >= 3).as('right) + + checkAnswer( + left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index ef84ead2e6e8b..8e1e1971d968b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -35,7 +35,7 @@ class QueryTest extends PlanTest { case singleItem => Seq(Seq(singleItem)) } - val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s}.nonEmpty + val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Any]) = if (!isSorted) answer.sortBy(_.toString) else answer val sparkAnswer = try rdd.collect().toSeq catch { case e: Exception => @@ -48,7 +48,7 @@ class QueryTest extends PlanTest { """.stripMargin) } - if(prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) { + if (prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) { fail(s""" |Results do not match for query: |${rdd.logicalPlan} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e9360b0fc7910..bf7fafe952303 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.test._ /* Implicits */ @@ -404,5 +406,4 @@ class SQLQuerySuite extends QueryTest { ) clear() } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index df6b118360d01..215618e852eb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -57,21 +57,4 @@ class PlannerSuite extends FunSuite { val planned = PartialAggregation(query) assert(planned.isEmpty) } - - test("equi-join is hash-join") { - val x = testData2.as('x) - val y = testData2.as('y) - val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed - val planned = planner.HashJoin(join) - assert(planned.size === 1) - } - - test("multiple-key equi-join is hash-join") { - val x = testData2.as('x) - val y = testData2.as('y) - val join = x.join(y, Inner, - Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed - val planned = planner.HashJoin(join) - assert(planned.size === 1) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 7714eb1b5628a..2ca0c1cdcbeca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -166,7 +166,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val scanner = new ParquetTableScan( ParquetTestData.testData.output, ParquetTestData.testData, - Seq())(TestSQLContext.sparkContext) + Seq())(TestSQLContext) val projected = scanner.pruneColumns(ParquetTypesConverter .convertToAttributes(MessageTypeParser .parseMessageType(ParquetTestData.subTestSchema))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7695242a81601..7aedfcd74189b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -258,7 +258,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ))=> + case (seq: Seq[_], ArrayType(typ)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_,_], MapType(kType, vType)) => map.map { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f923d68932f83..90eacf4268780 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -34,9 +34,8 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.execution.SparkLogicalPlan -import org.apache.spark.sql.hive.execution.{HiveTableScan, InsertIntoHiveTable} -import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} +import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.hive.execution.HiveTableScan /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -105,7 +104,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with object CreateTables extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case InsertIntoCreatedTable(db, tableName, child) => - val databaseName = db.getOrElse(SessionState.get.getCurrentDatabase) + val databaseName = db.getOrElse(hive.sessionState.getCurrentDatabase) createTable(databaseName, tableName, child.output) @@ -259,8 +258,6 @@ private[hive] case class MetastoreRelation new Partition(hiveQlTable, p) } - override def isPartitioned = hiveQlTable.isPartitioned - val tableDesc = new TableDesc( Class.forName(hiveQlTable.getSerializationLib).asInstanceOf[Class[Deserializer]], hiveQlTable.getInputFormatClass, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ec653efcc8c58..b70104dd5be5a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -38,8 +38,6 @@ import scala.collection.JavaConversions._ */ private[hive] case object NativePlaceholder extends Command -private[hive] case class DfsCommand(cmd: String) extends Command - private[hive] case class ShellCommand(cmd: String) extends Command private[hive] case class SourceCommand(filePath: String) extends Command @@ -204,6 +202,9 @@ private[hive] object HiveQl { class ParseException(sql: String, cause: Throwable) extends Exception(s"Failed to parse: $sql", cause) + class SemanticException(msg: String) + extends Exception(s"Error in semantic analysis: $msg") + /** * Returns the AST for the given SQL string. */ @@ -224,15 +225,15 @@ private[hive] object HiveQl { SetCommand(Some(key), Some(value)) } } else if (sql.trim.toLowerCase.startsWith("cache table")) { - CacheCommand(sql.drop(12).trim, true) + CacheCommand(sql.trim.drop(12).trim, true) } else if (sql.trim.toLowerCase.startsWith("uncache table")) { - CacheCommand(sql.drop(14).trim, false) + CacheCommand(sql.trim.drop(14).trim, false) } else if (sql.trim.toLowerCase.startsWith("add jar")) { - AddJar(sql.drop(8)) + AddJar(sql.trim.drop(8)) } else if (sql.trim.toLowerCase.startsWith("add file")) { - AddFile(sql.drop(9)) - } else if (sql.trim.startsWith("dfs")) { - DfsCommand(sql) + AddFile(sql.trim.drop(9)) + } else if (sql.trim.toLowerCase.startsWith("dfs")) { + NativeCommand(sql) } else if (sql.trim.startsWith("source")) { SourceCommand(sql.split(" ").toSeq match { case Seq("source", filePath) => filePath }) } else if (sql.trim.startsWith("!")) { @@ -480,6 +481,7 @@ private[hive] object HiveQl { whereClause :: groupByClause :: orderByClause :: + havingClause :: sortByClause :: clusterByClause :: distributeByClause :: @@ -494,6 +496,7 @@ private[hive] object HiveQl { "TOK_WHERE", "TOK_GROUPBY", "TOK_ORDERBY", + "TOK_HAVING", "TOK_SORTBY", "TOK_CLUSTERBY", "TOK_DISTRIBUTEBY", @@ -558,7 +561,6 @@ private[hive] object HiveQl { withWhere) }.getOrElse(withWhere) - // The projection of the query can either be a normal projection, an aggregation // (if there is a group by) or a script transformation. val withProject = transformation.getOrElse { @@ -576,21 +578,28 @@ private[hive] object HiveQl { val withDistinct = if (selectDistinctClause.isDefined) Distinct(withProject) else withProject + val withHaving = havingClause.map { h => + val havingExpr = h.getChildren.toSeq match { case Seq(hexpr) => nodeToExpr(hexpr) } + // Note that we added a cast to boolean. If the expression itself is already boolean, + // the optimizer will get rid of the unnecessary cast. + Filter(Cast(havingExpr, BooleanType), withDistinct) + }.getOrElse(withDistinct) + val withSort = (orderByClause, sortByClause, distributeByClause, clusterByClause) match { case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.getChildren.map(nodeToSortOrder), withDistinct) + Sort(totalOrdering.getChildren.map(nodeToSortOrder), withHaving) case (None, Some(perPartitionOrdering), None, None) => - SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withDistinct) + SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withHaving) case (None, None, Some(partitionExprs), None) => - Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct) + Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving) case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), - Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct)) + Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, Some(clusterExprs)) => SortPartitions(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), - Repartition(clusterExprs.getChildren.map(nodeToExpr), withDistinct)) - case (None, None, None, None) => withDistinct + Repartition(clusterExprs.getChildren.map(nodeToExpr), withHaving)) + case (None, None, None, None) => withHaving case _ => sys.error("Unsupported set of ordering / distribution clauses.") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index af7687b40429b..4d0fab4140b21 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -64,7 +64,6 @@ private[hive] trait HiveStrategies { val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet val (pruningPredicates, otherPredicates) = predicates.partition { _.references.map(_.exprId).subsetOf(partitionKeyIds) - } pruneFilterProject( @@ -81,16 +80,16 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.NativeCommand(sql) => NativeCommand(sql, plan.output)(context) :: Nil - case describe: logical.DescribeCommand => { + + case describe: logical.DescribeCommand => val resolvedTable = context.executePlan(describe.table).analyzed resolvedTable match { case t: MetastoreRelation => - Seq(DescribeHiveTableCommand( - t, describe.output, describe.isExtended)(context)) + Seq(DescribeHiveTableCommand(t, describe.output, describe.isExtended)(context)) case o: LogicalPlan => Seq(DescribeCommand(planLater(o), describe.output)(context)) } - } + case _ => Nil } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala new file mode 100644 index 0000000000000..a40e89e0d382b --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.hive.metastore.api.FieldSchema + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow, Row} +import org.apache.spark.sql.execution.{Command, LeafNode} +import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation} + +/** + * Implementation for "describe [extended] table". + * + * :: DeveloperApi :: + */ +@DeveloperApi +case class DescribeHiveTableCommand( + table: MetastoreRelation, + output: Seq[Attribute], + isExtended: Boolean)( + @transient context: HiveContext) + extends LeafNode with Command { + + // Strings with the format like Hive. It is used for result comparison in our unit tests. + lazy val hiveString: Seq[String] = { + val alignment = 20 + val delim = "\t" + + sideEffectResult.map { + case (name, dataType, comment) => + String.format("%-" + alignment + "s", name) + delim + + String.format("%-" + alignment + "s", dataType) + delim + + String.format("%-" + alignment + "s", Option(comment).getOrElse("None")) + } + } + + override protected[sql] lazy val sideEffectResult: Seq[(String, String, String)] = { + // Trying to mimic the format of Hive's output. But not exactly the same. + var results: Seq[(String, String, String)] = Nil + + val columns: Seq[FieldSchema] = table.hiveQlTable.getCols + val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols + results ++= columns.map(field => (field.getName, field.getType, field.getComment)) + if (!partitionColumns.isEmpty) { + val partColumnInfo = + partitionColumns.map(field => (field.getName, field.getType, field.getComment)) + results ++= + partColumnInfo ++ + Seq(("# Partition Information", "", "")) ++ + Seq((s"# ${output.get(0).name}", output.get(1).name, output.get(2).name)) ++ + partColumnInfo + } + + if (isExtended) { + results ++= Seq(("Detailed Table Information", table.hiveQlTable.getTTable.toString, "")) + } + + results + } + + override def execute(): RDD[Row] = { + val rows = sideEffectResult.map { + case (name, dataType, comment) => new GenericRow(Array[Any](name, dataType, comment)) + } + context.sparkContext.parallelize(rows, 1) + } + + override def otherCopyArgs = context :: Nil +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala new file mode 100644 index 0000000000000..ef8bae74530ec --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition} +import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.hadoop.hive.serde2.ColumnProjectionUtils +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{BooleanType, DataType} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.hive._ +import org.apache.spark.util.MutablePair + +/** + * :: DeveloperApi :: + * The Hive table scan operator. Column and partition pruning are both handled. + * + * @param attributes Attributes to be fetched from the Hive table. + * @param relation The Hive table be be scanned. + * @param partitionPruningPred An optional partition pruning predicate for partitioned table. + */ +@DeveloperApi +case class HiveTableScan( + attributes: Seq[Attribute], + relation: MetastoreRelation, + partitionPruningPred: Option[Expression])( + @transient val context: HiveContext) + extends LeafNode + with HiveInspectors { + + require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, + "Partition pruning predicates only supported for partitioned tables.") + + // Bind all partition key attribute references in the partition pruning predicate for later + // evaluation. + private[this] val boundPruningPred = partitionPruningPred.map { pred => + require( + pred.dataType == BooleanType, + s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") + + BindReferences.bindReference(pred, relation.partitionKeys) + } + + @transient + private[this] val hadoopReader = new HadoopTableReader(relation.tableDesc, context) + + /** + * The hive object inspector for this table, which can be used to extract values from the + * serialized row representation. + */ + @transient + private[this] lazy val objectInspector = + relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector] + + /** + * Functions that extract the requested attributes from the hive output. Partitioned values are + * casted from string to its declared data type. + */ + @transient + protected lazy val attributeFunctions: Seq[(Any, Array[String]) => Any] = { + attributes.map { a => + val ordinal = relation.partitionKeys.indexOf(a) + if (ordinal >= 0) { + val dataType = relation.partitionKeys(ordinal).dataType + (_: Any, partitionKeys: Array[String]) => { + castFromString(partitionKeys(ordinal), dataType) + } + } else { + val ref = objectInspector.getAllStructFieldRefs + .find(_.getFieldName == a.name) + .getOrElse(sys.error(s"Can't find attribute $a")) + val fieldObjectInspector = ref.getFieldObjectInspector + + val unwrapHiveData = fieldObjectInspector match { + case _: HiveVarcharObjectInspector => + (value: Any) => value.asInstanceOf[HiveVarchar].getValue + case _: HiveDecimalObjectInspector => + (value: Any) => BigDecimal(value.asInstanceOf[HiveDecimal].bigDecimalValue()) + case _ => + identity[Any] _ + } + + (row: Any, _: Array[String]) => { + val data = objectInspector.getStructFieldData(row, ref) + val hiveData = unwrapData(data, fieldObjectInspector) + if (hiveData != null) unwrapHiveData(hiveData) else null + } + } + } + } + + private[this] def castFromString(value: String, dataType: DataType) = { + Cast(Literal(value), dataType).eval(null) + } + + private def addColumnMetadataToConf(hiveConf: HiveConf) { + // Specifies IDs and internal names of columns to be scanned. + val neededColumnIDs = attributes.map(a => relation.output.indexWhere(_.name == a.name): Integer) + val columnInternalNames = neededColumnIDs.map(HiveConf.getColumnInternalName(_)).mkString(",") + + if (attributes.size == relation.output.size) { + ColumnProjectionUtils.setFullyReadColumns(hiveConf) + } else { + ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs) + } + + ColumnProjectionUtils.appendReadColumnNames(hiveConf, attributes.map(_.name)) + + // Specifies types and object inspectors of columns to be scanned. + val structOI = ObjectInspectorUtils + .getStandardObjectInspector( + relation.tableDesc.getDeserializer.getObjectInspector, + ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + val columnTypeNames = structOI + .getAllStructFieldRefs + .map(_.getFieldObjectInspector) + .map(TypeInfoUtils.getTypeInfoFromObjectInspector(_).getTypeName) + .mkString(",") + + hiveConf.set(serdeConstants.LIST_COLUMN_TYPES, columnTypeNames) + hiveConf.set(serdeConstants.LIST_COLUMNS, columnInternalNames) + } + + addColumnMetadataToConf(context.hiveconf) + + private def inputRdd = if (!relation.hiveQlTable.isPartitioned) { + hadoopReader.makeRDDForTable(relation.hiveQlTable) + } else { + hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) + } + + /** + * Prunes partitions not involve the query plan. + * + * @param partitions All partitions of the relation. + * @return Partitions that are involved in the query plan. + */ + private[hive] def prunePartitions(partitions: Seq[HivePartition]) = { + boundPruningPred match { + case None => partitions + case Some(shouldKeep) => partitions.filter { part => + val dataTypes = relation.partitionKeys.map(_.dataType) + val castedValues = for ((value, dataType) <- part.getValues.zip(dataTypes)) yield { + castFromString(value, dataType) + } + + // Only partitioned values are needed here, since the predicate has already been bound to + // partition key attribute references. + val row = new GenericRow(castedValues.toArray) + shouldKeep.eval(row).asInstanceOf[Boolean] + } + } + } + + override def execute() = { + inputRdd.mapPartitions { iterator => + if (iterator.isEmpty) { + Iterator.empty + } else { + val mutableRow = new GenericMutableRow(attributes.length) + val mutablePair = new MutablePair[Any, Array[String]]() + val buffered = iterator.buffered + + // NOTE (lian): Critical path of Hive table scan, unnecessary FP style code and pattern + // matching are avoided intentionally. + val rowsAndPartitionKeys = buffered.head match { + // With partition keys + case _: Array[Any] => + buffered.map { case array: Array[Any] => + val deserializedRow = array(0) + val partitionKeys = array(1).asInstanceOf[Array[String]] + mutablePair.update(deserializedRow, partitionKeys) + } + + // Without partition keys + case _ => + val emptyPartitionKeys = Array.empty[String] + buffered.map { deserializedRow => + mutablePair.update(deserializedRow, emptyPartitionKeys) + } + } + + rowsAndPartitionKeys.map { pair => + var i = 0 + while (i < attributes.length) { + mutableRow(i) = attributeFunctions(i)(pair._1, pair._2) + i += 1 + } + mutableRow: Row + } + } + } + } + + override def output = attributes +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala new file mode 100644 index 0000000000000..c2b0b00aa5852 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.collection.JavaConversions._ + +import java.util.{HashMap => JHashMap} + +import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.metastore.MetaStoreUtils +import org.apache.hadoop.hive.ql.Context +import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.serde2.Serializer +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} +import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation, SparkHiveHadoopWriter} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class InsertIntoHiveTable( + table: MetastoreRelation, + partition: Map[String, Option[String]], + child: SparkPlan, + overwrite: Boolean) + (@transient sc: HiveContext) + extends UnaryNode { + + val outputClass = newSerializer(table.tableDesc).getSerializedClass + @transient private val hiveContext = new Context(sc.hiveconf) + @transient private val db = Hive.get(sc.hiveconf) + + private def newSerializer(tableDesc: TableDesc): Serializer = { + val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] + serializer.initialize(null, tableDesc.getProperties) + serializer + } + + override def otherCopyArgs = sc :: Nil + + def output = child.output + + /** + * Wraps with Hive types based on object inspector. + * TODO: Consolidate all hive OI/data interface code. + */ + protected def wrap(a: (Any, ObjectInspector)): Any = a match { + case (s: String, oi: JavaHiveVarcharObjectInspector) => + new HiveVarchar(s, s.size) + + case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) => + new HiveDecimal(bd.underlying()) + + case (row: Row, oi: StandardStructObjectInspector) => + val struct = oi.create() + row.zip(oi.getAllStructFieldRefs: Seq[StructField]).foreach { + case (data, field) => + oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector)) + } + struct + + case (s: Seq[_], oi: ListObjectInspector) => + val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector)) + seqAsJavaList(wrappedSeq) + + case (m: Map[_, _], oi: MapObjectInspector) => + val keyOi = oi.getMapKeyObjectInspector + val valueOi = oi.getMapValueObjectInspector + val wrappedMap = m.map { case (key, value) => wrap(key, keyOi) -> wrap(value, valueOi) } + mapAsJavaMap(wrappedMap) + + case (obj, _) => + obj + } + + def saveAsHiveFile( + rdd: RDD[Writable], + valueClass: Class[_], + fileSinkConf: FileSinkDesc, + conf: JobConf, + isCompressed: Boolean) { + if (valueClass == null) { + throw new SparkException("Output value class not set") + } + conf.setOutputValueClass(valueClass) + if (fileSinkConf.getTableInfo.getOutputFileFormatClassName == null) { + throw new SparkException("Output format class not set") + } + // Doesn't work in Scala 2.9 due to what may be a generics bug + // TODO: Should we uncomment this for Scala 2.10? + // conf.setOutputFormat(outputFormatClass) + conf.set("mapred.output.format.class", fileSinkConf.getTableInfo.getOutputFileFormatClassName) + if (isCompressed) { + // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", + // and "mapred.output.compression.type" have no impact on ORC because it uses table properties + // to store compression information. + conf.set("mapred.output.compress", "true") + fileSinkConf.setCompressed(true) + fileSinkConf.setCompressCodec(conf.get("mapred.output.compression.codec")) + fileSinkConf.setCompressType(conf.get("mapred.output.compression.type")) + } + conf.setOutputCommitter(classOf[FileOutputCommitter]) + FileOutputFormat.setOutputPath( + conf, + SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf)) + + logger.debug("Saving as hadoop file of type " + valueClass.getSimpleName) + + val writer = new SparkHiveHadoopWriter(conf, fileSinkConf) + writer.preSetup() + + def writeToFile(context: TaskContext, iter: Iterator[Writable]) { + // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it + // around by taking a mod. We expect that no task will be attempted 2 billion times. + val attemptNumber = (context.attemptId % Int.MaxValue).toInt + + writer.setup(context.stageId, context.partitionId, attemptNumber) + writer.open() + + var count = 0 + while(iter.hasNext) { + val record = iter.next() + count += 1 + writer.write(record) + } + + writer.close() + writer.commit() + } + + sc.sparkContext.runJob(rdd, writeToFile _) + writer.commitJob() + } + + override def execute() = result + + /** + * Inserts all the rows in the table into Hive. Row objects are properly serialized with the + * `org.apache.hadoop.hive.serde2.SerDe` and the + * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. + * + * Note: this is run once and then kept to avoid double insertions. + */ + private lazy val result: RDD[Row] = { + val childRdd = child.execute() + assert(childRdd != null) + + // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer + // instances within the closure, since Serializer is not serializable while TableDesc is. + val tableDesc = table.tableDesc + val tableLocation = table.hiveQlTable.getDataLocation + val tmpLocation = hiveContext.getExternalTmpFileURI(tableLocation) + val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) + val rdd = childRdd.mapPartitions { iter => + val serializer = newSerializer(fileSinkConf.getTableInfo) + val standardOI = ObjectInspectorUtils + .getStandardObjectInspector( + fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, + ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + + val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val outputData = new Array[Any](fieldOIs.length) + iter.map { row => + var i = 0 + while (i < row.length) { + // Casts Strings to HiveVarchars when necessary. + outputData(i) = wrap(row(i), fieldOIs(i)) + i += 1 + } + + serializer.serialize(outputData, standardOI) + } + } + + // ORC stores compression information in table properties. While, there are other formats + // (e.g. RCFile) that rely on hadoop configurations to store compression information. + val jobConf = new JobConf(sc.hiveconf) + saveAsHiveFile( + rdd, + outputClass, + fileSinkConf, + jobConf, + sc.hiveconf.getBoolean("hive.exec.compress.output", false)) + + // TODO: Handle dynamic partitioning. + val outputPath = FileOutputFormat.getOutputPath(jobConf) + // Have to construct the format of dbname.tablename. + val qualifiedTableName = s"${table.databaseName}.${table.tableName}" + // TODO: Correctly set holdDDLTime. + // In most of the time, we should have holdDDLTime = false. + // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint. + val holdDDLTime = false + if (partition.nonEmpty) { + val partitionSpec = partition.map { + case (key, Some(value)) => key -> value + case (key, None) => key -> "" // Should not reach here right now. + } + val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) + db.validatePartitionNameCharacters(partVals) + // inheritTableSpecs is set to true. It should be set to false for a IMPORT query + // which is currently considered as a Hive native command. + val inheritTableSpecs = true + // TODO: Correctly set isSkewedStoreAsSubdir. + val isSkewedStoreAsSubdir = false + db.loadPartition( + outputPath, + qualifiedTableName, + partitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } else { + db.loadTable( + outputPath, + qualifiedTableName, + overwrite, + holdDDLTime) + } + + // It would be nice to just return the childRdd unchanged so insert operations could be chained, + // however for now we return an empty list to simplify compatibility checks with hive, which + // does not return anything for insert operations. + // TODO: implement hive compatibility as rules. + sc.sparkContext.makeRDD(Nil, 1) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala new file mode 100644 index 0000000000000..fe6031678f70f --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow, Row} +import org.apache.spark.sql.execution.{Command, LeafNode} +import org.apache.spark.sql.hive.HiveContext + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class NativeCommand( + sql: String, output: Seq[Attribute])( + @transient context: HiveContext) + extends LeafNode with Command { + + override protected[sql] lazy val sideEffectResult: Seq[String] = context.runSqlHive(sql) + + override def execute(): RDD[Row] = { + if (sideEffectResult.size == 0) { + context.emptyResult + } else { + val rows = sideEffectResult.map(r => new GenericRow(Array[Any](r))) + context.sparkContext.parallelize(rows, 1) + } + } + + override def otherCopyArgs = context :: Nil +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala similarity index 100% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala deleted file mode 100644 index 2de2db28a7e04..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala +++ /dev/null @@ -1,524 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.MetaStoreUtils -import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Hive} -import org.apache.hadoop.hive.ql.metadata.formatting.MetaDataFormatUtils -import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc} -import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption -import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.hive.serde2.objectinspector.primitive._ -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils -import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Serializer} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapred._ - -import org.apache.spark -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types.{BooleanType, DataType} -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.hive._ -import org.apache.spark.util.MutablePair -import org.apache.spark.{TaskContext, SparkException} - -/* Implicits */ -import scala.collection.JavaConversions._ - -/** - * :: DeveloperApi :: - * The Hive table scan operator. Column and partition pruning are both handled. - * - * @param attributes Attributes to be fetched from the Hive table. - * @param relation The Hive table be be scanned. - * @param partitionPruningPred An optional partition pruning predicate for partitioned table. - */ -@DeveloperApi -case class HiveTableScan( - attributes: Seq[Attribute], - relation: MetastoreRelation, - partitionPruningPred: Option[Expression])( - @transient val context: HiveContext) - extends LeafNode - with HiveInspectors { - - require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, - "Partition pruning predicates only supported for partitioned tables.") - - // Bind all partition key attribute references in the partition pruning predicate for later - // evaluation. - private val boundPruningPred = partitionPruningPred.map { pred => - require( - pred.dataType == BooleanType, - s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") - - BindReferences.bindReference(pred, relation.partitionKeys) - } - - @transient - val hadoopReader = new HadoopTableReader(relation.tableDesc, context) - - /** - * The hive object inspector for this table, which can be used to extract values from the - * serialized row representation. - */ - @transient - lazy val objectInspector = - relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector] - - /** - * Functions that extract the requested attributes from the hive output. Partitioned values are - * casted from string to its declared data type. - */ - @transient - protected lazy val attributeFunctions: Seq[(Any, Array[String]) => Any] = { - attributes.map { a => - val ordinal = relation.partitionKeys.indexOf(a) - if (ordinal >= 0) { - val dataType = relation.partitionKeys(ordinal).dataType - (_: Any, partitionKeys: Array[String]) => { - castFromString(partitionKeys(ordinal), dataType) - } - } else { - val ref = objectInspector.getAllStructFieldRefs - .find(_.getFieldName == a.name) - .getOrElse(sys.error(s"Can't find attribute $a")) - val fieldObjectInspector = ref.getFieldObjectInspector - - val unwrapHiveData = fieldObjectInspector match { - case _: HiveVarcharObjectInspector => - (value: Any) => value.asInstanceOf[HiveVarchar].getValue - case _: HiveDecimalObjectInspector => - (value: Any) => BigDecimal(value.asInstanceOf[HiveDecimal].bigDecimalValue()) - case _ => - identity[Any] _ - } - - (row: Any, _: Array[String]) => { - val data = objectInspector.getStructFieldData(row, ref) - val hiveData = unwrapData(data, fieldObjectInspector) - if (hiveData != null) unwrapHiveData(hiveData) else null - } - } - } - } - - private def castFromString(value: String, dataType: DataType) = { - Cast(Literal(value), dataType).eval(null) - } - - private def addColumnMetadataToConf(hiveConf: HiveConf) { - // Specifies IDs and internal names of columns to be scanned. - val neededColumnIDs = attributes.map(a => relation.output.indexWhere(_.name == a.name): Integer) - val columnInternalNames = neededColumnIDs.map(HiveConf.getColumnInternalName(_)).mkString(",") - - if (attributes.size == relation.output.size) { - ColumnProjectionUtils.setFullyReadColumns(hiveConf) - } else { - ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs) - } - - ColumnProjectionUtils.appendReadColumnNames(hiveConf, attributes.map(_.name)) - - // Specifies types and object inspectors of columns to be scanned. - val structOI = ObjectInspectorUtils - .getStandardObjectInspector( - relation.tableDesc.getDeserializer.getObjectInspector, - ObjectInspectorCopyOption.JAVA) - .asInstanceOf[StructObjectInspector] - - val columnTypeNames = structOI - .getAllStructFieldRefs - .map(_.getFieldObjectInspector) - .map(TypeInfoUtils.getTypeInfoFromObjectInspector(_).getTypeName) - .mkString(",") - - hiveConf.set(serdeConstants.LIST_COLUMN_TYPES, columnTypeNames) - hiveConf.set(serdeConstants.LIST_COLUMNS, columnInternalNames) - } - - addColumnMetadataToConf(context.hiveconf) - - @transient - def inputRdd = if (!relation.hiveQlTable.isPartitioned) { - hadoopReader.makeRDDForTable(relation.hiveQlTable) - } else { - hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) - } - - /** - * Prunes partitions not involve the query plan. - * - * @param partitions All partitions of the relation. - * @return Partitions that are involved in the query plan. - */ - private[hive] def prunePartitions(partitions: Seq[HivePartition]) = { - boundPruningPred match { - case None => partitions - case Some(shouldKeep) => partitions.filter { part => - val dataTypes = relation.partitionKeys.map(_.dataType) - val castedValues = for ((value, dataType) <- part.getValues.zip(dataTypes)) yield { - castFromString(value, dataType) - } - - // Only partitioned values are needed here, since the predicate has already been bound to - // partition key attribute references. - val row = new GenericRow(castedValues.toArray) - shouldKeep.eval(row).asInstanceOf[Boolean] - } - } - } - - def execute() = { - inputRdd.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val mutableRow = new GenericMutableRow(attributes.length) - val mutablePair = new MutablePair[Any, Array[String]]() - val buffered = iterator.buffered - - // NOTE (lian): Critical path of Hive table scan, unnecessary FP style code and pattern - // matching are avoided intentionally. - val rowsAndPartitionKeys = buffered.head match { - // With partition keys - case _: Array[Any] => - buffered.map { case array: Array[Any] => - val deserializedRow = array(0) - val partitionKeys = array(1).asInstanceOf[Array[String]] - mutablePair.update(deserializedRow, partitionKeys) - } - - // Without partition keys - case _ => - val emptyPartitionKeys = Array.empty[String] - buffered.map { deserializedRow => - mutablePair.update(deserializedRow, emptyPartitionKeys) - } - } - - rowsAndPartitionKeys.map { pair => - var i = 0 - while (i < attributes.length) { - mutableRow(i) = attributeFunctions(i)(pair._1, pair._2) - i += 1 - } - mutableRow: Row - } - } - } - } - - def output = attributes -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class InsertIntoHiveTable( - table: MetastoreRelation, - partition: Map[String, Option[String]], - child: SparkPlan, - overwrite: Boolean) - (@transient sc: HiveContext) - extends UnaryNode { - - val outputClass = newSerializer(table.tableDesc).getSerializedClass - @transient private val hiveContext = new Context(sc.hiveconf) - @transient private val db = Hive.get(sc.hiveconf) - - private def newSerializer(tableDesc: TableDesc): Serializer = { - val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] - serializer.initialize(null, tableDesc.getProperties) - serializer - } - - override def otherCopyArgs = sc :: Nil - - def output = child.output - - /** - * Wraps with Hive types based on object inspector. - * TODO: Consolidate all hive OI/data interface code. - */ - protected def wrap(a: (Any, ObjectInspector)): Any = a match { - case (s: String, oi: JavaHiveVarcharObjectInspector) => - new HiveVarchar(s, s.size) - - case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) => - new HiveDecimal(bd.underlying()) - - case (row: Row, oi: StandardStructObjectInspector) => - val struct = oi.create() - row.zip(oi.getAllStructFieldRefs: Seq[StructField]).foreach { - case (data, field) => - oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector)) - } - struct - - case (s: Seq[_], oi: ListObjectInspector) => - val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector)) - seqAsJavaList(wrappedSeq) - - case (obj, _) => - obj - } - - def saveAsHiveFile( - rdd: RDD[Writable], - valueClass: Class[_], - fileSinkConf: FileSinkDesc, - conf: JobConf, - isCompressed: Boolean) { - if (valueClass == null) { - throw new SparkException("Output value class not set") - } - conf.setOutputValueClass(valueClass) - if (fileSinkConf.getTableInfo.getOutputFileFormatClassName == null) { - throw new SparkException("Output format class not set") - } - // Doesn't work in Scala 2.9 due to what may be a generics bug - // TODO: Should we uncomment this for Scala 2.10? - // conf.setOutputFormat(outputFormatClass) - conf.set("mapred.output.format.class", fileSinkConf.getTableInfo.getOutputFileFormatClassName) - if (isCompressed) { - // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", - // and "mapred.output.compression.type" have no impact on ORC because it uses table properties - // to store compression information. - conf.set("mapred.output.compress", "true") - fileSinkConf.setCompressed(true) - fileSinkConf.setCompressCodec(conf.get("mapred.output.compression.codec")) - fileSinkConf.setCompressType(conf.get("mapred.output.compression.type")) - } - conf.setOutputCommitter(classOf[FileOutputCommitter]) - FileOutputFormat.setOutputPath( - conf, - SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf)) - - logger.debug("Saving as hadoop file of type " + valueClass.getSimpleName) - - val writer = new SparkHiveHadoopWriter(conf, fileSinkConf) - writer.preSetup() - - def writeToFile(context: TaskContext, iter: Iterator[Writable]) { - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt - - writer.setup(context.stageId, context.partitionId, attemptNumber) - writer.open() - - var count = 0 - while(iter.hasNext) { - val record = iter.next() - count += 1 - writer.write(record) - } - - writer.close() - writer.commit() - } - - sc.sparkContext.runJob(rdd, writeToFile _) - writer.commitJob() - } - - override def execute() = result - - /** - * Inserts all the rows in the table into Hive. Row objects are properly serialized with the - * `org.apache.hadoop.hive.serde2.SerDe` and the - * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. - * - * Note: this is run once and then kept to avoid double insertions. - */ - private lazy val result: RDD[Row] = { - val childRdd = child.execute() - assert(childRdd != null) - - // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer - // instances within the closure, since Serializer is not serializable while TableDesc is. - val tableDesc = table.tableDesc - val tableLocation = table.hiveQlTable.getDataLocation - val tmpLocation = hiveContext.getExternalTmpFileURI(tableLocation) - val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) - val rdd = childRdd.mapPartitions { iter => - val serializer = newSerializer(fileSinkConf.getTableInfo) - val standardOI = ObjectInspectorUtils - .getStandardObjectInspector( - fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, - ObjectInspectorCopyOption.JAVA) - .asInstanceOf[StructObjectInspector] - - - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray - val outputData = new Array[Any](fieldOIs.length) - iter.map { row => - var i = 0 - while (i < row.length) { - // Casts Strings to HiveVarchars when necessary. - outputData(i) = wrap(row(i), fieldOIs(i)) - i += 1 - } - - serializer.serialize(outputData, standardOI) - } - } - - // ORC stores compression information in table properties. While, there are other formats - // (e.g. RCFile) that rely on hadoop configurations to store compression information. - val jobConf = new JobConf(sc.hiveconf) - saveAsHiveFile( - rdd, - outputClass, - fileSinkConf, - jobConf, - sc.hiveconf.getBoolean("hive.exec.compress.output", false)) - - // TODO: Handle dynamic partitioning. - val outputPath = FileOutputFormat.getOutputPath(jobConf) - // Have to construct the format of dbname.tablename. - val qualifiedTableName = s"${table.databaseName}.${table.tableName}" - // TODO: Correctly set holdDDLTime. - // In most of the time, we should have holdDDLTime = false. - // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint. - val holdDDLTime = false - if (partition.nonEmpty) { - val partitionSpec = partition.map { - case (key, Some(value)) => key -> value - case (key, None) => key -> "" // Should not reach here right now. - } - val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) - db.validatePartitionNameCharacters(partVals) - // inheritTableSpecs is set to true. It should be set to false for a IMPORT query - // which is currently considered as a Hive native command. - val inheritTableSpecs = true - // TODO: Correctly set isSkewedStoreAsSubdir. - val isSkewedStoreAsSubdir = false - db.loadPartition( - outputPath, - qualifiedTableName, - partitionSpec, - overwrite, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) - } else { - db.loadTable( - outputPath, - qualifiedTableName, - overwrite, - holdDDLTime) - } - - // It would be nice to just return the childRdd unchanged so insert operations could be chained, - // however for now we return an empty list to simplify compatibility checks with hive, which - // does not return anything for insert operations. - // TODO: implement hive compatibility as rules. - sc.sparkContext.makeRDD(Nil, 1) - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class NativeCommand( - sql: String, output: Seq[Attribute])( - @transient context: HiveContext) - extends LeafNode with Command { - - override protected[sql] lazy val sideEffectResult: Seq[String] = context.runSqlHive(sql) - - override def execute(): RDD[spark.sql.Row] = { - if (sideEffectResult.size == 0) { - context.emptyResult - } else { - val rows = sideEffectResult.map(r => new GenericRow(Array[Any](r))) - context.sparkContext.parallelize(rows, 1) - } - } - - override def otherCopyArgs = context :: Nil -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class DescribeHiveTableCommand( - table: MetastoreRelation, - output: Seq[Attribute], - isExtended: Boolean)( - @transient context: HiveContext) - extends LeafNode with Command { - - // Strings with the format like Hive. It is used for result comparison in our unit tests. - lazy val hiveString: Seq[String] = { - val alignment = 20 - val delim = "\t" - - sideEffectResult.map { - case (name, dataType, comment) => - String.format("%-" + alignment + "s", name) + delim + - String.format("%-" + alignment + "s", dataType) + delim + - String.format("%-" + alignment + "s", Option(comment).getOrElse("None")) - } - } - - override protected[sql] lazy val sideEffectResult: Seq[(String, String, String)] = { - // Trying to mimic the format of Hive's output. But not exactly the same. - var results: Seq[(String, String, String)] = Nil - - val columns: Seq[FieldSchema] = table.hiveQlTable.getCols - val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols - results ++= columns.map(field => (field.getName, field.getType, field.getComment)) - if (!partitionColumns.isEmpty) { - val partColumnInfo = - partitionColumns.map(field => (field.getName, field.getType, field.getComment)) - results ++= - partColumnInfo ++ - Seq(("# Partition Information", "", "")) ++ - Seq((s"# ${output.get(0).name}", output.get(1).name, output.get(2).name)) ++ - partColumnInfo - } - - if (isExtended) { - results ++= Seq(("Detailed Table Information", table.hiveQlTable.getTTable.toString, "")) - } - - results - } - - override def execute(): RDD[Row] = { - val rows = sideEffectResult.map { - case (name, dataType, comment) => new GenericRow(Array[Any](name, dataType, comment)) - } - context.sparkContext.parallelize(rows, 1) - } - - override def otherCopyArgs = context :: Nil -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index ad5e24c62c621..9b105308ab7cf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -84,7 +84,7 @@ private[hive] object HiveFunctionRegistry case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType - + // java class case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType @@ -98,7 +98,7 @@ private[hive] object HiveFunctionRegistry case c: Class[_] if c == classOf[java.lang.Byte] => ByteType case c: Class[_] if c == classOf[java.lang.Float] => FloatType case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType - + // primitive type case c: Class[_] if c == java.lang.Short.TYPE => ShortType case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType @@ -107,7 +107,7 @@ private[hive] object HiveFunctionRegistry case c: Class[_] if c == java.lang.Byte.TYPE => ByteType case c: Class[_] if c == java.lang.Float.TYPE => FloatType case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType - + case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) } } @@ -148,7 +148,7 @@ private[hive] trait HiveFunctionFactory { case p: java.lang.Byte => p case p: java.lang.Boolean => p case str: String => str - case p: BigDecimal => p + case p: java.math.BigDecimal => p case p: Array[Byte] => p case p: java.sql.Timestamp => p } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 9f5cf282f7c48..9f1cd703103ec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -224,6 +224,27 @@ class HiveQuerySuite extends HiveComparisonTest { TestHive.reset() } + test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { + val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) + .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} + TestHive.sparkContext.parallelize(fixture).registerAsTable("having_test") + val results = + hql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") + .collect() + .map(x => Pair(x.getString(0), x.getInt(1))) + + assert(results === Array(Pair("foo", 4))) + TestHive.reset() + } + + test("SPARK-2180: HAVING with non-boolean clause raises no exceptions") { + hql("select key, count(*) c from src group by key having c").collect() + } + + test("SPARK-2225: turn HAVING without GROUP BY into a simple filter") { + assert(hql("select key from src having key > 490").collect().size < 100) + } + test("Query Hive native command execution result") { val tableName = "test_native_commands" @@ -349,6 +370,16 @@ class HiveQuerySuite extends HiveComparisonTest { } } + test("SPARK-2263: Insert Map values") { + hql("CREATE TABLE m(value MAP)") + hql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") + hql("SELECT * FROM m").collect().zip(hql("SELECT * FROM src LIMIT 10").collect()).map { + case (Row(map: Map[Int, String]), Row(key: Int, value: String)) => + assert(map.size === 1) + assert(map.head === (key, value)) + } + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" @@ -439,5 +470,7 @@ class HiveQuerySuite extends HiveComparisonTest { // Put tests that depend on specific Hive settings before these last two test, // since they modify /clear stuff. - } + +// for SPARK-2180 test +case class HavingRow(key: Int, value: String, attr: Int) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index a9e3f42a3adfc..f944d010660eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -122,6 +122,3 @@ class PairUdf extends GenericUDF { override def getDisplayString(p1: Array[String]): String = "" } - - - diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 34434449a0d77..4d7c84f443879 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -26,6 +26,11 @@ import scala.collection.JavaConversions._ * A set of test cases that validate partition and column pruning. */ class PruningSuite extends HiveComparisonTest { + // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset + // the environment to ensure all referenced tables in this suites are not cached in-memory. + // Refer to https://issues.apache.org/jira/browse/SPARK-2283 for details. + TestHive.reset() + // Column pruning tests createPruningTest("Column pruning - with partitioned table", diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 8f2267599914c..556f49342977a 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -154,7 +154,7 @@ trait ClientBase extends Logging { } /** Copy the file into HDFS if needed. */ - private def copyRemoteFile( + private[yarn] def copyRemoteFile( dstDir: Path, originalPath: Path, replication: Short, @@ -213,10 +213,19 @@ trait ClientBase extends Logging { val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - Map( - ClientBase.SPARK_JAR -> ClientBase.getSparkJar, ClientBase.APP_JAR -> args.userJar, - ClientBase.LOG4J_PROP -> System.getenv(ClientBase.LOG4J_CONF_ENV_KEY) - ).foreach { case(destName, _localPath) => + val oldLog4jConf = Option(System.getenv("SPARK_LOG4J_CONF")) + if (oldLog4jConf.isDefined) { + logWarning( + "SPARK_LOG4J_CONF detected in the system environment. This variable has been " + + "deprecated. Please refer to the \"Launching Spark on YARN\" documentation " + + "for alternatives.") + } + + List( + (ClientBase.SPARK_JAR, ClientBase.sparkJar(sparkConf), ClientBase.CONF_SPARK_JAR), + (ClientBase.APP_JAR, args.userJar, ClientBase.CONF_SPARK_USER_JAR), + ("log4j.properties", oldLog4jConf.getOrElse(null), null) + ).foreach { case(destName, _localPath, confKey) => val localPath: String = if (_localPath != null) _localPath.trim() else "" if (! localPath.isEmpty()) { val localURI = new URI(localPath) @@ -225,6 +234,8 @@ trait ClientBase extends Logging { val destPath = copyRemoteFile(dst, qualifyForLocal(localURI), replication, setPermissions) distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, destName, statCache) + } else if (confKey != null) { + sparkConf.set(confKey, localPath) } } } @@ -246,6 +257,8 @@ trait ClientBase extends Logging { if (addToClasspath) { cachedSecondaryJarLinks += linkname } + } else if (addToClasspath) { + cachedSecondaryJarLinks += file.trim() } } } @@ -265,14 +278,10 @@ trait ClientBase extends Logging { val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.driver.extraClassPath") - val log4jConf = System.getenv(ClientBase.LOG4J_CONF_ENV_KEY) - ClientBase.populateClasspath(yarnConf, sparkConf, log4jConf, env, extraCp) + ClientBase.populateClasspath(args, yarnConf, sparkConf, env, extraCp) env("SPARK_YARN_MODE") = "true" env("SPARK_YARN_STAGING_DIR") = stagingDir env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() - if (log4jConf != null) { - env(ClientBase.LOG4J_CONF_ENV_KEY) = log4jConf - } // Set the environment variables to be passed on to the executors. distCacheMgr.setDistFilesEnv(env) @@ -285,7 +294,6 @@ trait ClientBase extends Logging { // Pass SPARK_YARN_USER_ENV itself to the AM so it can use it to set up executor environments. env("SPARK_YARN_USER_ENV") = userEnvs } - env } @@ -310,6 +318,37 @@ trait ClientBase extends Logging { logInfo("Setting up container launch context") val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) amContainer.setLocalResources(localResources) + + // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to + // executors. But we can't just set spark.executor.extraJavaOptions, because the driver's + // SparkContext will not let that set spark* system properties, which is expected behavior for + // Yarn clients. So propagate it through the environment. + // + // Note that to warn the user about the deprecation in cluster mode, some code from + // SparkConf#validateSettings() is duplicated here (to avoid triggering the condition + // described above). + if (args.amClass == classOf[ApplicationMaster].getName) { + sys.env.get("SPARK_JAVA_OPTS").foreach { value => + val warning = + s""" + |SPARK_JAVA_OPTS was detected (set to '$value'). + |This is deprecated in Spark 1.0+. + | + |Please instead use: + | - ./spark-submit with conf/spark-defaults.conf to set defaults for an application + | - ./spark-submit with --driver-java-options to set -X options for a driver + | - spark.executor.extraJavaOptions to set -X options for executors + """.stripMargin + logWarning(warning) + for (proc <- Seq("driver", "executor")) { + val key = s"spark.$proc.extraJavaOptions" + if (sparkConf.contains(key)) { + throw new SparkException(s"Found both $key and SPARK_JAVA_OPTS. Use only the former.") + } + } + env("SPARK_JAVA_OPTS") = value + } + } amContainer.setEnvironment(env) val amMemory = calculateAMMemory(newApp) @@ -341,30 +380,20 @@ trait ClientBase extends Logging { javaOpts += "-XX:CMSIncrementalDutyCycle=10" } - // SPARK_JAVA_OPTS is deprecated, but for backwards compatibility: - sys.env.get("SPARK_JAVA_OPTS").foreach { opts => - sparkConf.set("spark.executor.extraJavaOptions", opts) - sparkConf.set("spark.driver.extraJavaOptions", opts) - } - + // Forward the Spark configuration to the application master / executors. // TODO: it might be nicer to pass these as an internal environment variable rather than // as Java options, due to complications with string parsing of nested quotes. - if (args.amClass == classOf[ExecutorLauncher].getName) { - // If we are being launched in client mode, forward the spark-conf options - // onto the executor launcher - for ((k, v) <- sparkConf.getAll) { - javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" - } - } else { - // If we are being launched in standalone mode, capture and forward any spark - // system properties (e.g. set by spark-class). - for ((k, v) <- sys.props.filterKeys(_.startsWith("spark"))) { - javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" - } - sys.props.get("spark.driver.extraJavaOptions").foreach(opts => javaOpts += opts) - sys.props.get("spark.driver.libraryPath").foreach(p => javaOpts += s"-Djava.library.path=$p") + for ((k, v) <- sparkConf.getAll) { + javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" + } + + if (args.amClass == classOf[ApplicationMaster].getName) { + sparkConf.getOption("spark.driver.extraJavaOptions") + .orElse(sys.env.get("SPARK_JAVA_OPTS")) + .foreach(opts => javaOpts += opts) + sparkConf.getOption("spark.driver.libraryPath") + .foreach(p => javaOpts += s"-Djava.library.path=$p") } - javaOpts += ClientBase.getLog4jConfiguration(localResources) // Command for the ApplicationMaster val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++ @@ -377,7 +406,10 @@ trait ClientBase extends Logging { "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") - logInfo("Command for starting the Spark ApplicationMaster: " + commands) + logInfo("Yarn AM launch context:") + logInfo(s" class: ${args.amClass}") + logInfo(s" env: $env") + logInfo(s" command: ${commands.mkString(" ")}") // TODO: it would be nicer to just make sure there are no null commands here val printableCommands = commands.map(s => if (s == null) "null" else s).toList @@ -391,12 +423,39 @@ trait ClientBase extends Logging { object ClientBase extends Logging { val SPARK_JAR: String = "__spark__.jar" val APP_JAR: String = "__app__.jar" - val LOG4J_PROP: String = "log4j.properties" - val LOG4J_CONF_ENV_KEY: String = "SPARK_LOG4J_CONF" val LOCAL_SCHEME = "local" + val CONF_SPARK_JAR = "spark.yarn.jar" + /** + * This is an internal config used to propagate the location of the user's jar file to the + * driver/executors. + */ + val CONF_SPARK_USER_JAR = "spark.yarn.user.jar" + /** + * This is an internal config used to propagate the list of extra jars to add to the classpath + * of executors. + */ val CONF_SPARK_YARN_SECONDARY_JARS = "spark.yarn.secondary.jars" + val ENV_SPARK_JAR = "SPARK_JAR" - def getSparkJar = sys.env.get("SPARK_JAR").getOrElse(SparkContext.jarOfClass(this.getClass).head) + /** + * Find the user-defined Spark jar if configured, or return the jar containing this + * class if not. + * + * This method first looks in the SparkConf object for the CONF_SPARK_JAR key, and in the + * user environment if that is not found (for backwards compatibility). + */ + def sparkJar(conf: SparkConf) = { + if (conf.contains(CONF_SPARK_JAR)) { + conf.get(CONF_SPARK_JAR) + } else if (System.getenv(ENV_SPARK_JAR) != null) { + logWarning( + s"$ENV_SPARK_JAR detected in the system environment. This variable has been deprecated " + + s"in favor of the $CONF_SPARK_JAR configuration variable.") + System.getenv(ENV_SPARK_JAR) + } else { + SparkContext.jarOfClass(this.getClass).head + } + } def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) = { val classPathElementsToAdd = getYarnAppClasspath(conf) ++ getMRAppClasspath(conf) @@ -469,71 +528,74 @@ object ClientBase extends Logging { triedDefault.toOption } + def populateClasspath(args: ClientArguments, conf: Configuration, sparkConf: SparkConf, + env: HashMap[String, String], extraClassPath: Option[String] = None) { + extraClassPath.foreach(addClasspathEntry(_, env)) + addClasspathEntry(Environment.PWD.$(), env) + + // Normally the users app.jar is last in case conflicts with spark jars + if (sparkConf.get("spark.yarn.user.classpath.first", "false").toBoolean) { + addUserClasspath(args, sparkConf, env) + addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env) + ClientBase.populateHadoopClasspath(conf, env) + } else { + addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env) + ClientBase.populateHadoopClasspath(conf, env) + addUserClasspath(args, sparkConf, env) + } + + // Append all jar files under the working directory to the classpath. + addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + "*", env); + } /** - * Returns the java command line argument for setting up log4j. If there is a log4j.properties - * in the given local resources, it is used, otherwise the SPARK_LOG4J_CONF environment variable - * is checked. + * Adds the user jars which have local: URIs (or alternate names, such as APP_JAR) explicitly + * to the classpath. */ - def getLog4jConfiguration(localResources: HashMap[String, LocalResource]): String = { - var log4jConf = LOG4J_PROP - if (!localResources.contains(log4jConf)) { - log4jConf = System.getenv(LOG4J_CONF_ENV_KEY) match { - case conf: String => - val confUri = new URI(conf) - if (ClientBase.LOCAL_SCHEME.equals(confUri.getScheme())) { - "file://" + confUri.getPath() - } else { - ClientBase.LOG4J_PROP - } - case null => "log4j-spark-container.properties" + private def addUserClasspath(args: ClientArguments, conf: SparkConf, + env: HashMap[String, String]) = { + if (args != null) { + addFileToClasspath(args.userJar, APP_JAR, env) + if (args.addJars != null) { + args.addJars.split(",").foreach { case file: String => + addFileToClasspath(file, null, env) + } } + } else { + val userJar = conf.get(CONF_SPARK_USER_JAR, null) + addFileToClasspath(userJar, APP_JAR, env) + + val cachedSecondaryJarLinks = conf.get(CONF_SPARK_YARN_SECONDARY_JARS, "").split(",") + cachedSecondaryJarLinks.foreach(jar => addFileToClasspath(jar, null, env)) } - " -Dlog4j.configuration=" + log4jConf } - def populateClasspath(conf: Configuration, sparkConf: SparkConf, log4jConf: String, - env: HashMap[String, String], extraClassPath: Option[String] = None) { - - if (log4jConf != null) { - // If a custom log4j config file is provided as a local: URI, add its parent directory to the - // classpath. Note that this only works if the custom config's file name is - // "log4j.properties". - val localPath = getLocalPath(log4jConf) - if (localPath != null) { - val parentPath = new File(localPath).getParent() - YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, parentPath, - File.pathSeparator) + /** + * Adds the given path to the classpath, handling "local:" URIs correctly. + * + * If an alternate name for the file is given, and it's not a "local:" file, the alternate + * name will be added to the classpath (relative to the job's work directory). + * + * If not a "local:" file and no alternate name, the environment is not modified. + * + * @param path Path to add to classpath (optional). + * @param fileName Alternate name for the file (optional). + * @param env Map holding the environment variables. + */ + private def addFileToClasspath(path: String, fileName: String, + env: HashMap[String, String]) : Unit = { + if (path != null) { + scala.util.control.Exception.ignoring(classOf[URISyntaxException]) { + val localPath = getLocalPath(path) + if (localPath != null) { + addClasspathEntry(localPath, env) + return + } } } - - /** Add entry to the classpath. */ - def addClasspathEntry(path: String) = YarnSparkHadoopUtil.addToEnvironment(env, - Environment.CLASSPATH.name, path, File.pathSeparator) - /** Add entry to the classpath. Interpreted as a path relative to the working directory. */ - def addPwdClasspathEntry(entry: String) = - addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + entry) - - extraClassPath.foreach(addClasspathEntry) - - val cachedSecondaryJarLinks = - sparkConf.getOption(CONF_SPARK_YARN_SECONDARY_JARS).getOrElse("").split(",") - .filter(_.nonEmpty) - // Normally the users app.jar is last in case conflicts with spark jars - if (sparkConf.get("spark.yarn.user.classpath.first", "false").toBoolean) { - addPwdClasspathEntry(APP_JAR) - cachedSecondaryJarLinks.foreach(addPwdClasspathEntry) - addPwdClasspathEntry(SPARK_JAR) - ClientBase.populateHadoopClasspath(conf, env) - } else { - addPwdClasspathEntry(SPARK_JAR) - ClientBase.populateHadoopClasspath(conf, env) - addPwdClasspathEntry(APP_JAR) - cachedSecondaryJarLinks.foreach(addPwdClasspathEntry) + if (fileName != null) { + addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + fileName, env); } - // Append all class files and jar files under the working directory to the classpath. - addClasspathEntry(Environment.PWD.$()) - addPwdClasspathEntry("*") } /** @@ -547,4 +609,8 @@ object ClientBase extends Logging { null } + private def addClasspathEntry(path: String, env: HashMap[String, String]) = + YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, path, + File.pathSeparator) + } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index 43dbb2464f929..4ba7133a959ed 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -55,10 +55,12 @@ trait ExecutorRunnableUtil extends Logging { sys.props.get("spark.executor.extraJavaOptions").foreach { opts => javaOpts += opts } + sys.env.get("SPARK_JAVA_OPTS").foreach { opts => + javaOpts += opts + } javaOpts += "-Djava.io.tmpdir=" + new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) - javaOpts += ClientBase.getLog4jConfiguration(localResources) // Certain configs need to be passed here because they are needed before the Executor // registers with the Scheduler and transfers the spark configs. Since the Executor backend @@ -166,13 +168,8 @@ trait ExecutorRunnableUtil extends Logging { def prepareEnvironment: HashMap[String, String] = { val env = new HashMap[String, String]() - val extraCp = sparkConf.getOption("spark.executor.extraClassPath") - val log4jConf = System.getenv(ClientBase.LOG4J_CONF_ENV_KEY) - ClientBase.populateClasspath(yarnConf, sparkConf, log4jConf, env, extraCp) - if (log4jConf != null) { - env(ClientBase.LOG4J_CONF_ENV_KEY) = log4jConf - } + ClientBase.populateClasspath(null, yarnConf, sparkConf, env, extraCp) // Allow users to specify some environment variables YarnSparkHadoopUtil.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"), diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 412dfe38d55eb..fd2694fe7278d 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -63,7 +63,7 @@ private[spark] class YarnClientSchedulerBackend( // variables. List(("--driver-memory", "SPARK_MASTER_MEMORY", "spark.master.memory"), ("--driver-memory", "SPARK_DRIVER_MEMORY", "spark.driver.memory"), - ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.worker.instances"), + ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"), ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"), ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala index 608c6e92624c6..686714dc36488 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -17,22 +17,31 @@ package org.apache.spark.deploy.yarn +import java.io.File import java.net.URI +import com.google.common.io.Files import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants.Environment - +import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse +import org.apache.hadoop.yarn.api.records.ContainerLaunchContext +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.mockito.Matchers._ +import org.mockito.Mockito._ import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers._ +import org.scalatest.Matchers import scala.collection.JavaConversions._ import scala.collection.mutable.{ HashMap => MutableHashMap } import scala.util.Try +import org.apache.spark.SparkConf +import org.apache.spark.util.Utils -class ClientBaseSuite extends FunSuite { +class ClientBaseSuite extends FunSuite with Matchers { test("default Yarn application classpath") { ClientBase.getDefaultYarnApplicationClasspath should be(Some(Fixtures.knownDefYarnAppCP)) @@ -68,6 +77,67 @@ class ClientBaseSuite extends FunSuite { } } + private val SPARK = "local:/sparkJar" + private val USER = "local:/userJar" + private val ADDED = "local:/addJar1,local:/addJar2,/addJar3" + + test("Local jar URIs") { + val conf = new Configuration() + val sparkConf = new SparkConf().set(ClientBase.CONF_SPARK_JAR, SPARK) + val env = new MutableHashMap[String, String]() + val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) + + ClientBase.populateClasspath(args, conf, sparkConf, env, None) + + val cp = env("CLASSPATH").split(File.pathSeparator) + s"$SPARK,$USER,$ADDED".split(",").foreach({ entry => + val uri = new URI(entry) + if (ClientBase.LOCAL_SCHEME.equals(uri.getScheme())) { + cp should contain (uri.getPath()) + } else { + cp should not contain (uri.getPath()) + } + }) + cp should contain (Environment.PWD.$()) + cp should contain (s"${Environment.PWD.$()}${File.separator}*") + cp should not contain (ClientBase.SPARK_JAR) + cp should not contain (ClientBase.APP_JAR) + } + + test("Jar path propagation through SparkConf") { + val conf = new Configuration() + val sparkConf = new SparkConf().set(ClientBase.CONF_SPARK_JAR, SPARK) + val yarnConf = new YarnConfiguration() + val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) + + val client = spy(new DummyClient(args, conf, sparkConf, yarnConf)) + doReturn(new Path("/")).when(client).copyRemoteFile(any(classOf[Path]), + any(classOf[Path]), anyShort(), anyBoolean()) + + var tempDir = Files.createTempDir(); + try { + client.prepareLocalResources(tempDir.getAbsolutePath()) + sparkConf.getOption(ClientBase.CONF_SPARK_USER_JAR) should be (Some(USER)) + + // The non-local path should be propagated by name only, since it will end up in the app's + // staging dir. + val expected = ADDED.split(",") + .map(p => { + val uri = new URI(p) + if (ClientBase.LOCAL_SCHEME == uri.getScheme()) { + p + } else { + Option(uri.getFragment()).getOrElse(new File(p).getName()) + } + }) + .mkString(",") + + sparkConf.getOption(ClientBase.CONF_SPARK_YARN_SECONDARY_JARS) should be (Some(expected)) + } finally { + Utils.deleteRecursively(tempDir) + } + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = @@ -109,4 +179,18 @@ class ClientBaseSuite extends FunSuite { def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = Try(clazz.getField(field)).map(_.get(null).asInstanceOf[A]).toOption.map(mapTo).getOrElse(defaults) + private class DummyClient( + val args: ClientArguments, + val conf: Configuration, + val sparkConf: SparkConf, + val yarnConf: YarnConfiguration) extends ClientBase { + + override def calculateAMMemory(newApp: GetNewApplicationResponse): Int = + throw new UnsupportedOperationException() + + override def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = + throw new UnsupportedOperationException() + + } + } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 117b33f466f85..07ba0a4b30bd7 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -81,6 +81,7 @@ class ExecutorRunnable( val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores, localResources) + logInfo(s"Setting up executor with environment: $env") logInfo("Setting up executor with commands: " + commands) ctx.setCommands(commands)