diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index 45526bf062fab..311f1c8885455 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -55,6 +55,11 @@ abstract class EdgeRDD[ED]( } } + override def computeOrReadCheckpoint(part: Partition, context: TaskContext) + : Iterator[Edge[ED]] = { + compute(part, context) + } + /** * Map the values in an edge partitioning preserving the structure but changing the values. * diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index 35577d9e2fc6f..4aa1f9e93b96e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -69,6 +69,11 @@ abstract class VertexRDD[VD]( firstParent[ShippableVertexPartition[VD]].iterator(part, context).next().iterator } + override def computeOrReadCheckpoint(part: Partition, context: TaskContext) + : Iterator[(VertexId, VD)] = { + compute(part, context) + } + /** * Construct a new VertexRDD that is indexed by only the visible vertices. The resulting * VertexRDD will be based on a different index and can no longer be quickly joined with this diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala index f1ecc9e2219d1..e7b9834cddb47 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.graphx import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils class EdgeRDDSuite extends SparkFunSuite with LocalSparkContext { @@ -33,4 +34,24 @@ class EdgeRDDSuite extends SparkFunSuite with LocalSparkContext { } } + test("checkpointed transformations") { + withSpark { sc => + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + sc.setCheckpointDir(path) + + val edges = Array(Edge(1L, 2L, 1), Edge(2L, 1L, 2), Edge(3L, 4L, 3), Edge(4L, 3L, 4)) + + val edgeEDD = EdgeRDD.fromEdges(sc.parallelize(edges)) + + edgeEDD.checkpoint() + edgeEDD.count() + + assert(edgeEDD.collect().toSet == edges.toSet) + + Utils.deleteRecursively(tempDir) + } + } + } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index 0bb9e0a3ea180..53f37fb4816fa 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.graphx import org.apache.spark.{HashPartitioner, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils class VertexRDDSuite extends SparkFunSuite with LocalSparkContext { @@ -197,4 +198,21 @@ class VertexRDDSuite extends SparkFunSuite with LocalSparkContext { } } + test("checkpointed transformations") { + withSpark { sc => + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + sc.setCheckpointDir(path) + + val verts = vertices(sc, 5) + verts.checkpoint() + verts.count() + + assert(verts.collect.toSet == Set(0L -> 0, 1L -> 1, 2L -> 2, 3L -> 3, 4L -> 4, 5L -> 5)) + + Utils.deleteRecursively(tempDir) + } + } + }