Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,10 @@ object SparkContext extends Logging {
implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) =
new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))

implicit def rddToPromiseRDDFunctions[T: ClassTag](rdd: RDD[T]) = new PromiseRDDFunctions(rdd)

implicit def rddToDropRDDFunctions[T: ClassTag](rdd: RDD[T]) = new DropRDDFunctions(rdd)

// Implicit conversions to common Writable types, for saveAsSequenceFile

implicit def intToIntWritable(i: Int) = new IntWritable(i)
Expand Down
172 changes: 172 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/DropRDDFunctions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{SparkContext, Logging, Partition, TaskContext}
import org.apache.spark.{Dependency, NarrowDependency, OneToOneDependency}

import org.apache.spark.SparkContext.rddToPromiseRDDFunctions


private [spark]
class FanInDep[T: ClassTag](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
// Assuming parent RDD type having only one partition
override def getParents(pid: Int) = List(0)
}


/**
* Extra functions available on RDDs for providing the RDD analogs of Scala drop,
* dropRight and dropWhile, which return an RDD as a result
*/
class DropRDDFunctions[T : ClassTag](self: RDD[T]) extends Logging with Serializable {

/**
* Return a new RDD formed by dropping the first (n) elements of the input RDD
*/
def drop(n: Int):RDD[T] = {
if (n <= 0) return self

// locate partition that includes the nth element
val locate = (partitions: Array[Partition], input: RDD[T], ctx: TaskContext) => {
var rem = n
var p = 0
var np = 0
while (rem > 0 && p < partitions.length) {
np = input.iterator(partitions(p), ctx).length
rem -= np
p += 1
}

if (rem > 0 || (rem == 0 && p >= partitions.length)) {
// all elements were dropped
(p, 0)
} else {
// (if we get here, note that rem <= 0)
(p - 1, np + rem)
}
}

val locRDD = self.promiseFromPartitionArray(locate)

new RDD[T](self.context, List(new OneToOneDependency(self), new FanInDep(locRDD))) {
override def getPartitions: Array[Partition] =
self.partitions.map(p => new PromiseArgPartition(p, List(locRDD)))

override val partitioner = self.partitioner

override def compute(split: Partition, ctx: TaskContext):Iterator[T] = {
val dp = split.asInstanceOf[PromiseArgPartition]
val (pFirst, pDrop) = dp.arg[(Int,Int)](0, ctx)
val input = firstParent[T]
if (dp.index > pFirst) return input.iterator(dp.partition, ctx)
if (dp.index == pFirst) return input.iterator(dp.partition, ctx).drop(pDrop)
Iterator.empty
}
}
}


/**
* Return a new RDD formed by dropping the last (n) elements of the input RDD
*/
def dropRight(n: Int):RDD[T] = {
if (n <= 0) return self

val locate = (partitions: Array[Partition], input: RDD[T], ctx: TaskContext) => {
var rem = n
var p = partitions.length-1
var np = 0
while (rem > 0 && p >= 0) {
np = input.iterator(partitions(p), ctx).length
rem -= np
p -= 1
}

if (rem > 0 || (rem == 0 && p < 0)) {
// all elements were dropped
(p, 0)
} else {
// (if we get here, note that rem <= 0)
(p + 1, -rem)
}
}

val locRDD = self.promiseFromPartitionArray(locate)

new RDD[T](self.context, List(new OneToOneDependency(self), new FanInDep(locRDD))) {
override def getPartitions: Array[Partition] =
self.partitions.map(p => new PromiseArgPartition(p, List(locRDD)))

override val partitioner = self.partitioner

override def compute(split: Partition, ctx: TaskContext):Iterator[T] = {
val dp = split.asInstanceOf[PromiseArgPartition]
val (pFirst, pTake) = dp.arg[(Int,Int)](0, ctx)
val input = firstParent[T]
if (dp.index < pFirst) return input.iterator(dp.partition, ctx)
if (dp.index == pFirst) return input.iterator(dp.partition, ctx).take(pTake)
Iterator.empty
}
}
}


/**
* Return a new RDD formed by dropping leading elements until predicate function (f) returns false
*/
def dropWhile(f: T=>Boolean):RDD[T] = {

val locate = (partitions: Array[Partition], input: RDD[T], ctx: TaskContext) => {
var p = 0
var np = 0
while (np <= 0 && p < partitions.length) {
np = input.iterator(partitions(p), ctx).dropWhile(f).length
p += 1
}

if (np <= 0 && p >= partitions.length) {
// all elements were dropped
p
} else {
p - 1
}
}

val locRDD = self.promiseFromPartitionArray(locate)

new RDD[T](self) {
override def getPartitions: Array[Partition] =
self.partitions.map(p => new PromiseArgPartition(p, List(locRDD)))

override val partitioner = self.partitioner

override def compute(split: Partition, ctx: TaskContext):Iterator[T] = {
val dp = split.asInstanceOf[PromiseArgPartition]
val pFirst = dp.arg[Int](0, ctx)
val input = firstParent[T]
if (dp.index > pFirst) return input.iterator(dp.partition, ctx)
if (dp.index == pFirst) return input.iterator(dp.partition, ctx).dropWhile(f)
Iterator.empty
}
}
}

}
111 changes: 111 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/PromiseRDDFunctions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{SparkContext, Logging, Partition, TaskContext,
Dependency, NarrowDependency}


private [spark]
class FanOutDep[T: ClassTag](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
// Assuming child RDD type having only one partition
override def getParents(pid: Int) = (0 until rdd.partitions.length)
}


private [spark]
class PromisePartition extends Partition {
// A PromiseRDD has exactly one partition, by construction:
override def index = 0
}


/**
* A way to represent the concept of a promised expression as an RDD, so that it
* can operate naturally inside the lazy-transform formalism
*/
private [spark]
class PromiseRDD[V: ClassTag](expr: => (TaskContext => V),
context: SparkContext, deps: Seq[Dependency[_]])
extends RDD[V](context, deps) {

// This RDD has exactly one partition by definition, since it will contain
// a single row holding the 'promised' result of evaluating 'expr'
override def getPartitions = Array(new PromisePartition)

// compute evaluates 'expr', yielding an iterator over a sequence of length 1:
override def compute(p: Partition, ctx: TaskContext) = List(expr(ctx)).iterator
}


/**
* A partition that augments a standard RDD partition with a list of PromiseRDD arguments,
* so that they are available at partition compute time
*/
private [spark]
class PromiseArgPartition(p: Partition, argv: Seq[PromiseRDD[_]]) extends Partition {
override def index = p.index

/**
* obtain the underlying partition
*/
def partition: Partition = p

/**
* Compute the nth PromiseRDD argument's expression and return its value
* The return type V must be provided explicitly, and be compatible with the
* actual type of the PromiseRDD.
*/
def arg[V](n: Int, ctx: TaskContext): V =
argv(n).iterator(new PromisePartition, ctx).next.asInstanceOf[V]
}


/**
* Extra functions available on RDDs for providing the RDD analogs of Scala drop,
* dropRight and dropWhile, which return an RDD as a result
*/
class PromiseRDDFunctions[T : ClassTag](self: RDD[T]) extends Logging with Serializable {

/**
* Return a PromiseRDD by applying function 'f' to the partitions of this RDD
*/
def promiseFromPartitions[V: ClassTag](f: Seq[Iterator[T]] => V): PromiseRDD[V] = {
val rdd = self
val plist = rdd.partitions
val expr = self.context.clean((ctx: TaskContext) => f(plist.map(s => rdd.iterator(s, ctx))))
new PromiseRDD[V](expr, rdd.context, List(new FanOutDep(rdd)))
}

/**
* Return a PromiseRDD by applying function 'f' to a partition array.
* This can allow improved efficiency over promiseFromPartitions(), as it does not force
* call to iterator() method over entire partition list, if 'f' does not require it
*/
private [spark]
def promiseFromPartitionArray[V: ClassTag](f: (Array[Partition],
RDD[T], TaskContext) => V): PromiseRDD[V] = {
val rdd = self
val plist = rdd.partitions
val expr = self.context.clean((ctx: TaskContext) => f(plist, rdd, ctx))
new PromiseRDD[V](expr, rdd.context, List(new FanOutDep(rdd)))
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.rdd

import org.scalatest.FunSuite

import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.util.Utils

class DropRDDFunctionsSuite extends FunSuite with SharedSparkContext {

test("drop") {
val rdd = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
assert(rdd.drop(0).collect() === Array(1, 2, 3, 4, 5, 6))
assert(rdd.drop(1).collect() === Array(2, 3, 4, 5, 6))
assert(rdd.drop(2).collect() === Array(3, 4, 5, 6))
assert(rdd.drop(3).collect() === Array(4, 5, 6))
assert(rdd.drop(4).collect() === Array(5, 6))
assert(rdd.drop(5).collect() === Array(6))
assert(rdd.drop(6).collect() === Array())
assert(rdd.drop(7).collect() === Array())
}

test("dropRight") {
val rdd = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
assert(rdd.dropRight(0).collect() === Array(1, 2, 3, 4, 5, 6))
assert(rdd.dropRight(1).collect() === Array(1, 2, 3, 4, 5))
assert(rdd.dropRight(2).collect() === Array(1, 2, 3, 4))
assert(rdd.dropRight(3).collect() === Array(1, 2, 3))
assert(rdd.dropRight(4).collect() === Array(1, 2))
assert(rdd.dropRight(5).collect() === Array(1))
assert(rdd.dropRight(6).collect() === Array())
assert(rdd.dropRight(7).collect() === Array())
}

test("dropWhile") {
val rdd = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
assert(rdd.dropWhile(_ <= 0).collect() === Array(1, 2, 3, 4, 5, 6))
assert(rdd.dropWhile(_ <= 1).collect() === Array(2, 3, 4, 5, 6))
assert(rdd.dropWhile(_ <= 2).collect() === Array(3, 4, 5, 6))
assert(rdd.dropWhile(_ <= 3).collect() === Array(4, 5, 6))
assert(rdd.dropWhile(_ <= 4).collect() === Array(5, 6))
assert(rdd.dropWhile(_ <= 5).collect() === Array(6))
assert(rdd.dropWhile(_ <= 6).collect() === Array())
assert(rdd.dropWhile(_ <= 7).collect() === Array())
}

test("empty input RDD") {
val rdd = sc.emptyRDD[Int]

assert(rdd.drop(0).collect() === Array())
assert(rdd.drop(1).collect() === Array())

assert(rdd.dropRight(0).collect() === Array())
assert(rdd.dropRight(1).collect() === Array())

assert(rdd.dropWhile((x:Int)=>false).collect() === Array())
assert(rdd.dropWhile((x:Int)=>true).collect() === Array())
}

test("filtered and unioned input") {
val consecutive = sc.makeRDD(Array(0, 1, 2, 3, 4, 5, 6, 7, 8), 3)
val rdd0 = consecutive.filter((x:Int)=>(x % 3)==0)
val rdd1 = consecutive.filter((x:Int)=>(x % 3)==1)
val rdd2 = consecutive.filter((x:Int)=>(x % 3)==2)

// input RDD: 0, 3, 6, 1, 4, 7, 2, 5, 8
assert((rdd0 ++ rdd1 ++ rdd2).drop(6).collect() === Array(2, 5, 8))
assert((rdd0 ++ rdd1 ++ rdd2).dropRight(6).collect() === Array(0, 3, 6))
assert((rdd0 ++ rdd1 ++ rdd2).dropWhile(_ < 7).collect() === Array(7, 2, 5, 8))
}
}
Loading