diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e132955f0f85..1e6008f0f6bf 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1344,6 +1344,10 @@ object SparkContext extends Logging { // TODO: Add AccumulatorParams for other types, e.g. lists and strings + implicit def rddToCascadeRDDFunctions[T: ClassTag](rdd: RDD[T]) = new CascadeRDDFunctions(rdd) + + implicit def rddToScanRDDFunctions[T: ClassTag](rdd: RDD[T]) = new ScanRDDFunctions(rdd) + implicit def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { new PairRDDFunctions(rdd) diff --git a/core/src/main/scala/org/apache/spark/rdd/CascadeRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/CascadeRDDFunctions.scala new file mode 100644 index 000000000000..82c7391d717a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/CascadeRDDFunctions.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkContext, Logging, Partition, TaskContext} +import org.apache.spark.{Dependency, NarrowDependency, OneToOneDependency} + + +private[spark] +class CascadeDep[T: ClassTag](rdd: RDD[T], pid: Int) extends NarrowDependency[T](rdd) { + // each cascaded dependency is one particular partition in the given rdd + override def getParents(unused: Int) = List(pid) +} + + +private[spark] +class CascadePartition extends Partition { + // each CascadeRDD has one partition + override def index = 0 +} + + +private[spark] +class CascadeRDD[T: ClassTag, U: ClassTag] + (rdd: RDD[T], pid: Int, cascade: Option[RDD[U]], + f: => ((Iterator[T], Option[Iterator[U]]) => Iterator[U])) + extends RDD[U](rdd.context, + cascade match { + case None => List(new CascadeDep(rdd, pid)) + case Some(crdd) => List(new CascadeDep(rdd, pid), new CascadeDep(crdd, 0)) + }) { + + val rddPartition = rdd.partitions(pid) + + override def getPartitions: Array[Partition] = Array(new CascadePartition) + + override def compute(unused: Partition, ctx: TaskContext): Iterator[U] = { + f(rdd.iterator(rddPartition, ctx), cascade.map(_.iterator(new CascadePartition, ctx))) + } +} + + +class CascadeRDDFunctions[T: ClassTag](self: RDD[T]) extends Logging with Serializable { + + /* + * Applies a "cascading" function to the input RDD, such that each output partition is + * a function of the corresponding input partition and the previous output partition + */ + def cascadePartitions[U: ClassTag] + (f: => ((Iterator[T], Option[Iterator[U]]) => Iterator[U])): RDD[U] = { + if (self.partitions.length <= 0) return self.context.emptyRDD[U] + + val fclean = self.context.clean(f) + + val cascade = ArrayBuffer[RDD[U]](new CascadeRDD(self, 0, None, fclean)) + + for (j <- 1 until self.partitions.length) { + val prev = cascade.last + cascade += new CascadeRDD(self, j, Some(prev), fclean) + } + + new UnionRDD(self.context, cascade) + } + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ScanRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/ScanRDDFunctions.scala new file mode 100644 index 000000000000..b83ba59711ee --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/ScanRDDFunctions.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkContext, Logging, Partition, TaskContext} +import org.apache.spark.{Dependency, NarrowDependency, OneToOneDependency} + +import scala.language.implicitConversions +import org.apache.spark.SparkContext.rddToCascadeRDDFunctions + + +private[spark] +class ScanPlyPartition[U: ClassTag](idx: Int, cur: (Int, Partition), prv: (Int, Partition)) + extends Partition { + override def index = idx + def get:((Int, Partition), (Int, Partition)) = (cur, prv) +} + + +private[spark] +class ScanPlyRangeDep[U: ClassTag](rdd: RDD[U], kL: Int, kU: Int) extends NarrowDependency(rdd) { + override def getParents(pid: Int) = if (pid >= kL && pid < kU) List(pid) else Nil +} + + +private[spark] +class ScanPlyOffsetDep[U: ClassTag](rdd: RDD[U], b: Int) extends NarrowDependency(rdd) { + override def getParents(pid: Int) = if (pid >= b) List(pid - b, pid) else Nil +} + + +private[spark] +class ScanPly0RDD[U: ClassTag](rdd: RDD[U]) extends RDD[U](rdd) { + override def getPartitions = rdd.partitions.take(rdd.partitions.length - 1) + override def compute(split: Partition, ctx: TaskContext) = + List(rdd.iterator(split, ctx).toSeq.last).iterator +} + + +private[spark] +class ScanPlyRDD[U: ClassTag](f: (U, U) => U, plies: Seq[RDD[U]]) + extends RDD[U](plies(0).context, Nil) { + val ply:Array[RDD[U]] = plies.toArray + val n = ply(0).partitions.length + + override def getPartitions = { + val plist: ArrayBuffer[Partition] = ArrayBuffer.empty + for (j <- 0 until ply.length) { + val (kL, kU) = if (j <= 0) (0, 1) else (math.pow(2,j - 1).toInt, math.pow(2, j).toInt) + for (k <- kL until kU) { + plist += new ScanPlyPartition(k, (j, ply(j).partitions(k)), null) + } + } + + val jj = ply.length - 1 + val b = math.pow(2, jj).toInt + + for (k <- b until n) { + plist += new ScanPlyPartition(k, (jj, ply(jj).partitions(k)), (jj, ply(jj).partitions(k - b))) + } + + plist.toArray + } + + override def getDependencies = { + val dlist: ArrayBuffer[Dependency[U]] = ArrayBuffer.empty + for (j <- 0 until ply.length) { + val (kL, kU) = if (j <= 0) (0, 1) else (math.pow(2,j - 1).toInt, math.pow(2, j).toInt) + dlist += new ScanPlyRangeDep(ply(j), kL, kU) + } + dlist += new ScanPlyOffsetDep(ply.last, math.pow(2, ply.length - 1).toInt) + + dlist + } + + override def compute(split: Partition, ctx:TaskContext):Iterator[U] = { + val p = split.asInstanceOf[ScanPlyPartition[U]] + val (cur, prv) = p.get + val iter = parent[U](cur._1).iterator(cur._2, ctx) + if (prv == null) iter else { + val x = parent[U](prv._1).iterator(prv._2, ctx).next + List(f(x, iter.next)).iterator + } + } +} + +private[spark] +class ScanOutputPartition(s: Partition, o: Partition) extends Partition { + val scan = s + val offset = o + override def index = scan.index +} + +private[spark] +class ScanOutputRDD[U: ClassTag](scans: RDD[U], offsets: RDD[U], f: (U, U) => U) + extends RDD[U](scans.context, Nil) { + override def getDependencies = { + List(new OneToOneDependency(scans), + new NarrowDependency(offsets) { + override def getParents(pid: Int) = if (pid < 1) Nil else List(pid - 1) + }) + } + + override def getPartitions = { + Array(new ScanOutputPartition(scans.partitions.head, null)) ++ + scans.partitions.tail.zip(offsets.partitions).map(x => new ScanOutputPartition(x._1, x._2)) + } + + override def compute(split: Partition, ctx: TaskContext) = { + val p = split.asInstanceOf[ScanOutputPartition] + val iter = scans.iterator(p.scan, ctx) + if (split.index == 0) iter else { + val z = offsets.iterator(p.offset, ctx).next + iter.drop(1).map(f(z, _)) + } + } +} + + +class ScanRDDFunctions[T : ClassTag](self: RDD[T]) extends Logging with Serializable { + + /** + * Sequential-only prefix scan. Analogous to scanLeft on scala sequences + */ + def scanLeft[U: ClassTag](z: U)(f: (U, T) => U): RDD[U] = { + if (self.partitions.length <= 0) return self.context.parallelize(Array(z), 1) + + val g = self.context.clean((input: Iterator[T], cascade: Option[Iterator[U]]) => { + val zz:U = cascade.map(_.toSeq.last).getOrElse(z) + input.scanLeft(zz)(f) + }) + + self.cascadePartitions(g).mapPartitionsWithIndex((j: Int, input: Iterator[U]) => { + if (j == 0) input else input.drop(1) + }) + } + + + /** + * Parallel prefix scan. Analogous to scan on scala sequences + */ + def scan[U >: T : ClassTag](z: U)(f: (U, U) => U): RDD[U] = { + if (self.partitions.length <= 0) return self.context.parallelize(Array(z), 1) + + val fclean = self.context.clean(f) + + // Compute prefix scan on each partition + val pps = self.mapPartitions(_.toSeq.scan(z)(fclean).iterator) + + // Extract the last row of each scan partition. This is ply(0). + val ply:ArrayBuffer[RDD[U]] = ArrayBuffer(new ScanPly0RDD(pps)) + + // Compute the prefix scan on the last rows of the partitions to obtain + // offsets for output partitions. Each partition of each ply has one row. + // There are 1+ceiling(log_base_2(n-1)) plies, where n is the number of + // input partitions. The total number of ply partitions is O((n)log(n)). + var b = 1 + while (b < ply(0).partitions.length) { + val nxt = new ScanPlyRDD(fclean, ply) + ply += nxt + b = 2 * b + } + + // Add the offset for each partition (ply.last) to the per-partition scans + new ScanOutputRDD(pps, ply.last, fclean) + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/ScanRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ScanRDDFunctionsSuite.scala new file mode 100644 index 000000000000..d5c6730f3a5f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/ScanRDDFunctionsSuite.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.scalatest.FunSuite +import org.scalatest.Matchers + +import org.apache.spark.{Logging, SharedSparkContext} +import org.apache.spark.SparkContext._ + +class ScanRDDFunctionsSuite extends FunSuite with SharedSparkContext with Matchers with Logging { + + test("empty") { + val e = sc.emptyRDD[Int] + + // scanLeft is stable with any initial value + assert(e.scanLeft(0)(_ + _).collect === Array(0)) + assert(e.scanLeft(77)(_ + _).collect === Array(77)) + + // parallel scan requires an initial value that is + // the identity element of the given scanning function + assert(e.scan(0)(_ + _).collect === Array(0)) + } + + test("simple") { + val rdd = sc.parallelize(Array(1, 2, 3, 4), 2) + + assert(rdd.scanLeft(0)(_ + _).collect === Array(0, 1, 3, 6, 10)) + assert(rdd.scanLeft(1)(_ + _).collect === Array(1, 2, 4, 7, 11)) + + assert(rdd.scan(0)(_ + _).collect === Array(0, 1, 3, 6, 10)) + } + + test("multiply") { + val rdd = sc.parallelize(Array(1, 2, 3, 4), 2) + + assert(rdd.scanLeft(1)(_ * _).collect === Array(1, 1, 2, 6, 24)) + assert(rdd.scanLeft(2)(_ * _).collect === Array(2, 2, 4, 12, 48)) + assert(rdd.scanLeft(0)(_ * _).collect === Array(0, 0, 0, 0, 0)) + + assert(rdd.scan(1)(_ * _).collect === Array(1, 1, 2, 6, 24)) + } + + test("concatenate") { + val rdd = sc.parallelize(Array("a", "b", "c", "d"), 4) + + assert(rdd.scanLeft("")(_ ++ _).collect === Array("", "a", "ab", "abc", "abcd")) + assert(rdd.scanLeft("x")(_ ++ _).collect === Array("x", "xa", "xab", "xabc", "xabcd")) + + assert(rdd.scan("")(_ ++ _).collect === Array("", "a", "ab", "abc", "abcd")) + } + + test("cumulative with raw") { + val rdd = sc.parallelize(1 to 4, 2) + val f = (x:(Int,Int), y:Int) => (y, y+x._2) + assert(rdd.scanLeft((0,0))(f).collect === Array((0,0), (1,1), (2,3), (3,6), (4,10))) + } + + test("empty partitions") { + val pe = sc.emptyRDD[Int] + val p1 = sc.parallelize(Array(1, 2), 1) + val p2 = sc.parallelize(Array(3, 4), 1) + + assert((pe ++ p1 ++ p2).scanLeft(0)(_ + _).collect === Array(0, 1, 3, 6, 10)) + assert((p1 ++ pe ++ p2).scanLeft(0)(_ + _).collect === Array(0, 1, 3, 6, 10)) + assert((p1 ++ p2 ++ pe).scanLeft(0)(_ + _).collect === Array(0, 1, 3, 6, 10)) + + assert((pe ++ pe ++ p1 ++ p2).scanLeft(0)(_ + _).collect === Array(0, 1, 3, 6, 10)) + assert((p1 ++ pe ++ pe ++ p2).scanLeft(0)(_ + _).collect === Array(0, 1, 3, 6, 10)) + assert((p1 ++ p2 ++ pe ++ pe).scanLeft(0)(_ + _).collect === Array(0, 1, 3, 6, 10)) + + assert((pe).scanLeft(0)(_ + _).collect === Array(0)) + assert((pe ++ pe).scanLeft(0)(_ + _).collect === Array(0)) + assert((pe ++ pe ++ pe).scanLeft(0)(_ + _).collect === Array(0)) + + assert((pe ++ p1 ++ p2).scan(0)(_ + _).collect === Array(0, 1, 3, 6, 10)) + assert((p1 ++ pe ++ p2).scan(0)(_ + _).collect === Array(0, 1, 3, 6, 10)) + assert((p1 ++ p2 ++ pe).scan(0)(_ + _).collect === Array(0, 1, 3, 6, 10)) + + assert((pe ++ pe ++ p1 ++ p2).scan(0)(_ + _).collect === Array(0, 1, 3, 6, 10)) + assert((p1 ++ pe ++ pe ++ p2).scan(0)(_ + _).collect === Array(0, 1, 3, 6, 10)) + assert((p1 ++ p2 ++ pe ++ pe).scan(0)(_ + _).collect === Array(0, 1, 3, 6, 10)) + + assert((pe).scan(0)(_ + _).collect === Array(0)) + assert((pe ++ pe).scan(0)(_ + _).collect === Array(0)) + assert((pe ++ pe ++ pe).scan(0)(_ + _).collect === Array(0)) + } + + test("randomized small") { + val n = 20 + val rng = new scala.util.Random() + + for (unused <- 1 to 50) { + val data = Array.fill(n) { rng.nextInt(100) } + val rdd = sc.parallelize(data, 1 + rng.nextInt(n)) + assert(rdd.scanLeft(0)(_ + _).collect === data.scanLeft(0)(_ + _)) + assert(rdd.scan(0)(_ + _).collect === data.scan(0)(_ + _)) + } + } + + test("randomized large") { + val n = 1000 + val pmin = 32 + val pmax = 128 + val rng = new scala.util.Random() + + for (unused <- 1 to 50) { + val data = Array.fill(n) { rng.nextInt(100) } + val rdd = sc.parallelize(data, pmin + rng.nextInt(pmax - pmin)) + assert(rdd.scanLeft(0)(_ + _).collect === data.scanLeft(0)(_ + _)) + assert(rdd.scan(0)(_ + _).collect === data.scan(0)(_ + _)) + } + } +}