diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3e6addeaf04a..f3c31f55a8df 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1340,6 +1340,10 @@ object SparkContext extends Logging { implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) + implicit def rddToPromiseRDDFunctions[T: ClassTag](rdd: RDD[T]) = new PromiseRDDFunctions(rdd) + + implicit def rddToDropRDDFunctions[T: ClassTag](rdd: RDD[T]) = new DropRDDFunctions(rdd) + // Implicit conversions to common Writable types, for saveAsSequenceFile implicit def intToIntWritable(i: Int) = new IntWritable(i) diff --git a/core/src/main/scala/org/apache/spark/rdd/DropRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DropRDDFunctions.scala new file mode 100644 index 000000000000..48aa17999ce2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/DropRDDFunctions.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.{SparkContext, Logging, Partition, TaskContext} +import org.apache.spark.{Dependency, NarrowDependency, OneToOneDependency} + +import org.apache.spark.SparkContext.rddToPromiseRDDFunctions + + +private [spark] +class FanInDep[T: ClassTag](rdd: RDD[T]) extends NarrowDependency[T](rdd) { + // Assuming parent RDD type having only one partition + override def getParents(pid: Int) = List(0) +} + + +/** + * Extra functions available on RDDs for providing the RDD analogs of Scala drop, + * dropRight and dropWhile, which return an RDD as a result + */ +class DropRDDFunctions[T : ClassTag](self: RDD[T]) extends Logging with Serializable { + + /** + * Return a new RDD formed by dropping the first (n) elements of the input RDD + */ + def drop(n: Int):RDD[T] = { + if (n <= 0) return self + + // locate partition that includes the nth element + val locate = (partitions: Array[Partition], parent: RDD[T], ctx: TaskContext) => { + var rem = n + var p = 0 + var np = 0 + while (rem > 0 && p < partitions.length) { + np = parent.iterator(partitions(p), ctx).length + rem -= np + p += 1 + } + + if (rem > 0 || (rem == 0 && p >= partitions.length)) { + // all elements were dropped + (p, 0) + } else { + // (if we get here, note that rem <= 0) + (p - 1, np + rem) + } + } + + val locRDD = self.promiseFromPartitionArray(locate) + + new RDD[T](self.context, List(new OneToOneDependency(self), new FanInDep(locRDD))) { + override def getPartitions: Array[Partition] = + self.partitions.map(p => new PromiseArgPartition(p, List(locRDD))) + + override val partitioner = self.partitioner + + override def compute(split: Partition, ctx: TaskContext):Iterator[T] = { + val dp = split.asInstanceOf[PromiseArgPartition] + val (pFirst, pDrop) = dp.arg[(Int,Int)](0, ctx) + val parent = firstParent[T] + if (dp.index > pFirst) return parent.iterator(dp.partition, ctx) + if (dp.index == pFirst) return parent.iterator(dp.partition, ctx).drop(pDrop) + Iterator.empty + } + } + } + + + /** + * Return a new RDD formed by dropping the last (n) elements of the input RDD + */ + def dropRight(n: Int):RDD[T] = { + if (n <= 0) return self + + val locate = (partitions: Array[Partition], parent: RDD[T], ctx: TaskContext) => { + var rem = n + var p = partitions.length-1 + var np = 0 + while (rem > 0 && p >= 0) { + np = parent.iterator(partitions(p), ctx).length + rem -= np + p -= 1 + } + + if (rem > 0 || (rem == 0 && p < 0)) { + // all elements were dropped + (p, 0) + } else { + // (if we get here, note that rem <= 0) + (p + 1, -rem) + } + } + + val locRDD = self.promiseFromPartitionArray(locate) + + new RDD[T](self.context, List(new OneToOneDependency(self), new FanInDep(locRDD))) { + override def getPartitions: Array[Partition] = + self.partitions.map(p => new PromiseArgPartition(p, List(locRDD))) + + override val partitioner = self.partitioner + + override def compute(split: Partition, ctx: TaskContext):Iterator[T] = { + val dp = split.asInstanceOf[PromiseArgPartition] + val (pFirst, pTake) = dp.arg[(Int,Int)](0, ctx) + val parent = firstParent[T] + if (dp.index < pFirst) return parent.iterator(dp.partition, ctx) + if (dp.index == pFirst) return parent.iterator(dp.partition, ctx).take(pTake) + Iterator.empty + } + } + } + + + /** + * Return a new RDD formed by dropping leading elements until predicate function (f) returns false + */ + def dropWhile(f: T=>Boolean):RDD[T] = { + + val locate = (partitions: Array[Partition], parent: RDD[T], ctx: TaskContext) => { + var p = 0 + var np = 0 + while (np <= 0 && p < partitions.length) { + np = parent.iterator(partitions(p), ctx).dropWhile(f).length + p += 1 + } + + if (np <= 0 && p >= partitions.length) { + // all elements were dropped + p + } else { + p - 1 + } + } + + val locRDD = self.promiseFromPartitionArray(locate) + + new RDD[T](self) { + override def getPartitions: Array[Partition] = + self.partitions.map(p => new PromiseArgPartition(p, List(locRDD))) + + override val partitioner = self.partitioner + + override def compute(split: Partition, ctx: TaskContext):Iterator[T] = { + val dp = split.asInstanceOf[PromiseArgPartition] + val pFirst = dp.arg[Int](0, ctx) + val parent = firstParent[T] + if (dp.index > pFirst) return parent.iterator(dp.partition, ctx) + if (dp.index == pFirst) return parent.iterator(dp.partition, ctx).dropWhile(f) + Iterator.empty + } + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/PromiseRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PromiseRDDFunctions.scala new file mode 100644 index 000000000000..1c2f7c2139ed --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/PromiseRDDFunctions.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.{SparkContext, Logging, Partition, TaskContext, + Dependency, NarrowDependency} + + +private [spark] +class FanOutDep[T: ClassTag](rdd: RDD[T]) extends NarrowDependency[T](rdd) { + // Assuming child RDD type having only one partition + override def getParents(pid: Int) = (0 until rdd.partitions.length) +} + + +private [spark] +class PromisePartition extends Partition { + // A PromiseRDD has exactly one partition, by construction: + override def index = 0 +} + + +/** + * A way to represent the concept of a promised expression as an RDD, so that it + * can operate naturally inside the lazy-transform formalism + */ +private [spark] +class PromiseRDD[V: ClassTag](expr: => (TaskContext => V), + context: SparkContext, deps: Seq[Dependency[_]]) + extends RDD[V](context, deps) { + + // This RDD has exactly one partition by definition, since it will contain + // a single row holding the 'promised' result of evaluating 'expr' + override def getPartitions = Array(new PromisePartition) + + // compute evaluates 'expr', yielding an iterator over a sequence of length 1: + override def compute(p: Partition, ctx: TaskContext) = List(expr(ctx)).iterator +} + + +/** + * A partition that augments a standard RDD partition with a list of PromiseRDD arguments, + * so that they are available at partition compute time + */ +private [spark] +class PromiseArgPartition(p: Partition, argv: Seq[PromiseRDD[_]]) extends Partition { + override def index = p.index + + /** + * obtain the underlying partition + */ + def partition: Partition = p + + /** + * Compute the nth PromiseRDD argument's expression and return its value + * The return type V must be provided explicitly, and be compatible with the + * actual type of the PromiseRDD. + */ + def arg[V](n: Int, ctx: TaskContext): V = + argv(n).iterator(new PromisePartition, ctx).next.asInstanceOf[V] +} + + +/** + * Extra functions available on RDDs for providing the RDD analogs of Scala drop, + * dropRight and dropWhile, which return an RDD as a result + */ +class PromiseRDDFunctions[T : ClassTag](self: RDD[T]) extends Logging with Serializable { + + /** + * Return a PromiseRDD by applying function 'f' to the partitions of this RDD + */ + def promiseFromPartitions[V: ClassTag](f: Seq[Iterator[T]] => V): PromiseRDD[V] = { + val rdd = self + val plist = rdd.partitions + val expr = (ctx: TaskContext) => f(plist.map(s => rdd.iterator(s, ctx))) + new PromiseRDD[V](expr, rdd.context, List(new FanOutDep(rdd))) + } + + /** + * Return a PromiseRDD by applying function 'f' to a partition array. + * This can allow improved efficiency over promiseFromPartitions(), as it does not force + * call to iterator() method over entire partition list, if 'f' does not require it + */ + private [spark] + def promiseFromPartitionArray[V: ClassTag](f: (Array[Partition], + RDD[T], TaskContext) => V): PromiseRDD[V] = { + val rdd = self + val plist = rdd.partitions + val expr = (ctx: TaskContext) => f(plist, rdd, ctx) + new PromiseRDD[V](expr, rdd.context, List(new FanOutDep(rdd))) + } + +} diff --git a/core/src/test/scala/org/apache/spark/rdd/DropRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DropRDDFunctionsSuite.scala new file mode 100644 index 000000000000..1ad05ce8f538 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/DropRDDFunctionsSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.scalatest.FunSuite + +import org.apache.spark._ +import org.apache.spark.SparkContext._ +import org.apache.spark.util.Utils + +class DropRDDFunctionsSuite extends FunSuite with SharedSparkContext { + + test("drop") { + val rdd = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) + assert(rdd.drop(0).collect() === Array(1, 2, 3, 4, 5, 6)) + assert(rdd.drop(1).collect() === Array(2, 3, 4, 5, 6)) + assert(rdd.drop(2).collect() === Array(3, 4, 5, 6)) + assert(rdd.drop(3).collect() === Array(4, 5, 6)) + assert(rdd.drop(4).collect() === Array(5, 6)) + assert(rdd.drop(5).collect() === Array(6)) + assert(rdd.drop(6).collect() === Array()) + assert(rdd.drop(7).collect() === Array()) + } + + test("dropRight") { + val rdd = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) + assert(rdd.dropRight(0).collect() === Array(1, 2, 3, 4, 5, 6)) + assert(rdd.dropRight(1).collect() === Array(1, 2, 3, 4, 5)) + assert(rdd.dropRight(2).collect() === Array(1, 2, 3, 4)) + assert(rdd.dropRight(3).collect() === Array(1, 2, 3)) + assert(rdd.dropRight(4).collect() === Array(1, 2)) + assert(rdd.dropRight(5).collect() === Array(1)) + assert(rdd.dropRight(6).collect() === Array()) + assert(rdd.dropRight(7).collect() === Array()) + } + + test("dropWhile") { + val rdd = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) + assert(rdd.dropWhile(_ <= 0).collect() === Array(1, 2, 3, 4, 5, 6)) + assert(rdd.dropWhile(_ <= 1).collect() === Array(2, 3, 4, 5, 6)) + assert(rdd.dropWhile(_ <= 2).collect() === Array(3, 4, 5, 6)) + assert(rdd.dropWhile(_ <= 3).collect() === Array(4, 5, 6)) + assert(rdd.dropWhile(_ <= 4).collect() === Array(5, 6)) + assert(rdd.dropWhile(_ <= 5).collect() === Array(6)) + assert(rdd.dropWhile(_ <= 6).collect() === Array()) + assert(rdd.dropWhile(_ <= 7).collect() === Array()) + } + + test("empty input RDD") { + val rdd = sc.emptyRDD[Int] + + assert(rdd.drop(0).collect() === Array()) + assert(rdd.drop(1).collect() === Array()) + + assert(rdd.dropRight(0).collect() === Array()) + assert(rdd.dropRight(1).collect() === Array()) + + assert(rdd.dropWhile((x:Int)=>false).collect() === Array()) + assert(rdd.dropWhile((x:Int)=>true).collect() === Array()) + } + + test("filtered and unioned input") { + val consecutive = sc.makeRDD(Array(0, 1, 2, 3, 4, 5, 6, 7, 8), 3) + val rdd0 = consecutive.filter((x:Int)=>(x % 3)==0) + val rdd1 = consecutive.filter((x:Int)=>(x % 3)==1) + val rdd2 = consecutive.filter((x:Int)=>(x % 3)==2) + + // input RDD: 0, 3, 6, 1, 4, 7, 2, 5, 8 + assert((rdd0 ++ rdd1 ++ rdd2).drop(6).collect() === Array(2, 5, 8)) + assert((rdd0 ++ rdd1 ++ rdd2).dropRight(6).collect() === Array(0, 3, 6)) + assert((rdd0 ++ rdd1 ++ rdd2).dropWhile(_ < 7).collect() === Array(7, 2, 5, 8)) + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/PromiseRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PromiseRDDFunctionsSuite.scala new file mode 100644 index 000000000000..d454a0539fc5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/PromiseRDDFunctionsSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.scalatest.FunSuite + +import org.apache.spark._ +import org.apache.spark.SparkContext._ +import org.apache.spark.util.Utils + +class PromiseRDDFunctionsSuite extends FunSuite with SharedSparkContext { + + test("simple promise") { + // a promise RDD having no RDD dependencies + val rdd = new PromiseRDD((ctx: TaskContext) => 42, sc, Nil) + assert(rdd.collect() === Array(42)) + } + + test("promise a sum") { + val data = sc.parallelize(Array(1, 2, 3)) + val rdd = data.promiseFromPartitions((s:Seq[Iterator[Int]]) => s.map(_.sum).sum) + assert(rdd.collect() === Array(6)) + } +}