diff --git a/README.md b/README.md index 8dd8b70696aa2..dbf53dcd76b2d 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ To build Spark and its example programs, run: (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at -["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). +["Building Spark with Maven"](http://spark.apache.org/docs/latest/building-with-maven.html). ## Interactive Scala Shell diff --git a/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java b/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java new file mode 100644 index 0000000000000..0ad189633e427 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java; + + +import java.util.List; +import java.util.concurrent.Future; + +public interface JavaFutureAction extends Future { + + /** + * Returns the job IDs run by the underlying async operation. + * + * This returns the current snapshot of the job list. Certain operations may run multiple + * jobs, so multiple calls to this method may return different lists. + */ + List jobIds(); +} diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index d89bb50076c9a..80da62c44edc5 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -61,7 +61,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { val computedValues = rdd.computeOrReadCheckpoint(partition, context) // If the task is running locally, do not persist the result - if (context.runningLocally) { + if (context.isRunningLocally) { return computedValues } diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index e8f761eaa5799..d5c8f9d76c476 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -17,20 +17,21 @@ package org.apache.spark -import scala.concurrent._ -import scala.concurrent.duration.Duration -import scala.util.Try +import java.util.Collections +import java.util.concurrent.TimeUnit -import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaFutureAction import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter} +import scala.concurrent._ +import scala.concurrent.duration.Duration +import scala.util.{Failure, Try} + /** - * :: Experimental :: * A future for the result of an action to support cancellation. This is an extension of the * Scala Future interface to support cancellation. */ -@Experimental trait FutureAction[T] extends Future[T] { // Note that we redefine methods of the Future trait here explicitly so we can specify a different // documentation (with reference to the word "action"). @@ -69,6 +70,11 @@ trait FutureAction[T] extends Future[T] { */ override def isCompleted: Boolean + /** + * Returns whether the action has been cancelled. + */ + def isCancelled: Boolean + /** * The value of this Future. * @@ -96,15 +102,16 @@ trait FutureAction[T] extends Future[T] { /** - * :: Experimental :: * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include * count, collect, reduce. */ -@Experimental class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) extends FutureAction[T] { + @volatile private var _cancelled: Boolean = false + override def cancel() { + _cancelled = true jobWaiter.cancel() } @@ -143,6 +150,8 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def isCompleted: Boolean = jobWaiter.jobFinished + + override def isCancelled: Boolean = _cancelled override def value: Option[Try[T]] = { if (jobWaiter.jobFinished) { @@ -164,12 +173,10 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: /** - * :: Experimental :: * A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take, * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the * action thread if it is being blocked by a job. */ -@Experimental class ComplexFutureAction[T] extends FutureAction[T] { // Pointer to the thread that is executing the action. It is set when the action is run. @@ -222,7 +229,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { // If the action hasn't been cancelled yet, submit the job. The check and the submitJob // command need to be in an atomic block. val job = this.synchronized { - if (!cancelled) { + if (!isCancelled) { rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc) } else { throw new SparkException("Action has been cancelled") @@ -243,10 +250,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { } } - /** - * Returns whether the promise has been cancelled. - */ - def cancelled: Boolean = _cancelled + override def isCancelled: Boolean = _cancelled @throws(classOf[InterruptedException]) @throws(classOf[scala.concurrent.TimeoutException]) @@ -271,3 +275,55 @@ class ComplexFutureAction[T] extends FutureAction[T] { def jobIds = jobs } + +private[spark] +class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T) + extends JavaFutureAction[T] { + + import scala.collection.JavaConverters._ + + override def isCancelled: Boolean = futureAction.isCancelled + + override def isDone: Boolean = { + // According to java.util.Future's Javadoc, this returns True if the task was completed, + // whether that completion was due to successful execution, an exception, or a cancellation. + futureAction.isCancelled || futureAction.isCompleted + } + + override def jobIds(): java.util.List[java.lang.Integer] = { + Collections.unmodifiableList(futureAction.jobIds.map(Integer.valueOf).asJava) + } + + private def getImpl(timeout: Duration): T = { + // This will throw TimeoutException on timeout: + Await.ready(futureAction, timeout) + futureAction.value.get match { + case scala.util.Success(value) => converter(value) + case Failure(exception) => + if (isCancelled) { + throw new CancellationException("Job cancelled").initCause(exception) + } else { + // java.util.Future.get() wraps exceptions in ExecutionException + throw new ExecutionException("Exception thrown by job", exception) + } + } + } + + override def get(): T = getImpl(Duration.Inf) + + override def get(timeout: Long, unit: TimeUnit): T = + getImpl(Duration.fromNanos(unit.toNanos(timeout))) + + override def cancel(mayInterruptIfRunning: Boolean): Boolean = synchronized { + if (isDone) { + // According to java.util.Future's Javadoc, this should return false if the task is completed. + false + } else { + // We're limited in terms of the semantics we can provide here; our cancellation is + // asynchronous and doesn't provide a mechanism to not cancel if the job is running. + futureAction.cancel() + true + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 605df0e929faa..dbbcc23305c50 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -18,7 +18,8 @@ package org.apache.spark import scala.collection.JavaConverters._ -import scala.collection.mutable.HashMap +import scala.collection.mutable.{HashMap, LinkedHashSet} +import org.apache.spark.serializer.KryoSerializer /** * Configuration for a Spark application. Used to set various Spark parameters as key-value pairs. @@ -140,6 +141,20 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { this } + /** + * Use Kryo serialization and register the given set of classes with Kryo. + * If called multiple times, this will append the classes from all calls together. + */ + def registerKryoClasses(classes: Array[Class[_]]): SparkConf = { + val allClassNames = new LinkedHashSet[String]() + allClassNames ++= get("spark.kryo.classesToRegister", "").split(',').filter(!_.isEmpty) + allClassNames ++= classes.map(_.getName) + + set("spark.kryo.classesToRegister", allClassNames.mkString(",")) + set("spark.serializer", classOf[KryoSerializer].getName) + this + } + /** Remove a parameter from the configuration */ def remove(key: String): SparkConf = { settings.remove(key) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index dd3157990ef2d..ac7935b8c231e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -239,6 +239,10 @@ class SparkContext(config: SparkConf) extends Logging { None } + // Bind the UI before starting the task scheduler to communicate + // the bound port to the cluster manager properly + ui.foreach(_.bind()) + /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) @@ -341,10 +345,6 @@ class SparkContext(config: SparkConf) extends Logging { postEnvironmentUpdate() postApplicationStart() - // Bind the SparkUI after starting the task scheduler - // because certain pages and listeners depend on it - ui.foreach(_.bind()) - private[spark] var checkpointDir: Option[String] = None // Thread Local variable that can be used by users to pass information down the stack diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 0846225e4f992..c38b96528d037 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -35,6 +35,7 @@ import org.apache.spark.Partitioner._ import org.apache.spark.SparkContext.rddToPairRDDFunctions import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.{OrderedRDDFunctions, RDD} @@ -265,10 +266,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] = - mapAsJavaMap(rdd.reduceByKeyLocally(func)) + mapAsSerializableJavaMap(rdd.reduceByKeyLocally(func)) /** Count the number of elements for each key, and return the result to the master as a Map. */ - def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey()) + def countByKey(): java.util.Map[K, Long] = mapAsSerializableJavaMap(rdd.countByKey()) /** * :: Experimental :: @@ -277,7 +278,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) */ @Experimental def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] = - rdd.countByKeyApprox(timeout).map(mapAsJavaMap) + rdd.countByKeyApprox(timeout).map(mapAsSerializableJavaMap) /** * :: Experimental :: @@ -287,7 +288,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) @Experimental def countByKeyApprox(timeout: Long, confidence: Double = 0.95) : PartialResult[java.util.Map[K, BoundedDouble]] = - rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) + rdd.countByKeyApprox(timeout, confidence).map(mapAsSerializableJavaMap) /** * Aggregate the values of each key, using given combine functions and a neutral "zero value". @@ -614,7 +615,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Return the key-value pairs in this RDD to the master as a Map. */ - def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap()) + def collectAsMap(): java.util.Map[K, V] = mapAsSerializableJavaMap(rdd.collectAsMap()) + /** * Pass each value in the key-value pair RDD through a map function without changing the keys; diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 545bc0e9e99ed..efb8978f7ce12 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -21,15 +21,18 @@ import java.util.{Comparator, List => JList, Iterator => JIterator} import java.lang.{Iterable => JIterable, Long => JLong} import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.spark.{FutureAction, Partition, SparkContext, TaskContext} +import org.apache.spark._ +import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD @@ -293,8 +296,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Applies a function f to all elements of this RDD. */ def foreach(f: VoidFunction[T]) { - val cleanF = rdd.context.clean((x: T) => f.call(x)) - rdd.foreach(cleanF) + rdd.foreach(x => f.call(x)) } /** @@ -390,7 +392,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): java.util.Map[T, java.lang.Long] = - mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) + mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) /** * (Experimental) Approximate version of countByValue(). @@ -399,13 +401,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { timeout: Long, confidence: Double ): PartialResult[java.util.Map[T, BoundedDouble]] = - rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap) + rdd.countByValueApprox(timeout, confidence).map(mapAsSerializableJavaMap) /** * (Experimental) Approximate version of countByValue(). */ def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] = - rdd.countByValueApprox(timeout).map(mapAsJavaMap) + rdd.countByValueApprox(timeout).map(mapAsSerializableJavaMap) /** * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so @@ -575,16 +577,44 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def name(): String = rdd.name /** - * :: Experimental :: - * The asynchronous version of the foreach action. - * - * @param f the function to apply to all the elements of the RDD - * @return a FutureAction for the action + * The asynchronous version of `count`, which returns a + * future for counting the number of elements in this RDD. */ - @Experimental - def foreachAsync(f: VoidFunction[T]): FutureAction[Unit] = { - import org.apache.spark.SparkContext._ - rdd.foreachAsync(x => f.call(x)) + def countAsync(): JavaFutureAction[JLong] = { + new JavaFutureActionWrapper[Long, JLong](rdd.countAsync(), JLong.valueOf) + } + + /** + * The asynchronous version of `collect`, which returns a future for + * retrieving an array containing all of the elements in this RDD. + */ + def collectAsync(): JavaFutureAction[JList[T]] = { + new JavaFutureActionWrapper(rdd.collectAsync(), (x: Seq[T]) => x.asJava) + } + + /** + * The asynchronous version of the `take` action, which returns a + * future for retrieving the first `num` elements of this RDD. + */ + def takeAsync(num: Int): JavaFutureAction[JList[T]] = { + new JavaFutureActionWrapper(rdd.takeAsync(num), (x: Seq[T]) => x.asJava) } + /** + * The asynchronous version of the `foreach` action, which + * applies a function f to all the elements of this RDD. + */ + def foreachAsync(f: VoidFunction[T]): JavaFutureAction[Void] = { + new JavaFutureActionWrapper[Unit, Void](rdd.foreachAsync(x => f.call(x)), + { x => null.asInstanceOf[Void] }) + } + + /** + * The asynchronous version of the `foreachPartition` action, which + * applies a function f to each partition of this RDD. + */ + def foreachPartitionAsync(f: VoidFunction[java.util.Iterator[T]]): JavaFutureAction[Void] = { + new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x)), + { x => null.asInstanceOf[Void] }) + } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index 22810cb1c662d..b52d0a5028e84 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -19,10 +19,20 @@ package org.apache.spark.api.java import com.google.common.base.Optional +import scala.collection.convert.Wrappers.MapWrapper + private[spark] object JavaUtils { def optionToOptional[T](option: Option[T]): Optional[T] = option match { case Some(value) => Optional.of(value) case None => Optional.absent() } + + // Workaround for SPARK-3926 / SI-8911 + def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) = + new SerializableMapWrapper(underlying) + + class SerializableMapWrapper[A, B](underlying: collection.Map[A, B]) + extends MapWrapper(underlying) with java.io.Serializable + } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 42d58682a1e23..99af2e9608ea7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -26,6 +26,7 @@ import scala.util.Random import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec +import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.io.ByteArrayChunkOutputStream @@ -46,14 +47,12 @@ import org.apache.spark.util.io.ByteArrayChunkOutputStream * This prevents the driver from being the bottleneck in sending out multiple copies of the * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]]. * + * When initialized, TorrentBroadcast objects read SparkEnv.get.conf. + * * @param obj object to broadcast - * @param isLocal whether Spark is running in local mode (single JVM process). * @param id A unique identifier for the broadcast variable. */ -private[spark] class TorrentBroadcast[T: ClassTag]( - obj : T, - @transient private val isLocal: Boolean, - id: Long) +private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) extends Broadcast[T](id) with Logging with Serializable { /** @@ -62,6 +61,20 @@ private[spark] class TorrentBroadcast[T: ClassTag]( * blocks from the driver and/or other executors. */ @transient private var _value: T = obj + /** The compression codec to use, or None if compression is disabled */ + @transient private var compressionCodec: Option[CompressionCodec] = _ + /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */ + @transient private var blockSize: Int = _ + + private def setConf(conf: SparkConf) { + compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) { + Some(CompressionCodec.createCodec(conf)) + } else { + None + } + blockSize = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 + } + setConf(SparkEnv.get.conf) private val broadcastId = BroadcastBlockId(id) @@ -76,23 +89,20 @@ private[spark] class TorrentBroadcast[T: ClassTag]( * @return number of blocks this broadcast variable is divided into */ private def writeBlocks(): Int = { - // For local mode, just put the object in the BlockManager so we can find it later. - SparkEnv.get.blockManager.putSingle( - broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - - if (!isLocal) { - val blocks = TorrentBroadcast.blockifyObject(_value) - blocks.zipWithIndex.foreach { case (block, i) => - SparkEnv.get.blockManager.putBytes( - BroadcastBlockId(id, "piece" + i), - block, - StorageLevel.MEMORY_AND_DISK_SER, - tellMaster = true) - } - blocks.length - } else { - 0 + // Store a copy of the broadcast variable in the driver so that tasks run on the driver + // do not create a duplicate copy of the broadcast variable's value. + SparkEnv.get.blockManager.putSingle(broadcastId, _value, StorageLevel.MEMORY_AND_DISK, + tellMaster = false) + val blocks = + TorrentBroadcast.blockifyObject(_value, blockSize, SparkEnv.get.serializer, compressionCodec) + blocks.zipWithIndex.foreach { case (block, i) => + SparkEnv.get.blockManager.putBytes( + BroadcastBlockId(id, "piece" + i), + block, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true) } + blocks.length } /** Fetch torrent blocks from the driver and/or other executors. */ @@ -104,29 +114,24 @@ private[spark] class TorrentBroadcast[T: ClassTag]( for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { val pieceId = BroadcastBlockId(id, "piece" + pid) - - // First try getLocalBytes because there is a chance that previous attempts to fetch the + logDebug(s"Reading piece $pieceId of $broadcastId") + // First try getLocalBytes because there is a chance that previous attempts to fetch the // broadcast blocks have already fetched some of the blocks. In that case, some blocks // would be available locally (on this executor). - var blockOpt = bm.getLocalBytes(pieceId) - if (!blockOpt.isDefined) { - blockOpt = bm.getRemoteBytes(pieceId) - blockOpt match { - case Some(block) => - // If we found the block from remote executors/driver's BlockManager, put the block - // in this executor's BlockManager. - SparkEnv.get.blockManager.putBytes( - pieceId, - block, - StorageLevel.MEMORY_AND_DISK_SER, - tellMaster = true) - - case None => - throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) - } + def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId) + def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block => + // If we found the block from remote executors/driver's BlockManager, put the block + // in this executor's BlockManager. + SparkEnv.get.blockManager.putBytes( + pieceId, + block, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true) + block } - // If we get here, the option is defined. - blocks(pid) = blockOpt.get + val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse( + throw new SparkException(s"Failed to get $pieceId of $broadcastId")) + blocks(pid) = block } blocks } @@ -156,6 +161,7 @@ private[spark] class TorrentBroadcast[T: ClassTag]( private def readObject(in: ObjectInputStream) { in.defaultReadObject() TorrentBroadcast.synchronized { + setConf(SparkEnv.get.conf) SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match { case Some(x) => _value = x.asInstanceOf[T] @@ -167,7 +173,8 @@ private[spark] class TorrentBroadcast[T: ClassTag]( val time = (System.nanoTime() - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") - _value = TorrentBroadcast.unBlockifyObject[T](blocks) + _value = + TorrentBroadcast.unBlockifyObject[T](blocks, SparkEnv.get.serializer, compressionCodec) // Store the merged copy in BlockManager so other tasks on this executor don't // need to re-fetch it. SparkEnv.get.blockManager.putSingle( @@ -179,43 +186,29 @@ private[spark] class TorrentBroadcast[T: ClassTag]( private object TorrentBroadcast extends Logging { - /** Size of each block. Default value is 4MB. */ - private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 - private var initialized = false - private var conf: SparkConf = null - private var compress: Boolean = false - private var compressionCodec: CompressionCodec = null - - def initialize(_isDriver: Boolean, conf: SparkConf) { - TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests - synchronized { - if (!initialized) { - compress = conf.getBoolean("spark.broadcast.compress", true) - compressionCodec = CompressionCodec.createCodec(conf) - initialized = true - } - } - } - def stop() { - initialized = false - } - - def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = { - val bos = new ByteArrayChunkOutputStream(BLOCK_SIZE) - val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos - val ser = SparkEnv.get.serializer.newInstance() + def blockifyObject[T: ClassTag]( + obj: T, + blockSize: Int, + serializer: Serializer, + compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = { + val bos = new ByteArrayChunkOutputStream(blockSize) + val out: OutputStream = compressionCodec.map(c => c.compressedOutputStream(bos)).getOrElse(bos) + val ser = serializer.newInstance() val serOut = ser.serializeStream(out) serOut.writeObject[T](obj).close() bos.toArrays.map(ByteBuffer.wrap) } - def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = { + def unBlockifyObject[T: ClassTag]( + blocks: Array[ByteBuffer], + serializer: Serializer, + compressionCodec: Option[CompressionCodec]): T = { + require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") val is = new SequenceInputStream( asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block)))) - val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is - - val ser = SparkEnv.get.serializer.newInstance() + val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) + val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) val obj = serIn.readObject[T]() serIn.close() @@ -227,6 +220,7 @@ private object TorrentBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver. */ def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = { + logDebug(s"Unpersisting TorrentBroadcast $id") SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index ad0f701d7a98f..fb024c12094f2 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -28,14 +28,13 @@ import org.apache.spark.{SecurityManager, SparkConf} */ class TorrentBroadcastFactory extends BroadcastFactory { - override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - TorrentBroadcast.initialize(isDriver, conf) - } + override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { } - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = - new TorrentBroadcast[T](value_, isLocal, id) + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = { + new TorrentBroadcast[T](value_, id) + } - override def stop() { TorrentBroadcast.stop() } + override def stop() { } /** * Remove all persisted state associated with the torrent broadcast with the given ID. diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index a7368f9f3dfbe..b9dd8557ee904 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -71,6 +71,8 @@ private[deploy] object DeployMessages { case class RegisterWorkerFailed(message: String) extends DeployMessage + case class ReconnectWorker(masterUrl: String) extends DeployMessage + case class KillExecutor(masterUrl: String, appId: String, execId: Int) extends DeployMessage case class LaunchExecutor( diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index f98b531316a3d..3b6bb9fe128a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -341,7 +341,14 @@ private[spark] class Master( case Some(workerInfo) => workerInfo.lastHeartbeat = System.currentTimeMillis() case None => - logWarning("Got heartbeat from unregistered worker " + workerId) + if (workers.map(_.id).contains(workerId)) { + logWarning(s"Got heartbeat from unregistered worker $workerId." + + " Asking it to re-register.") + sender ! ReconnectWorker(masterUrl) + } else { + logWarning(s"Got heartbeat from unregistered worker $workerId." + + " This worker was never registered, so ignoring the heartbeat.") + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 9b52cb06fb6fa..c4a8ec2e5e7b0 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -20,12 +20,14 @@ package org.apache.spark.deploy.worker import java.io.File import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{UUID, Date} +import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap import scala.concurrent.duration._ import scala.language.postfixOps +import scala.util.Random import akka.actor._ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} @@ -64,8 +66,22 @@ private[spark] class Worker( // Send a heartbeat every (heartbeat timeout) / 4 milliseconds val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 - val REGISTRATION_TIMEOUT = 20.seconds - val REGISTRATION_RETRIES = 3 + // Model retries to connect to the master, after Hadoop's model. + // The first six attempts to reconnect are in shorter intervals (between 5 and 15 seconds) + // Afterwards, the next 10 attempts are between 30 and 90 seconds. + // A bit of randomness is introduced so that not all of the workers attempt to reconnect at + // the same time. + val INITIAL_REGISTRATION_RETRIES = 6 + val TOTAL_REGISTRATION_RETRIES = INITIAL_REGISTRATION_RETRIES + 10 + val FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND = 0.500 + val REGISTRATION_RETRY_FUZZ_MULTIPLIER = { + val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits) + randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND + } + val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 * + REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds + val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60 + * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders @@ -103,6 +119,7 @@ private[spark] class Worker( var coresUsed = 0 var memoryUsed = 0 + var connectionAttemptCount = 0 val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) val workerSource = new WorkerSource(this) @@ -158,7 +175,7 @@ private[spark] class Worker( connected = true } - def tryRegisterAllMasters() { + private def tryRegisterAllMasters() { for (masterUrl <- masterUrls) { logInfo("Connecting to master " + masterUrl + "...") val actor = context.actorSelection(Master.toAkkaUrl(masterUrl)) @@ -166,26 +183,47 @@ private[spark] class Worker( } } - def registerWithMaster() { - tryRegisterAllMasters() - var retries = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { - Utils.tryOrExit { - retries += 1 - if (registered) { - registrationRetryTimer.foreach(_.cancel()) - } else if (retries >= REGISTRATION_RETRIES) { - logError("All masters are unresponsive! Giving up.") - System.exit(1) - } else { - tryRegisterAllMasters() + private def retryConnectToMaster() { + Utils.tryOrExit { + connectionAttemptCount += 1 + logInfo(s"Attempting to connect to master (attempt # $connectionAttemptCount") + if (registered) { + registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer = None + } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { + tryRegisterAllMasters() + if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { + registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer = Some { + context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL, + PROLONGED_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster) } } + } else { + logError("All masters are unresponsive! Giving up.") + System.exit(1) } } } + def registerWithMaster() { + // DisassociatedEvent may be triggered multiple times, so don't attempt registration + // if there are outstanding registration attempts scheduled. + registrationRetryTimer match { + case None => + registered = false + tryRegisterAllMasters() + connectionAttemptCount = 0 + registrationRetryTimer = Some { + context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL, + INITIAL_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster) + } + case Some(_) => + logInfo("Not spawning another attempt to register with the master, since there is an" + + " attempt scheduled already.") + } + } + override def receiveWithLogging = { case RegisteredWorker(masterUrl, masterWebUiUrl) => logInfo("Successfully registered with master " + masterUrl) @@ -243,6 +281,10 @@ private[spark] class Worker( System.exit(1) } + case ReconnectWorker(masterUrl) => + logInfo(s"Master with url $masterUrl requested this worker to reconnect.") + registerWithMaster() + case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_) => if (masterUrl != activeMasterUrl) { logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor.") @@ -362,9 +404,10 @@ private[spark] class Worker( } } - def masterDisconnected() { + private def masterDisconnected() { logError("Connection to master failed! Waiting for master to reconnect...") connected = false + registerWithMaster() } def generateWorkerId(): String = { diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index ede5568493cc0..9f9f10b7ebc3a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -24,14 +24,11 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.reflect.ClassTag import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} -import org.apache.spark.annotation.Experimental /** - * :: Experimental :: * A set of asynchronous RDD actions available through an implicit conversion. * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. */ -@Experimental class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging { /** diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 8010dd90082f8..775141775e06c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -132,27 +132,47 @@ class HadoopRDD[K, V]( // used to build JobTracker ID private val createTime = new Date() + private val shouldCloneJobConf = sc.conf.get("spark.hadoop.cloneConf", "false").toBoolean + // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { val conf: Configuration = broadcastedConf.value.value - if (conf.isInstanceOf[JobConf]) { - // A user-broadcasted JobConf was provided to the HadoopRDD, so always use it. - conf.asInstanceOf[JobConf] - } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { - // getJobConf() has been called previously, so there is already a local cache of the JobConf - // needed by this RDD. - HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] - } else { - // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the - // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). - // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. - // Synchronize to prevent ConcurrentModificationException (Spark-1097, Hadoop-10456). + if (shouldCloneJobConf) { + // Hadoop Configuration objects are not thread-safe, which may lead to various problems if + // one job modifies a configuration while another reads it (SPARK-2546). This problem occurs + // somewhat rarely because most jobs treat the configuration as though it's immutable. One + // solution, implemented here, is to clone the Configuration object. Unfortunately, this + // clone can be very expensive. To avoid unexpected performance regressions for workloads and + // Hadoop versions that do not suffer from these thread-safety issues, this cloning is + // disabled by default. HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Cloning Hadoop Configuration") val newJobConf = new JobConf(conf) - initLocalJobConfFuncOpt.map(f => f(newJobConf)) - HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + if (!conf.isInstanceOf[JobConf]) { + initLocalJobConfFuncOpt.map(f => f(newJobConf)) + } newJobConf } + } else { + if (conf.isInstanceOf[JobConf]) { + logDebug("Re-using user-broadcasted JobConf") + conf.asInstanceOf[JobConf] + } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { + logDebug("Re-using cached JobConf") + HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] + } else { + // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the + // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). + // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. + // Synchronize to prevent ConcurrentModificationException (SPARK-1097, HADOOP-10456). + HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Creating new JobConf and caching it for later re-use") + val newJobConf = new JobConf(conf) + initLocalJobConfFuncOpt.map(f => f(newJobConf)) + HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + newJobConf + } + } } } @@ -276,7 +296,10 @@ class HadoopRDD[K, V]( } private[spark] object HadoopRDD extends Logging { - /** Constructing Configuration objects is not threadsafe, use this lock to serialize. */ + /** + * Configuration's constructor is not threadsafe (see SPARK-1097 and HADOOP-10456). + * Therefore, we synchronize on this lock before calling new JobConf() or new Configuration(). + */ val CONFIGURATION_INSTANTIATION_LOCK = new Object() /** diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index ac96de86dd6d4..da89f634abaea 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -315,8 +315,15 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) @deprecated("Use reduceByKeyLocally", "1.0.0") def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func) - /** Count the number of elements for each key, and return the result to the master as a Map. */ - def countByKey(): Map[K, Long] = self.map(_._1).countByValue() + /** + * Count the number of elements for each key, collecting the results to a local Map. + * + * Note that this method should only be used if the resulting map is expected to be small, as + * the whole thing is loaded into the driver's memory. + * To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which + * returns an RDD[T, Long] instead of a map. + */ + def countByKey(): Map[K, Long] = self.mapValues(_ => 1L).reduceByKey(_ + _).collect().toMap /** * :: Experimental :: diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 71cabf61d4ee0..b7f125d01dfaf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -927,32 +927,15 @@ abstract class RDD[T: ClassTag]( } /** - * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final - * combine step happens locally on the master, equivalent to running a single reduce task. + * Return the count of each unique value in this RDD as a local map of (value, count) pairs. + * + * Note that this method should only be used if the resulting map is expected to be small, as + * the whole thing is loaded into the driver's memory. + * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which + * returns an RDD[T, Long] instead of a map. */ def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = { - if (elementClassTag.runtimeClass.isArray) { - throw new SparkException("countByValue() does not support arrays") - } - // TODO: This should perhaps be distributed by default. - val countPartition = (iter: Iterator[T]) => { - val map = new OpenHashMap[T,Long] - iter.foreach { - t => map.changeValue(t, 1L, _ + 1L) - } - Iterator(map) - }: Iterator[OpenHashMap[T,Long]] - val mergeMaps = (m1: OpenHashMap[T,Long], m2: OpenHashMap[T,Long]) => { - m2.foreach { case (key, value) => - m1.changeValue(key, value, _ + value) - } - m1 - }: OpenHashMap[T,Long] - val myResult = mapPartitions(countPartition).reduce(mergeMaps) - // Convert to a Scala mutable map - val mutableResult = scala.collection.mutable.Map[T,Long]() - myResult.foreach { case (k, v) => mutableResult.put(k, v) } - mutableResult + map(value => (value, null)).countByKey() } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 6d697e3d003f6..2b39c7fc872da 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -221,6 +221,7 @@ private[spark] class TaskSchedulerImpl( var newExecAvail = false for (o <- offers) { executorIdToHost(o.executorId) = o.host + activeExecutorIds += o.executorId if (!executorsByHost.contains(o.host)) { executorsByHost(o.host) = new HashSet[String]() executorAdded(o.executorId, o.host) @@ -261,7 +262,6 @@ private[spark] class TaskSchedulerImpl( val tid = task.taskId taskIdToTaskSetId(tid) = taskSet.taskSet.id taskIdToExecutorId(tid) = execId - activeExecutorIds += execId executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK assert(availableCpus(i) >= 0) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index d6386f8c06fff..621a951c27d07 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -53,7 +53,18 @@ class KryoSerializer(conf: SparkConf) private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) - private val registrator = conf.getOption("spark.kryo.registrator") + private val userRegistrator = conf.getOption("spark.kryo.registrator") + private val classesToRegister = conf.get("spark.kryo.classesToRegister", "") + .split(',') + .filter(!_.isEmpty) + .map { className => + try { + Class.forName(className) + } catch { + case e: Exception => + throw new SparkException("Failed to load class to register with Kryo", e) + } + } def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) @@ -80,22 +91,20 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) - // Allow the user to register their own classes by setting spark.kryo.registrator - for (regCls <- registrator) { - logDebug("Running user registrator: " + regCls) - try { - val reg = Class.forName(regCls, true, classLoader).newInstance() - .asInstanceOf[KryoRegistrator] - - // Use the default classloader when calling the user registrator. - Thread.currentThread.setContextClassLoader(classLoader) - reg.registerClasses(kryo) - } catch { - case e: Exception => - throw new SparkException(s"Failed to invoke $regCls", e) - } finally { - Thread.currentThread.setContextClassLoader(oldClassLoader) - } + try { + // Use the default classloader when calling the user registrator. + Thread.currentThread.setContextClassLoader(classLoader) + // Register classes given through spark.kryo.classesToRegister. + classesToRegister.foreach { clazz => kryo.register(clazz) } + // Allow the user to register their own classes by setting spark.kryo.registrator. + userRegistrator + .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) + .foreach { reg => reg.registerClasses(kryo) } + } catch { + case e: Exception => + throw new SparkException(s"Failed to register classes with Kryo", e) + } finally { + Thread.currentThread.setContextClassLoader(oldClassLoader) } // Register Chill's classes; we do this after our ranges and the user's own classes to let diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3f5d06e1aeee7..0ce2a3f631b15 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -870,7 +870,7 @@ private[spark] class BlockManager( logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer") blockTransferService.uploadBlockSync( peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel) - logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %f ms" + logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %d ms" .format((System.currentTimeMillis - onePeerStartTime))) peersReplicatedTo += peer peersForReplication -= peer diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala index a82f71ed08475..1e02f1225d344 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala @@ -29,7 +29,7 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") private val live = parent.live private val sc = parent.sc private val listener = parent.listener - private lazy val isFairScheduler = parent.isFairScheduler + private def isFairScheduler = parent.isFairScheduler def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 53a7512edd852..0aeff6455b3fe 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -269,23 +269,44 @@ private[spark] object Utils extends Logging { dir } - /** Copy all data from an InputStream to an OutputStream */ + /** Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream + * copying is disabled by default unless explicitly set transferToEnabled as true, + * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false]. + */ def copyStream(in: InputStream, out: OutputStream, - closeStreams: Boolean = false): Long = + closeStreams: Boolean = false, + transferToEnabled: Boolean = false): Long = { var count = 0L try { - if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream]) { + if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream] + && transferToEnabled) { // When both streams are File stream, use transferTo to improve copy performance. val inChannel = in.asInstanceOf[FileInputStream].getChannel() val outChannel = out.asInstanceOf[FileOutputStream].getChannel() + val initialPos = outChannel.position() val size = inChannel.size() // In case transferTo method transferred less data than we have required. while (count < size) { count += inChannel.transferTo(count, size - count, outChannel) } + + // Check the position after transferTo loop to see if it is in the right position and + // give user information if not. + // Position will not be increased to the expected length after calling transferTo in + // kernel version 2.6.32, this issue can be seen in + // https://bugs.openjdk.java.net/browse/JDK-7052359 + // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). + val finalPos = outChannel.position() + assert(finalPos == initialPos + size, + s""" + |Current position $finalPos do not equal to expected position ${initialPos + size} + |after transferTo, please check your kernel version to see if it is 2.6.32, + |this is a kernel bug which will lead to unexpected behavior when using transferTo. + |You can set spark.file.transferTo = false to disable this NIO feature. + """.stripMargin) } else { val buf = new Array[Byte](8192) var n = 0 @@ -727,7 +748,7 @@ private[spark] object Utils extends Logging { /** * Determines if a directory contains any files newer than cutoff seconds. - * + * * @param dir must be the path to a directory, or IllegalArgumentException is thrown * @param cutoff measured in seconds. Returns true if there are any files or directories in the * given directory whose last modified time is later than this many seconds ago diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 644fa36818647..d1b06d14acbd2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -93,6 +93,7 @@ private[spark] class ExternalSorter[K, V, C]( private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) // Size of object batches when reading/writing from serializers. // @@ -705,10 +706,10 @@ private[spark] class ExternalSorter[K, V, C]( var out: FileOutputStream = null var in: FileInputStream = null try { - out = new FileOutputStream(outputFile) + out = new FileOutputStream(outputFile, true) for (i <- 0 until numPartitions) { in = new FileInputStream(partitionWriters(i).fileSegment().file) - val size = org.apache.spark.util.Utils.copyStream(in, out, false) + val size = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) in.close() in = null lengths(i) = size diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index b8fa822ae4bd8..814e40c4f77cc 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -20,6 +20,7 @@ import java.io.*; import java.net.URI; import java.util.*; +import java.util.concurrent.*; import scala.Tuple2; import scala.Tuple3; @@ -29,6 +30,7 @@ import com.google.common.collect.Iterators; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.base.Throwables; import com.google.common.base.Optional; import com.google.common.base.Charsets; import com.google.common.io.Files; @@ -43,10 +45,7 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaDoubleRDD; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.*; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.partial.BoundedDouble; @@ -1308,6 +1307,92 @@ public void collectUnderlyingScalaRDD() { Assert.assertEquals(data.size(), collected.length); } + private static final class BuggyMapFunction implements Function { + + @Override + public T call(T x) throws Exception { + throw new IllegalStateException("Custom exception!"); + } + } + + @Test + public void collectAsync() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction> future = rdd.collectAsync(); + List result = future.get(); + Assert.assertEquals(data, result); + Assert.assertFalse(future.isCancelled()); + Assert.assertTrue(future.isDone()); + Assert.assertEquals(1, future.jobIds().size()); + } + + @Test + public void foreachAsync() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction future = rdd.foreachAsync( + new VoidFunction() { + @Override + public void call(Integer integer) throws Exception { + // intentionally left blank. + } + } + ); + future.get(); + Assert.assertFalse(future.isCancelled()); + Assert.assertTrue(future.isDone()); + Assert.assertEquals(1, future.jobIds().size()); + } + + @Test + public void countAsync() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction future = rdd.countAsync(); + long count = future.get(); + Assert.assertEquals(data.size(), count); + Assert.assertFalse(future.isCancelled()); + Assert.assertTrue(future.isDone()); + Assert.assertEquals(1, future.jobIds().size()); + } + + @Test + public void testAsyncActionCancellation() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction future = rdd.foreachAsync(new VoidFunction() { + @Override + public void call(Integer integer) throws Exception { + Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled. + } + }); + future.cancel(true); + Assert.assertTrue(future.isCancelled()); + Assert.assertTrue(future.isDone()); + try { + future.get(2000, TimeUnit.MILLISECONDS); + Assert.fail("Expected future.get() for cancelled job to throw CancellationException"); + } catch (CancellationException ignored) { + // pass + } + } + + @Test + public void testAsyncActionErrorWrapping() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction future = rdd.map(new BuggyMapFunction()).countAsync(); + try { + future.get(2, TimeUnit.SECONDS); + Assert.fail("Expected future.get() for failed job to throw ExcecutionException"); + } catch (ExecutionException ee) { + Assert.assertTrue(Throwables.getStackTraceAsString(ee).contains("Custom exception!")); + } + Assert.assertTrue(future.isDone()); + } + + /** * Test for SPARK-3647. This test needs to use the maven-built assembly to trigger the issue, * since that's the only artifact where Guava classes have been relocated. @@ -1333,4 +1418,16 @@ public Optional call(Integer i) { } } + static class Class1 {} + static class Class2 {} + + @Test + public void testRegisterKryoClasses() { + SparkConf conf = new SparkConf(); + conf.registerKryoClasses(new Class[]{ Class1.class, Class2.class }); + Assert.assertEquals( + Class1.class.getName() + "," + Class2.class.getName(), + conf.get("spark.kryo.classesToRegister")); + } + } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 87e9012622456..5d018ea9868a7 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark import org.scalatest.FunSuite +import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} +import com.esotericsoftware.kryo.Kryo class SparkConfSuite extends FunSuite with LocalSparkContext { test("loading from system properties") { @@ -133,4 +135,64 @@ class SparkConfSuite extends FunSuite with LocalSparkContext { System.clearProperty("spark.test.a.b.c") } } + + test("register kryo classes through registerKryoClasses") { + val conf = new SparkConf().set("spark.kryo.registrationRequired", "true") + + conf.registerKryoClasses(Array(classOf[Class1], classOf[Class2])) + assert(conf.get("spark.kryo.classesToRegister") === + classOf[Class1].getName + "," + classOf[Class2].getName) + + conf.registerKryoClasses(Array(classOf[Class3])) + assert(conf.get("spark.kryo.classesToRegister") === + classOf[Class1].getName + "," + classOf[Class2].getName + "," + classOf[Class3].getName) + + conf.registerKryoClasses(Array(classOf[Class2])) + assert(conf.get("spark.kryo.classesToRegister") === + classOf[Class1].getName + "," + classOf[Class2].getName + "," + classOf[Class3].getName) + + // Kryo doesn't expose a way to discover registered classes, but at least make sure this doesn't + // blow up. + val serializer = new KryoSerializer(conf) + serializer.newInstance().serialize(new Class1()) + serializer.newInstance().serialize(new Class2()) + serializer.newInstance().serialize(new Class3()) + } + + test("register kryo classes through registerKryoClasses and custom registrator") { + val conf = new SparkConf().set("spark.kryo.registrationRequired", "true") + + conf.registerKryoClasses(Array(classOf[Class1])) + assert(conf.get("spark.kryo.classesToRegister") === classOf[Class1].getName) + + conf.set("spark.kryo.registrator", classOf[CustomRegistrator].getName) + + // Kryo doesn't expose a way to discover registered classes, but at least make sure this doesn't + // blow up. + val serializer = new KryoSerializer(conf) + serializer.newInstance().serialize(new Class1()) + serializer.newInstance().serialize(new Class2()) + } + + test("register kryo classes through conf") { + val conf = new SparkConf().set("spark.kryo.registrationRequired", "true") + conf.set("spark.kryo.classesToRegister", "java.lang.StringBuffer") + conf.set("spark.serializer", classOf[KryoSerializer].getName) + + // Kryo doesn't expose a way to discover registered classes, but at least make sure this doesn't + // blow up. + val serializer = new KryoSerializer(conf) + serializer.newInstance().serialize(new StringBuffer()) + } + +} + +class Class1 {} +class Class2 {} +class Class3 {} + +class CustomRegistrator extends KryoRegistrator { + def registerClasses(kryo: Kryo) { + kryo.register(classOf[Class2]) + } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index acaf321de52fb..e096c8c3e9b46 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.broadcast +import scala.util.Random + import org.scalatest.FunSuite import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException} +import org.apache.spark.io.SnappyCompressionCodec +import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage._ - class BroadcastSuite extends FunSuite with LocalSparkContext { private val httpConf = broadcastConf("HttpBroadcastFactory") @@ -84,6 +87,24 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } + test("TorrentBroadcast's blockifyObject and unblockifyObject are inverses") { + import org.apache.spark.broadcast.TorrentBroadcast._ + val blockSize = 1024 + val conf = new SparkConf() + val compressionCodec = Some(new SnappyCompressionCodec(conf)) + val serializer = new JavaSerializer(conf) + val seed = 42 + val rand = new Random(seed) + for (trial <- 1 to 100) { + val size = 1 + rand.nextInt(1024 * 10) + val data: Array[Byte] = new Array[Byte](size) + rand.nextBytes(data) + val blocks = blockifyObject(data, blockSize, serializer, compressionCodec) + val unblockified = unBlockifyObject[Array[Byte]](blocks, serializer, compressionCodec) + assert(unblockified === data) + } + } + test("Unpersisting HttpBroadcast on executors only in local mode") { testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false) } @@ -193,26 +214,17 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { blockId = BroadcastBlockId(broadcastId, "piece0") statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === (if (distributed) 1 else 0)) + assert(statuses.size === 1) } // Verify that blocks are persisted in both the executors and the driver def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) { var blockId = BroadcastBlockId(broadcastId) - var statuses = bmm.getBlockStatus(blockId, askSlaves = true) - if (distributed) { - assert(statuses.size === numSlaves + 1) - } else { - assert(statuses.size === 1) - } + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === numSlaves + 1) blockId = BroadcastBlockId(broadcastId, "piece0") - statuses = bmm.getBlockStatus(blockId, askSlaves = true) - if (distributed) { - assert(statuses.size === numSlaves + 1) - } else { - assert(statuses.size === 0) - } + assert(statuses.size === numSlaves + 1) } // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver @@ -224,7 +236,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(statuses.size === expectedNumBlocks) blockId = BroadcastBlockId(broadcastId, "piece0") - expectedNumBlocks = if (removeFromDriver || !distributed) 0 else 1 + expectedNumBlocks = if (removeFromDriver) 0 else 1 statuses = bmm.getBlockStatus(blockId, askSlaves = true) assert(statuses.size === expectedNumBlocks) } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index e1e35b688d581..64ac6d2d920d2 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -210,13 +210,13 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } test("kryo with nonexistent custom registrator should fail") { - import org.apache.spark.{SparkConf, SparkException} + import org.apache.spark.SparkException val conf = new SparkConf(false) conf.set("spark.kryo.registrator", "this.class.does.not.exist") - + val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance()) - assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist")) + assert(thrown.getMessage.contains("Failed to register classes with Kryo")) } test("default class loader can be set by a different thread") { diff --git a/docs/README.md b/docs/README.md index 0facecdd5f767..d2d58e435d4c4 100644 --- a/docs/README.md +++ b/docs/README.md @@ -25,8 +25,7 @@ installing via the Ruby Gem dependency manager. Since the exact HTML output varies between versions of Jekyll and its dependencies, we list specific versions here in some cases: - $ sudo gem install jekyll -v 1.4.3 - $ sudo gem uninstall kramdown -v 1.4.1 + $ sudo gem install jekyll $ sudo gem install jekyll-redirect-from Execute `jekyll` from the `docs/` directory. Compiling the site with Jekyll will create a directory diff --git a/docs/configuration.md b/docs/configuration.md index f0204c640bc89..66738d3ca754e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -124,12 +124,23 @@ of the most common options to set are: org.apache.spark.Serializer. + + spark.kryo.classesToRegister + (none) + + If you use Kryo serialization, give a comma-separated list of custom class names to register + with Kryo. + See the tuning guide for more details. + + spark.kryo.registrator (none) - If you use Kryo serialization, set this class to register your custom classes with Kryo. - It should be set to a class that extends + If you use Kryo serialization, set this class to register your custom classes with Kryo. This + property is useful if you need to register your classes in a custom way, e.g. to specify a custom + field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be + set to a class that extends KryoRegistrator. See the tuning guide for more details. @@ -619,6 +630,15 @@ Apart from these, the following properties are also available, and may be useful output directories. We recommend that users do not disable this except if trying to achieve compatibility with previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. + + spark.hadoop.cloneConf + false + If set to true, clones a new Hadoop Configuration object for each task. This + option should be enabled to work around Configuration thread-safety issues (see + SPARK-2546 for more details). + This is disabled by default in order to avoid unexpected performance regressions for jobs that + are not affected by these issues. + spark.executor.heartbeatInterval 10000 diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 738309c668387..8bbba88b31978 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -212,6 +212,67 @@ The complete code can be found in the Spark Streaming example [JavaNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java).
+ +
+First, we import StreamingContext, which is the main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and batch interval of 1 second. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +# Create a local StreamingContext with two working thread and batch interval of 1 second +sc = SparkContext("local[2]", "NetworkWordCount") +ssc = StreamingContext(sc, 1) +{% endhighlight %} + +Using this context, we can create a DStream that represents streaming data from a TCP +source hostname, e.g. `localhost`, and port, e.g. `9999` + +{% highlight python %} +# Create a DStream that will connect to hostname:port, like localhost:9999 +lines = ssc.socketTextStream("localhost", 9999) +{% endhighlight %} + +This `lines` DStream represents the stream of data that will be received from the data +server. Each record in this DStream is a line of text. Next, we want to split the lines by +space into words. + +{% highlight python %} +# Split each line into words +words = lines.flatMap(lambda line: line.split(" ")) +{% endhighlight %} + +`flatMap` is a one-to-many DStream operation that creates a new DStream by +generating multiple new records from each record in the source DStream. In this case, +each line will be split into multiple words and the stream of words is represented as the +`words` DStream. Next, we want to count these words. + +{% highlight python %} +# Count each word in each batch +pairs = words.map(lambda word: (word, 1)) +wordCounts = pairs.reduceByKey(lambda x, y: x + y) + +# Print the first ten elements of each RDD generated in this DStream to the console +wordCounts.pprint() +{% endhighlight %} + +The `words` DStream is further mapped (one-to-one transformation) to a DStream of `(word, +1)` pairs, which is then reduced to get the frequency of words in each batch of data. +Finally, `wordCounts.pprint()` will print a few of the counts generated every second. + +Note that when these lines are executed, Spark Streaming only sets up the computation it +will perform when it is started, and no real processing has started yet. To start the processing +after all the transformations have been setup, we finally call + +{% highlight python %} +ssc.start() # Start the computation +ssc.awaitTermination() # Wait for the computation to terminate +{% endhighlight %} + +The complete code can be found in the Spark Streaming example +[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/network_wordcount.py). +
+
@@ -236,6 +297,11 @@ $ ./bin/run-example streaming.NetworkWordCount localhost 9999 $ ./bin/run-example streaming.JavaNetworkWordCount localhost 9999 {% endhighlight %} +
+{% highlight bash %} +$ ./bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999 +{% endhighlight %} +
@@ -259,8 +325,11 @@ hello world +
+ +
{% highlight bash %} -# TERMINAL 2: RUNNING NetworkWordCount or JavaNetworkWordCount +# TERMINAL 2: RUNNING NetworkWordCount $ ./bin/run-example streaming.NetworkWordCount localhost 9999 ... @@ -271,6 +340,37 @@ Time: 1357008430000 ms (world,1) ... {% endhighlight %} +
+ +
+{% highlight bash %} +# TERMINAL 2: RUNNING JavaNetworkWordCount + +$ ./bin/run-example streaming.JavaNetworkWordCount localhost 9999 +... +------------------------------------------- +Time: 1357008430000 ms +------------------------------------------- +(hello,1) +(world,1) +... +{% endhighlight %} +
+
+{% highlight bash %} +# TERMINAL 2: RUNNING network_wordcount.py + +$ ./bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999 +... +------------------------------------------- +Time: 2014-10-14 15:25:21 +------------------------------------------- +(hello,1) +(world,1) +... +{% endhighlight %} +
+
@@ -398,9 +498,34 @@ JavaSparkContext sc = ... //existing JavaSparkContext JavaStreamingContext ssc = new JavaStreamingContext(sc, new Duration(1000)); {% endhighlight %} +
+ +A [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) object can be created from a [SparkContext](api/python/pyspark.html#pyspark.SparkContext) object. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +sc = SparkContext(master, appName) +ssc = StreamingContext(sc, 1) +{% endhighlight %} + +The `appName` parameter is a name for your application to show on the cluster UI. +`master` is a [Spark, Mesos or YARN cluster URL](submitting-applications.html#master-urls), +or a special __"local[\*]"__ string to run in local mode. In practice, when running on a cluster, +you will not want to hardcode `master` in the program, +but rather [launch the application with `spark-submit`](submitting-applications.html) and +receive it there. However, for local testing and unit tests, you can pass "local[\*]" to run Spark Streaming +in-process (detects the number of cores in the local system). + +The batch interval must be set based on the latency requirements of your application +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +section for more details. +
After a context is defined, you have to do the follow steps. + 1. Define the input sources. 1. Setup the streaming computations. 1. Start the receiving and procesing of data using `streamingContext.start()`. @@ -483,6 +608,9 @@ methods for creating DStreams from files and Akka actors as input sources.
streamingContext.fileStream(dataDirectory);
+
+ streamingContext.textFileStream(dataDirectory) +
Spark Streaming will monitor the directory `dataDirectory` and process any files created in that directory (files written in nested directories not supported). Note that @@ -684,13 +812,30 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi JavaPairDStream runningCounts = pairs.updateStateByKey(updateFunction); {% endhighlight %} + +
+ +{% highlight python %} +def updateFunction(newValues, runningCount): + if runningCount is None: + runningCount = 0 + return sum(newValues, runningCount) # add the new values with the previous running count to get the new count +{% endhighlight %} + +This is applied on a DStream containing words (say, the `pairs` DStream containing `(word, +1)` pairs in the [earlier example](#a-quick-example)). + +{% highlight python %} +runningCounts = pairs.updateStateByKey(updateFunction) +{% endhighlight %} +
The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete Scala code, take a look at the example -[StatefulNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala). +[stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py). #### Transform Operation {:.no_toc} @@ -732,6 +877,15 @@ JavaPairDStream cleanedDStream = wordCounts.transform( }); {% endhighlight %} + +
+ +{% highlight python %} +spamInfoRDD = sc.pickleFile(...) # RDD containing spam information + +# join data stream with spam information to do data cleaning +cleanedDStream = wordCounts.transform(lambda rdd: rdd.join(spamInfoRDD).filter(...)) +{% endhighlight %}
@@ -793,6 +947,14 @@ Function2 reduceFunc = new Function2 windowedWordCounts = pairs.reduceByKeyAndWindow(reduceFunc, new Duration(30000), new Duration(10000)); {% endhighlight %} + +
+ +{% highlight python %} +# Reduce last 30 seconds of data, every 10 seconds +windowedWordCounts = pairs.reduceByKeyAndWindow(lambda x, y: x + y, lambda x, y: x - y, 30, 10) +{% endhighlight %} +
@@ -860,6 +1022,7 @@ see [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream) and [PairDStreamFunctions](api/scala/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions). For the Java API, see [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) and [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html). +For the Python API, see [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) *** @@ -872,9 +1035,12 @@ Currently, the following output operations are defined: - + + This is useful for development and debugging. +
+ PS: called pprint() in Python) + @@ -915,17 +1081,41 @@ For this purpose, a developer may inadvertantly try creating a connection object the Spark driver, but try to use it in a Spark worker to save records in the RDDs. For example (in Scala), +
+
+ +{% highlight scala %} dstream.foreachRDD(rdd => { val connection = createNewConnection() // executed at the driver rdd.foreach(record => { connection.send(record) // executed at the worker }) }) +{% endhighlight %} + +
+
+ +{% highlight python %} +def sendRecord(rdd): + connection = createNewConnection() # executed at the driver + rdd.foreach(lambda record: connection.send(record)) + connection.close() + +dstream.foreachRDD(sendRecord) +{% endhighlight %} + +
+
- This is incorrect as this requires the connection object to be serialized and sent from the driver to the worker. Such connection objects are rarely transferrable across machines. This error may manifest as serialization errors (connection object not serializable), initialization errors (connection object needs to be initialized at the workers), etc. The correct solution is to create the connection object at the worker. + This is incorrect as this requires the connection object to be serialized and sent from the driver to the worker. Such connection objects are rarely transferrable across machines. This error may manifest as serialization errors (connection object not serializable), initialization errors (connection object needs to be initialized at the workers), etc. The correct solution is to create the connection object at the worker. - However, this can lead to another common mistake - creating a new connection for every record. For example, +
+
+ +{% highlight scala %} dstream.foreachRDD(rdd => { rdd.foreach(record => { val connection = createNewConnection() @@ -933,9 +1123,28 @@ For example (in Scala), connection.close() }) }) +{% endhighlight %} - Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use `rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection. +
+
+ +{% highlight python %} +def sendRecord(record): + connection = createNewConnection() + connection.send(record) + connection.close() + +dstream.foreachRDD(lambda rdd: rdd.foreach(sendRecord)) +{% endhighlight %} +
+
+ + Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use `rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection. + +
+
+{% highlight scala %} dstream.foreachRDD(rdd => { rdd.foreachPartition(partitionOfRecords => { val connection = createNewConnection() @@ -943,13 +1152,31 @@ For example (in Scala), connection.close() }) }) +{% endhighlight %} +
+ +
+{% highlight python %} +def sendPartition(iter): + connection = createNewConnection() + for record in iter: + connection.send(record) + connection.close() + +dstream.foreachRDD(lambda rdd: rdd.foreachPartition(sendPartition)) +{% endhighlight %} +
+
- This amortizes the connection creation overheads over many records. + This amortizes the connection creation overheads over many records. - Finally, this can be further optimized by reusing connection objects across multiple RDDs/batches. One can maintain a static pool of connection objects than can be reused as RDDs of multiple batches are pushed to the external system, thus further reducing the overheads. - + +
+
+{% highlight scala %} dstream.foreachRDD(rdd => { rdd.foreachPartition(partitionOfRecords => { // ConnectionPool is a static, lazily initialized pool of connections @@ -958,8 +1185,25 @@ For example (in Scala), ConnectionPool.returnConnection(connection) // return to the pool for future reuse }) }) +{% endhighlight %} +
- Note that the connections in the pool should be lazily created on demand and timed out if not used for a while. This achieves the most efficient sending of data to external systems. +
+{% highlight python %} +def sendPartition(iter): + # ConnectionPool is a static, lazily initialized pool of connections + connection = ConnectionPool.getConnection() + for record in iter: + connection.send(record) + # return to the pool for future reuse + ConnectionPool.returnConnection(connection) + +dstream.foreachRDD(lambda rdd: rdd.foreachPartition(sendPartition)) +{% endhighlight %} +
+
+ +Note that the connections in the pool should be lazily created on demand and timed out if not used for a while. This achieves the most efficient sending of data to external systems. ##### Other points to remember: @@ -1376,6 +1620,44 @@ You can also explicitly create a `JavaStreamingContext` from the checkpoint data the computation by using `new JavaStreamingContext(checkpointDirectory)`. +
+ +This behavior is made simple by using `StreamingContext.getOrCreate`. This is used as follows. + +{% highlight python %} +# Function to create and setup a new StreamingContext +def functionToCreateContext(): + sc = SparkContext(...) # new context + ssc = new StreamingContext(...) + lines = ssc.socketTextStream(...) # create DStreams + ... + ssc.checkpoint(checkpointDirectory) # set checkpoint directory + return ssc + +# Get StreamingContext from checkpoint data or create a new one +context = StreamingContext.getOrCreate(checkpointDirectory, functionToCreateContext) + +# Do additional setup on context that needs to be done, +# irrespective of whether it is being started or restarted +context. ... + +# Start the context +context.start() +context.awaitTermination() +{% endhighlight %} + +If the `checkpointDirectory` exists, then the context will be recreated from the checkpoint data. +If the directory does not exist (i.e., running for the first time), +then the function `functionToCreateContext` will be called to create a new +context and set up the DStreams. See the Python example +[recoverable_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/streaming/recoverable_network_wordcount.py). +This example appends the word counts of network data into a file. + +You can also explicitly create a `StreamingContext` from the checkpoint data and start the + computation by using `StreamingContext.getOrCreate(checkpointDirectory, None)`. + +
+ **Note**: If Spark Streaming and/or the Spark Streaming program is recompiled, @@ -1572,7 +1854,11 @@ package and renamed for better clarity. [TwitterUtils](api/java/index.html?org/apache/spark/streaming/twitter/TwitterUtils.html), [ZeroMQUtils](api/java/index.html?org/apache/spark/streaming/zeromq/ZeroMQUtils.html), and [MQTTUtils](api/java/index.html?org/apache/spark/streaming/mqtt/MQTTUtils.html) + - Python docs + * [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) + * [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) * More examples in [Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming) and [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming) + and [Python] ({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/streaming) * [Paper](http://www.eecs.berkeley.edu/Pubs/TechRpts/2012/EECS-2012-259.pdf) and [video](http://youtu.be/g171ndOHgJ0) describing Spark Streaming. diff --git a/docs/tuning.md b/docs/tuning.md index 8fb2a0433b1a8..9b5c9adac6a4f 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -47,24 +47,11 @@ registration requirement, but we recommend trying it in any network-intensive ap Spark automatically includes Kryo serializers for the many commonly-used core Scala classes covered in the AllScalaRegistrar from the [Twitter chill](https://github.com/twitter/chill) library. -To register your own custom classes with Kryo, create a public class that extends -[`org.apache.spark.serializer.KryoRegistrator`](api/scala/index.html#org.apache.spark.serializer.KryoRegistrator) and set the -`spark.kryo.registrator` config property to point to it, as follows: +To register your own custom classes with Kryo, use the `registerKryoClasses` method. {% highlight scala %} -import com.esotericsoftware.kryo.Kryo -import org.apache.spark.serializer.KryoRegistrator - -class MyRegistrator extends KryoRegistrator { - override def registerClasses(kryo: Kryo) { - kryo.register(classOf[MyClass1]) - kryo.register(classOf[MyClass2]) - } -} - val conf = new SparkConf().setMaster(...).setAppName(...) -conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") -conf.set("spark.kryo.registrator", "mypackage.MyRegistrator") +conf.registerKryoClasses(Seq(classOf[MyClass1], classOf[MyClass2])) val sc = new SparkContext(conf) {% endhighlight %} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 5622df5ce03ff..981bc4f0613a9 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -57,7 +57,7 @@ public class JavaCustomReceiver extends Receiver { public static void main(String[] args) { if (args.length < 2) { - System.err.println("Usage: JavaNetworkWordCount "); + System.err.println("Usage: JavaCustomReceiver "); System.exit(1); } diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py new file mode 100644 index 0000000000000..fc6827c82bf9b --- /dev/null +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in text encoded with UTF8 received from the network every second. + + Usage: recoverable_network_wordcount.py + and describe the TCP server that Spark Streaming would connect to receive + data. directory to HDFS-compatible file system which checkpoint data + file to which the word counts will be appended + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/recoverable_network_wordcount.py \ + localhost 9999 ~/checkpoint/ ~/out` + + If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create + a new StreamingContext (will print "Creating new context" to the console). Otherwise, if + checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from + the checkpoint data. +""" + +import os +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + + +def createContext(host, port, outputPath): + # If you do not see this printed, that means the StreamingContext has been loaded + # from the new checkpoint + print "Creating new context" + if os.path.exists(outputPath): + os.remove(outputPath) + sc = SparkContext(appName="PythonStreamingRecoverableNetworkWordCount") + ssc = StreamingContext(sc, 1) + + # Create a socket stream on target ip:port and count the + # words in input stream of \n delimited text (eg. generated by 'nc') + lines = ssc.socketTextStream(host, port) + words = lines.flatMap(lambda line: line.split(" ")) + wordCounts = words.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y) + + def echo(time, rdd): + counts = "Counts at time %s %s" % (time, rdd.collect()) + print counts + print "Appending to " + os.path.abspath(outputPath) + with open(outputPath, 'a') as f: + f.write(counts + "\n") + + wordCounts.foreachRDD(echo) + return ssc + +if __name__ == "__main__": + if len(sys.argv) != 5: + print >> sys.stderr, "Usage: recoverable_network_wordcount.py "\ + " " + exit(-1) + host, port, checkpoint, output = sys.argv[1:] + ssc = StreamingContext.getOrCreate(checkpoint, + lambda: createContext(host, int(port), output)) + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala index e06f4dcd54442..e322d4ce5a745 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala @@ -18,17 +18,7 @@ package org.apache.spark.examples.bagel import org.apache.spark._ -import org.apache.spark.SparkContext._ -import org.apache.spark.serializer.KryoRegistrator - import org.apache.spark.bagel._ -import org.apache.spark.bagel.Bagel._ - -import scala.collection.mutable.ArrayBuffer - -import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} - -import com.esotericsoftware.kryo._ class PageRankUtils extends Serializable { def computeWithCombiner(numVertices: Long, epsilon: Double)( @@ -99,13 +89,6 @@ class PRMessage() extends Message[String] with Serializable { } } -class PRKryoRegistrator extends KryoRegistrator { - def registerClasses(kryo: Kryo) { - kryo.register(classOf[PRVertex]) - kryo.register(classOf[PRMessage]) - } -} - class CustomPartitioner(partitions: Int) extends Partitioner { def numPartitions = partitions diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala index e4db3ec51313d..859abedf2a55e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala @@ -38,8 +38,7 @@ object WikipediaPageRank { } val sparkConf = new SparkConf() sparkConf.setAppName("WikipediaPageRank") - sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - sparkConf.set("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) + sparkConf.registerKryoClasses(Array(classOf[PRVertex], classOf[PRMessage])) val inputFile = args(0) val threshold = args(1).toDouble diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 45527d9382fd0..d70d93608a57c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -46,10 +46,8 @@ object Analytics extends Logging { } val options = mutable.Map(optionsList: _*) - val conf = new SparkConf() - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator") - .set("spark.locality.wait", "100000") + val conf = new SparkConf().set("spark.locality.wait", "100000") + GraphXUtils.registerKryoClasses(conf) val numEPart = options.remove("numEPart").map(_.toInt).getOrElse { println("Set the number of edge partitions using --numEPart.") diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 5f35a5836462e..05676021718d9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -18,7 +18,7 @@ package org.apache.spark.examples.graphx import org.apache.spark.SparkContext._ -import org.apache.spark.graphx.PartitionStrategy +import org.apache.spark.graphx.{GraphXUtils, PartitionStrategy} import org.apache.spark.{SparkContext, SparkConf} import org.apache.spark.graphx.util.GraphGenerators import java.io.{PrintWriter, FileOutputStream} @@ -80,8 +80,7 @@ object SynthBenchmark { val conf = new SparkConf() .setAppName(s"GraphX Synth Benchmark (nverts = $numVertices, app = $app)") - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator") + GraphXUtils.registerKryoClasses(conf) val sc = new SparkContext(conf) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index fc6678013b932..8796c28db8a66 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -19,7 +19,6 @@ package org.apache.spark.examples.mllib import scala.collection.mutable -import com.esotericsoftware.kryo.Kryo import org.apache.log4j.{Level, Logger} import scopt.OptionParser @@ -27,7 +26,6 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.recommendation.{ALS, MatrixFactorizationModel, Rating} import org.apache.spark.rdd.RDD -import org.apache.spark.serializer.{KryoSerializer, KryoRegistrator} /** * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/). @@ -40,13 +38,6 @@ import org.apache.spark.serializer.{KryoSerializer, KryoRegistrator} */ object MovieLensALS { - class ALSRegistrator extends KryoRegistrator { - override def registerClasses(kryo: Kryo) { - kryo.register(classOf[Rating]) - kryo.register(classOf[mutable.BitSet]) - } - } - case class Params( input: String = null, kryo: Boolean = false, @@ -108,8 +99,7 @@ object MovieLensALS { def run(params: Params) { val conf = new SparkConf().setAppName(s"MovieLensALS with $params") if (params.kryo) { - conf.set("spark.serializer", classOf[KryoSerializer].getName) - .set("spark.kryo.registrator", classOf[ALSRegistrator].getName) + conf.registerKryoClasses(Array(classOf[mutable.BitSet], classOf[Rating])) .set("spark.kryoserializer.buffer.mb", "8") } val sc = new SparkContext(conf) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala index 1948c978c30bf..563c948957ecf 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala @@ -27,10 +27,10 @@ import org.apache.spark.graphx.impl._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.util.collection.OpenHashSet - /** * Registers GraphX classes with Kryo for improved performance. */ +@deprecated("Register GraphX classes with Kryo using GraphXUtils.registerKryoClasses", "1.2.0") class GraphKryoRegistrator extends KryoRegistrator { def registerClasses(kryo: Kryo) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala new file mode 100644 index 0000000000000..2cb07937eaa2a --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx + +import org.apache.spark.SparkConf + +import org.apache.spark.graphx.impl._ +import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap + +import org.apache.spark.util.collection.{OpenHashSet, BitSet} +import org.apache.spark.util.BoundedPriorityQueue + +object GraphXUtils { + /** + * Registers classes that GraphX uses with Kryo. + */ + def registerKryoClasses(conf: SparkConf) { + conf.registerKryoClasses(Array( + classOf[Edge[Object]], + classOf[(VertexId, Object)], + classOf[EdgePartition[Object, Object]], + classOf[BitSet], + classOf[VertexIdToIndexMap], + classOf[VertexAttributeBlock[Object]], + classOf[PartitionStrategy], + classOf[BoundedPriorityQueue[Object]], + classOf[EdgeDirection], + classOf[GraphXPrimitiveKeyOpenHashMap[VertexId, Int]], + classOf[OpenHashSet[Int]], + classOf[OpenHashSet[Long]])) + } +} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala index 47594a800a3b1..a3e28efc75a98 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala @@ -17,9 +17,6 @@ package org.apache.spark.graphx -import org.scalatest.Suite -import org.scalatest.BeforeAndAfterEach - import org.apache.spark.SparkConf import org.apache.spark.SparkContext @@ -31,8 +28,7 @@ trait LocalSparkContext { /** Runs `f` on a new SparkContext and ensures that it is stopped afterwards. */ def withSpark[T](f: SparkContext => T) = { val conf = new SparkConf() - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator") + GraphXUtils.registerKryoClasses(conf) val sc = new SparkContext("local", "test", conf) try { f(sc) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index 9d00f76327e4c..db1dac6160080 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -129,9 +129,9 @@ class EdgePartitionSuite extends FunSuite { val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0)) val a: EdgePartition[Int, Int] = makeEdgePartition(aList) val javaSer = new JavaSerializer(new SparkConf()) - val kryoSer = new KryoSerializer(new SparkConf() - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")) + val conf = new SparkConf() + GraphXUtils.registerKryoClasses(conf) + val kryoSer = new KryoSerializer(conf) for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) { val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a)) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala index f9e771a900013..fe8304c1cdc32 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala @@ -125,9 +125,9 @@ class VertexPartitionSuite extends FunSuite { val verts = Set((0L, 1), (1L, 1), (2L, 1)) val vp = VertexPartition(verts.iterator) val javaSer = new JavaSerializer(new SparkConf()) - val kryoSer = new KryoSerializer(new SparkConf() - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")) + val conf = new SparkConf() + GraphXUtils.registerKryoClasses(conf) + val kryoSer = new KryoSerializer(conf) for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) { val vpSer: VertexPartition[Int] = s.deserialize(s.serialize(vp)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 9a100170b75c6..b478c21537c2a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -673,6 +673,11 @@ private[spark] object SerDe extends Serializable { rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int])) } + /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */ + def fromTuple2RDD(rdd: RDD[Tuple2[Any, Any]]): RDD[Array[Any]] = { + rdd.map(x => Array(x._1, x._2)) + } + /** * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by * PySpark. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala new file mode 100644 index 0000000000000..93a7353e2c070 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import scala.reflect.ClassTag + +import org.apache.spark.Logging +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD + +/** + * ::Experimental:: + * Evaluator for ranking algorithms. + * + * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. + */ +@Experimental +class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]) + extends Logging with Serializable { + + /** + * Compute the average precision of all the queries, truncated at ranking position k. + * + * If for a query, the ranking algorithm returns n (n < k) results, the precision value will be + * computed as #(relevant items retrieved) / k. This formula also applies when the size of the + * ground truth set is less than k. + * + * If a query has an empty ground truth set, zero will be used as precision together with + * a log warning. + * + * See the following paper for detail: + * + * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen + * + * @param k the position to compute the truncated precision, must be positive + * @return the average precision at the first k ranking positions + */ + def precisionAt(k: Int): Double = { + require(k > 0, "ranking position k should be positive") + predictionAndLabels.map { case (pred, lab) => + val labSet = lab.toSet + + if (labSet.nonEmpty) { + val n = math.min(pred.length, k) + var i = 0 + var cnt = 0 + while (i < n) { + if (labSet.contains(pred(i))) { + cnt += 1 + } + i += 1 + } + cnt.toDouble / k + } else { + logWarning("Empty ground truth set, check input data") + 0.0 + } + }.mean + } + + /** + * Returns the mean average precision (MAP) of all the queries. + * If a query has an empty ground truth set, the average precision will be zero and a log + * warining is generated. + */ + lazy val meanAveragePrecision: Double = { + predictionAndLabels.map { case (pred, lab) => + val labSet = lab.toSet + + if (labSet.nonEmpty) { + var i = 0 + var cnt = 0 + var precSum = 0.0 + val n = pred.length + while (i < n) { + if (labSet.contains(pred(i))) { + cnt += 1 + precSum += cnt.toDouble / (i + 1) + } + i += 1 + } + precSum / labSet.size + } else { + logWarning("Empty ground truth set, check input data") + 0.0 + } + }.mean + } + + /** + * Compute the average NDCG value of all the queries, truncated at ranking position k. + * The discounted cumulative gain at position k is computed as: + * sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), + * and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current + * implementation, the relevance value is binary. + + * If a query has an empty ground truth set, zero will be used as ndcg together with + * a log warning. + * + * See the following paper for detail: + * + * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen + * + * @param k the position to compute the truncated ndcg, must be positive + * @return the average ndcg at the first k ranking positions + */ + def ndcgAt(k: Int): Double = { + require(k > 0, "ranking position k should be positive") + predictionAndLabels.map { case (pred, lab) => + val labSet = lab.toSet + + if (labSet.nonEmpty) { + val labSetSize = labSet.size + val n = math.min(math.max(pred.length, labSetSize), k) + var maxDcg = 0.0 + var dcg = 0.0 + var i = 0 + while (i < n) { + val gain = 1.0 / math.log(i + 2) + if (labSet.contains(pred(i))) { + dcg += gain + } + if (i < labSetSize) { + maxDcg += gain + } + i += 1 + } + dcg / maxDcg + } else { + logWarning("Empty ground truth set, check input data") + 0.0 + } + }.mean + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 03eeaa707715b..6737a2f4176c2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD @@ -909,32 +911,39 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - val numSplits = metadata.numSplits(featureIndex) - val numBins = metadata.numBins(featureIndex) if (metadata.isContinuous(featureIndex)) { - val numSamples = sampledInput.length + val featureSamples = sampledInput.map(lp => lp.features(featureIndex)) + val featureSplits = findSplitsForContinuousFeature(featureSamples, + metadata, featureIndex) + + val numSplits = featureSplits.length + val numBins = numSplits + 1 + logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits") splits(featureIndex) = new Array[Split](numSplits) bins(featureIndex) = new Array[Bin](numBins) - val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex) - logDebug("stride = " + stride) - for (splitIndex <- 0 until numSplits) { - val sampleIndex = splitIndex * stride.toInt - // Set threshold halfway in between 2 samples. - val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0 + + var splitIndex = 0 + while (splitIndex < numSplits) { + val threshold = featureSplits(splitIndex) splits(featureIndex)(splitIndex) = new Split(featureIndex, threshold, Continuous, List()) + splitIndex += 1 } bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) - for (splitIndex <- 1 until numSplits) { + + splitIndex = 1 + while (splitIndex < numSplits) { bins(featureIndex)(splitIndex) = new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex), Continuous, Double.MinValue) + splitIndex += 1 } bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1), new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) } else { + val numSplits = metadata.numSplits(featureIndex) + val numBins = metadata.numBins(featureIndex) // Categorical feature val featureArity = metadata.featureArity(featureIndex) if (metadata.isUnordered(featureIndex)) { @@ -1011,4 +1020,77 @@ object DecisionTree extends Serializable with Logging { categories } + /** + * Find splits for a continuous feature + * NOTE: Returned number of splits is set based on `featureSamples` and + * could be different from the specified `numSplits`. + * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. + * @param featureSamples feature values of each sample + * @param metadata decision tree metadata + * NOTE: `metadata.numbins` will be changed accordingly + * if there are not enough splits to be found + * @param featureIndex feature index to find splits + * @return array of splits + */ + private[tree] def findSplitsForContinuousFeature( + featureSamples: Array[Double], + metadata: DecisionTreeMetadata, + featureIndex: Int): Array[Double] = { + require(metadata.isContinuous(featureIndex), + "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") + + val splits = { + val numSplits = metadata.numSplits(featureIndex) + + // get count for each distinct value + val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) => + m + ((x, m.getOrElse(x, 0) + 1)) + } + // sort distinct values + val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray + + // if possible splits is not enough or just enough, just return all possible splits + val possibleSplits = valueCounts.length + if (possibleSplits <= numSplits) { + valueCounts.map(_._1) + } else { + // stride between splits + val stride: Double = featureSamples.length.toDouble / (numSplits + 1) + logDebug("stride = " + stride) + + // iterate `valueCount` to find splits + val splits = new ArrayBuffer[Double] + var index = 1 + // currentCount: sum of counts of values that have been visited + var currentCount = valueCounts(0)._2 + // targetCount: target value for `currentCount`. + // If `currentCount` is closest value to `targetCount`, + // then current value is a split threshold. + // After finding a split threshold, `targetCount` is added by stride. + var targetCount = stride + while (index < valueCounts.length) { + val previousCount = currentCount + currentCount += valueCounts(index)._2 + val previousGap = math.abs(previousCount - targetCount) + val currentGap = math.abs(currentCount - targetCount) + // If adding count of current value to currentCount + // makes the gap between currentCount and targetCount smaller, + // previous value is a split threshold. + if (previousGap < currentGap) { + splits.append(valueCounts(index - 1)._1) + targetCount += stride + } + index += 1 + } + + splits.toArray + } + } + + assert(splits.length > 0) + // set number of splits accordingly + metadata.setNumSplits(featureIndex, splits.length) + + splits + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 772c02670e541..5bc0f2635c6b1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -76,6 +76,17 @@ private[tree] class DecisionTreeMetadata( numBins(featureIndex) - 1 } + + /** + * Set number of splits for a continuous feature. + * For a continuous feature, number of bins is number of splits plus 1. + */ + def setNumSplits(featureIndex: Int, numSplits: Int) { + require(isContinuous(featureIndex), + s"Only number of bin for a continuous feature can be set.") + numBins(featureIndex) = numSplits + 1 + } + /** * Indicates if feature subsampling is being used. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala new file mode 100644 index 0000000000000..a2d4bb41484b8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.util.LocalSparkContext + +class RankingMetricsSuite extends FunSuite with LocalSparkContext { + test("Ranking metrics: map, ndcg") { + val predictionAndLabels = sc.parallelize( + Seq( + (Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)), + (Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)), + (Array[Int](1, 2, 3, 4, 5), Array[Int]()) + ), 2) + val eps: Double = 1E-5 + + val metrics = new RankingMetrics(predictionAndLabels) + val map = metrics.meanAveragePrecision + + assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps) + assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps) + assert(metrics.precisionAt(3) ~== 1.0/3 absTol eps) + assert(metrics.precisionAt(4) ~== 0.75/3 absTol eps) + assert(metrics.precisionAt(5) ~== 0.8/3 absTol eps) + assert(metrics.precisionAt(10) ~== 0.8/3 absTol eps) + assert(metrics.precisionAt(15) ~== 8.0/45 absTol eps) + + assert(map ~== 0.355026 absTol eps) + + assert(metrics.ndcgAt(3) ~== 1.0/3 absTol eps) + assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps) + assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps) + assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps) + + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 98a72b0c4d750..8fc5e111bbc17 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node} @@ -102,6 +102,72 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq) } + test("find splits for a continuous feature") { + // find splits for normal case + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(6), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array.fill(200000)(math.random) + val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 5) + assert(fakeMetadata.numSplits(0) === 5) + assert(fakeMetadata.numBins(0) === 6) + // check returned splits are distinct + assert(splits.distinct.length === splits.length) + } + + // find splits should not return identical splits + // when there are not enough split candidates, reduce the number of splits in metadata + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(5), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) + val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 3) + assert(fakeMetadata.numSplits(0) === 3) + assert(fakeMetadata.numBins(0) === 4) + // check returned splits are distinct + assert(splits.distinct.length === splits.length) + } + + // find splits when most samples close to the minimum + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) + val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 2) + assert(fakeMetadata.numSplits(0) === 2) + assert(fakeMetadata.numBins(0) === 3) + assert(splits(0) === 2.0) + assert(splits(1) === 3.0) + } + + // find splits when most samples close to the maximum + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 1) + assert(fakeMetadata.numSplits(0) === 1) + assert(fakeMetadata.numBins(0) === 2) + assert(splits(0) === 1.0) + } + } + test("Multiclass classification with unordered categorical features:" + " split and bin calculations") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index fb44ceb0f57ee..6b13765b98f41 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -93,8 +93,9 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { val categoricalFeaturesInfo = Map.empty[Int, Int] val numTrees = 1 - val strategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + val strategy = new Strategy(algo = Regression, impurity = Variance, + maxDepth = 2, maxBins = 10, numClassesForClassification = 2, + categoricalFeaturesInfo = categoricalFeaturesInfo) val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 350aad47735e4..c58666af84f24 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -54,7 +54,18 @@ object MimaExcludes { // TaskContext was promoted to Abstract class ProblemFilters.exclude[AbstractClassProblem]( "org.apache.spark.TaskContext") - + ) ++ Seq( + // Adding new methods to the JavaRDDLike trait: + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.takeAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachPartitionAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.countAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.collectAsync") ) case v if v.startsWith("1.1") => diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst new file mode 100644 index 0000000000000..5024d694b668f --- /dev/null +++ b/python/docs/pyspark.streaming.rst @@ -0,0 +1,10 @@ +pyspark.streaming module +================== + +Module contents +--------------- + +.. automodule:: pyspark.streaming + :members: + :undoc-members: + :show-inheritance: diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 17f96b8700bd7..22872dbbe3b55 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -53,6 +53,23 @@ class MatrixFactorizationModel(object): >>> model = ALS.train(ratings, 1) >>> model.predictAll(testset).count() == 2 True + + >>> model = ALS.train(ratings, 4) + >>> model.userFeatures().count() == 2 + True + + >>> first_user = model.userFeatures().take(1)[0] + >>> latents = first_user[1] + >>> len(latents) == 4 + True + + >>> model.productFeatures().count() == 2 + True + + >>> first_product = model.productFeatures().take(1)[0] + >>> latents = first_product[1] + >>> len(latents) == 4 + True """ def __init__(self, sc, java_model): @@ -83,6 +100,20 @@ def predictAll(self, user_product): return RDD(sc._jvm.SerDe.javaToPython(jresult), sc, AutoBatchedSerializer(PickleSerializer())) + def userFeatures(self): + sc = self._context + juf = self._java_model.userFeatures() + juf = sc._jvm.SerDe.fromTuple2RDD(juf).toJavaRDD() + return RDD(sc._jvm.PythonRDD.javaToPython(juf), sc, + AutoBatchedSerializer(PickleSerializer())) + + def productFeatures(self): + sc = self._context + jpf = self._java_model.productFeatures() + jpf = sc._jvm.SerDe.fromTuple2RDD(jpf).toJavaRDD() + return RDD(sc._jvm.PythonRDD.javaToPython(jpf), sc, + AutoBatchedSerializer(PickleSerializer())) + class ALS(object): diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index a6019dadf781c..84baf12b906df 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -22,7 +22,7 @@ from functools import wraps from pyspark import PickleSerializer -from pyspark.mllib.linalg import _to_java_object_rdd +from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd __all__ = ['MultivariateStatisticalSummary', 'Statistics'] @@ -107,7 +107,7 @@ def colStats(rdd): array([ 2., 0., 0., -2.]) """ sc = rdd.ctx - jrdd = _to_java_object_rdd(rdd) + jrdd = _to_java_object_rdd(rdd.map(_convert_to_vector)) cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd) return MultivariateStatisticalSummary(sc, cStats) @@ -163,14 +163,15 @@ def corr(x, y=None, method=None): if type(y) == str: raise TypeError("Use 'method=' to specify method name.") - jx = _to_java_object_rdd(x) if not y: + jx = _to_java_object_rdd(x.map(_convert_to_vector)) resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method) bytes = sc._jvm.SerDe.dumps(resultMat) ser = PickleSerializer() return ser.loads(str(bytes)).toArray() else: - jy = _to_java_object_rdd(y) + jx = _to_java_object_rdd(x.map(float)) + jy = _to_java_object_rdd(y.map(float)) return sc._jvm.PythonMLLibAPI().corr(jx, jy, method) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 463faf7b6f520..d6fb87b378b4a 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -36,6 +36,8 @@ from pyspark.serializers import PickleSerializer from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.random import RandomRDDs +from pyspark.mllib.stat import Statistics from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -202,6 +204,23 @@ def test_regression(self): self.assertTrue(dt_model.predict(features[3]) > 0) +class StatTests(PySparkTestCase): + # SPARK-4023 + def test_col_with_different_rdds(self): + # numpy + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(1000, summary.count()) + # array + data = self.sc.parallelize([range(10)] * 10) + summary = Statistics.colStats(data) + self.assertEqual(10, summary.count()) + # array + data = self.sc.parallelize([pyarray.array("d", range(10))] * 10) + summary = Statistics.colStats(data) + self.assertEqual(10, summary.count()) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 0938eebd3a548..64ee79d83e849 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -153,9 +153,9 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, DecisionTreeModel classifier of depth 1 with 3 nodes >>> print model.toDebugString(), # it already has newline DecisionTreeModel classifier of depth 1 with 3 nodes - If (feature 0 <= 0.5) + If (feature 0 <= 0.0) Predict: 0.0 - Else (feature 0 > 0.5) + Else (feature 0 > 0.0) Predict: 1.0 >>> model.predict(array([1.0])) > 0 True diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 55e247da0e4dc..528a181e8905a 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -31,7 +31,7 @@ def __init__(self, withReplacement, seed=None): "Falling back to default random generator for sampling.") self._use_numpy = False - self._seed = seed if seed is not None else random.randint(0, sys.maxint) + self._seed = seed if seed is not None else random.randint(0, 2 ** 32 - 1) self._withReplacement = withReplacement self._random = None self._split = None @@ -47,7 +47,7 @@ def initRandomGenerator(self, split): for _ in range(0, split): # discard the next few values in the sequence to have a # different seed for the different splits - self._random.randint(0, sys.maxint) + self._random.randint(0, 2 ** 32 - 1) self._split = split self._rand_initialized = True diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index dc9dc41121935..2f53fbd27b17a 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -79,7 +79,7 @@ class StreamingContext(object): L{DStream} various input sources. It can be from an existing L{SparkContext}. After creating and transforming DStreams, the streaming computation can be started and stopped using `context.start()` and `context.stop()`, - respectively. `context.awaitTransformation()` allows the current thread + respectively. `context.awaitTermination()` allows the current thread to wait for the termination of the context by `stop()` or by an exception. """ _transformerSerializer = None diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 5ae5cf07f0137..0826ddc56e844 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -441,9 +441,11 @@ def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuratio if `invReduceFunc` is not None, the reduction is done incrementally using the old window's reduced value : - 1. reduce the new values that entered the window (e.g., adding new counts) - 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - This is more efficient than `invReduceFunc` is None. + + 1. reduce the new values that entered the window (e.g., adding new counts) + + 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + This is more efficient than `invReduceFunc` is None. @param reduceFunc: associative reduce function @param invReduceFunc: inverse reduce function of `reduceFunc` diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index f5ccf31abb3fa..1a8e4150e63c3 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -433,6 +433,12 @@ def test_deleting_input_files(self): os.unlink(tempFile.name) self.assertRaises(Exception, lambda: filtered_data.count()) + def test_sampling_default_seed(self): + # Test for SPARK-3995 (default seed setting) + data = self.sc.parallelize(range(1000), 1) + subset = data.takeSample(False, 10) + self.assertEqual(len(subset), 10) + def testAggregateByKey(self): data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index b4d606d37e732..a277684f6327c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -181,9 +181,11 @@ class SqlParser extends AbstractSparkSQLParser { ) protected lazy val joinedRelation: Parser[LogicalPlan] = - relationFactor ~ joinType.? ~ (JOIN ~> relationFactor) ~ joinConditions.? ^^ { - case r1 ~ jt ~ r2 ~ cond => - Join(r1, r2, joinType = jt.getOrElse(Inner), cond) + relationFactor ~ rep1(joinType.? ~ (JOIN ~> relationFactor) ~ joinConditions.?) ^^ { + case r1 ~ joins => + joins.foldLeft(r1) { case (lhs, jt ~ rhs ~ cond) => + Join(lhs, rhs, joinType = jt.getOrElse(Inner), cond) + } } protected lazy val joinConditions: Parser[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 82553063145b8..a448c794213ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -60,6 +60,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool ResolveFunctions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: + TrimAliases :: typeCoercionRules ++ extendedRules : _*), Batch("Check Analysis", Once, @@ -89,6 +90,23 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool } } + /** + * Removes no-op Alias expressions from the plan. + */ + object TrimAliases extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Aggregate(groups, aggs, child) => + Aggregate( + groups.map { + _ transform { + case Alias(c, _) => c + } + }, + aggs, + child) + } + } + /** * Checks for non-aggregated attributes with aggregation */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8e5ee12e314bf..8e5baf0eb82d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -32,6 +32,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (StringType, _: NumericType) => true case (StringType, TimestampType) => true case (StringType, DateType) => true + case (_: NumericType, DateType) => true + case (BooleanType, DateType) => true + case (DateType, _: NumericType) => true + case (DateType, BooleanType) => true case _ => child.nullable } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3693b41404fd6..9ce7c78195830 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -28,7 +28,9 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ -object Optimizer extends RuleExecutor[LogicalPlan] { +abstract class Optimizer extends RuleExecutor[LogicalPlan] + +object DefaultOptimizer extends Optimizer { val batches = Batch("Combine Limits", FixedPoint(100), CombineLimits) :: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index 245a2e148030c..ef3114fd4dbab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -15,9 +15,8 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.optimizer +package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index 887aabb1d5fb4..275ea2627ebcd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -15,9 +15,8 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.optimizer +package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala index 890d6289b9dfb..ae99a3f9ba287 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala @@ -30,7 +30,7 @@ class ExpressionOptimizationSuite extends ExpressionEvaluationSuite { expected: Any, inputRow: Row = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, NoRelation) - val optimizedPlan = Optimizer(plan) + val optimizedPlan = DefaultOptimizer(plan) super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow) } -} \ No newline at end of file +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 23e7b2d270777..0e4a9ca60b00d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl.ExpressionConversions import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.optimizer.{Optimizer, DefaultOptimizer} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.types.DataType @@ -68,7 +68,7 @@ class SQLContext(@transient val sparkContext: SparkContext) new Analyzer(catalog, functionRegistry, caseSensitive = true) @transient - protected[sql] val optimizer = Optimizer + protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer @transient protected[sql] val sqlParser = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index e9d04ce7aae4c..df01411f60a05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -22,6 +22,7 @@ import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} import scala.collection.JavaConversions import scala.math.BigDecimal +import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow} /** @@ -114,7 +115,7 @@ object Row { // they are actually accessed. case row: ScalaRow => new Row(row) case map: scala.collection.Map[_, _] => - JavaConversions.mapAsJavaMap( + mapAsSerializableJavaMap( map.map { case (key, value) => (toJavaValue(key), toJavaValue(value)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 15f6ba4f72bbd..3959925a2e529 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -43,6 +43,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { TimeZone.setDefault(origZone) } + test("grouping on nested fields") { + jsonRDD(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) + .registerTempTable("rows") + + checkAnswer( + sql( + """ + |select attribute, sum(cnt) + |from ( + | select nested.attribute, count(*) as cnt + | from rows + | group by nested.attribute) a + |group by attribute + """.stripMargin), + Row(1, 1) :: Nil) + } + test("SPARK-3176 Added Parser of SQL ABS()") { checkAnswer( sql("SELECT ABS(-1.3)"), @@ -720,4 +737,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1") checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) } + + test("Multiple join") { + checkAnswer( + sql( + """SELECT a.key, b.key, c.key + |FROM testData a + |JOIN testData b ON a.key = b.key + |JOIN testData c ON a.key = c.key + """.stripMargin), + (1 to 100).map(i => Seq(i, i, i))) + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 7463df1f47d43..a5c457c677564 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -62,7 +62,7 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo } catch { case cause: Throwable => logError(s"Failed in [$command]", cause) - new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null) + new CommandProcessorResponse(0, ExceptionUtils.getFullStackTrace(cause), null) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 582264eb59f83..2136a2ea63543 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -39,7 +39,11 @@ private[hive] object SparkSQLEnv extends Logging { sparkContext.addSparkListener(new StatsReportListener()) hiveContext = new HiveContext(sparkContext) { - @transient override lazy val sessionState = SessionState.get() + @transient override lazy val sessionState = { + val state = SessionState.get() + setConf(state.getConf.getAllProperties) + state + } @transient override lazy val hiveconf = sessionState.getConf } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 0de29d5cffd0e..fd4f65e488259 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -67,10 +67,6 @@ class HadoopTableReader( private val _broadcastedHiveConf = sc.sparkContext.broadcast(new SerializableWritable(hiveExtraConf)) - def broadcastedHiveConf = _broadcastedHiveConf - - def hiveConf = _broadcastedHiveConf.value.value - override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] = makeRDDForTable( hiveTable, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala index 7931425675128..4b2a8511e8e49 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala @@ -188,10 +188,11 @@ private[sql] case class InsertIntoOrcTable( val fieldOIs = standardOI .getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray val outputData = new Array[Any](fieldOIs.length) + val wrappers = fieldOIs.map(HadoopTypeConverter.wrapperFor) iter.map { row => var i = 0 while (i < row.length) { - outputData(i) = HadoopTypeConverter.wrap((row(i), fieldOIs(i))) + outputData(i) = wrappers(i)(row(i)) i += 1 } orcSerde.serialize(outputData, standardOI) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala index 14fe51b92d4b1..81eea055e8b60 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala @@ -45,32 +45,36 @@ package object orc { // TypeConverter for InsertIntoOrcTable object HadoopTypeConverter extends HiveInspectors { - def wrap(a: (Any, ObjectInspector)): Any = a match { - case (s: String, oi: JavaHiveVarcharObjectInspector) => - new HiveVarchar(s, s.size) + def wrapperFor(oi: ObjectInspector): Any => Any = oi match { + case _: JavaHiveVarcharObjectInspector => + (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) - case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) => - new HiveDecimal(bd.underlying()) + case _: JavaHiveDecimalObjectInspector => + (o: Any) => new HiveDecimal(o.asInstanceOf[BigDecimal].underlying()) - case (row: Row, oi: StandardStructObjectInspector) => - val struct = oi.create() - row.zip(oi.getAllStructFieldRefs: Seq[StructField]).foreach { - case (data, field) => - oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector)) + case soi: StandardStructObjectInspector => + val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) + (o: Any) => { + val struct = soi.create() + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { + (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + } + struct } - struct - case (s: Seq[_], oi: ListObjectInspector) => - val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector)) - seqAsJavaList(wrappedSeq) - case (m: Map[_, _], oi: MapObjectInspector) => - val keyOi = oi.getMapKeyObjectInspector - val valueOi = oi.getMapValueObjectInspector - val wrappedMap = m.map { case (key, value) => wrap(key, keyOi) -> wrap(value, valueOi) } - mapAsJavaMap(wrappedMap) + case loi: ListObjectInspector => + val wrapper = wrapperFor(loi.getListElementObjectInspector) + (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) - case (obj, _) => - obj + case moi: MapObjectInspector => + val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) + val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) + (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => + keyWrapper(key) -> valueWrapper(value) + }) + + case _ => + identity[Any] } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 5a8eef1372e23..23d6d1c5e50fa 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -47,7 +47,7 @@ import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab * The associated SparkContext can be accessed using `context.sparkContext`. After * creating and transforming DStreams, the streaming computation can be started and stopped * using `context.start()` and `context.stop()`, respectively. - * `context.awaitTransformation()` allows the current thread to wait for the termination + * `context.awaitTermination()` allows the current thread to wait for the termination * of the context by `stop()` or by an exception. */ class StreamingContext private[streaming] ( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 9dc26dc6b32a1..7db66c69a6d73 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -46,7 +46,7 @@ import org.apache.spark.streaming.receiver.Receiver * org.apache.spark.api.java.JavaSparkContext (see core Spark documentation) can be accessed * using `context.sparkContext`. After creating and transforming DStreams, the streaming * computation can be started and stopped using `context.start()` and `context.stop()`, - * respectively. `context.awaitTransformation()` allows the current thread to wait for the + * respectively. `context.awaitTermination()` allows the current thread to wait for the * termination of a context by `stop()` or by an exception. */ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable {
Output OperationMeaning
print() print() Prints first ten elements of every batch of data in a DStream on the driver. - This is useful for development and debugging.
saveAsObjectFiles(prefix, [suffix])