Skip to content
Closed
14 changes: 14 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,20 @@ def checkpoint(self, eager=True):
jdf = self._jdf.checkpoint(eager)
return DataFrame(jdf, self.sql_ctx)

@since(2.3)
def localCheckpoint(self, eager=True):
"""Returns a locally checkpointed version of this Dataset. Checkpointing can be used to
truncate the logical plan of this DataFrame, which is especially useful in iterative
algorithms where the plan may grow exponentially. Local checkpoints are stored in the
executors using the caching subsystem and therefore they are not reliable.

:param eager: Whether to checkpoint this DataFrame immediately

.. note:: Experimental
"""
jdf = self._jdf.localCheckpoint(eager)
return DataFrame(jdf, self.sql_ctx)

@since(2.1)
def withWatermark(self, eventTime, delayThreshold):
"""Defines an event time watermark for this :class:`DataFrame`. A watermark tracks a point
Expand Down
49 changes: 46 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ class Dataset[T] private[sql](
*/
@Experimental
@InterfaceStability.Evolving
def checkpoint(): Dataset[T] = checkpoint(eager = true)
def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true)

/**
* Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the
Expand All @@ -540,9 +540,52 @@ class Dataset[T] private[sql](
*/
@Experimental
@InterfaceStability.Evolving
def checkpoint(eager: Boolean): Dataset[T] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check the test case of def checkpoint? At least we need to add a test case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try to create a test to localCheckpoint based on the one for checkpoint, but I'm not very familiar with Scala and the Spark scala API, so currently I don't feel at ease to create a meaningful test. Would anybody be up to add one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we already test checkpoint in DatasetSuite

def checkpoint(eager: Boolean): Dataset[T] = checkpoint(eager = eager, reliableCheckpoint = true)

/**
* Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be
* used to truncate the logical plan of this Dataset, which is especially useful in iterative
* algorithms where the plan may grow exponentially. Local checkpoints are written to executor
* storage and despite potentially faster they are unreliable and may compromise job completion.
*
* @group basic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add @since

* @since 2.3.0
*/
@Experimental
@InterfaceStability.Evolving
def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false)

/**
* Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to truncate
* the logical plan of this Dataset, which is especially useful in iterative algorithms where the
* plan may grow exponentially. Local checkpoints are written to executor storage and despite
* potentially faster they are unreliable and may compromise job completion.
*
* @group basic
* @since 2.3.0
*/
@Experimental
@InterfaceStability.Evolving
def localCheckpoint(eager: Boolean): Dataset[T] = checkpoint(
eager = eager,
reliableCheckpoint = false
)

/**
* Returns a checkpointed version of this Dataset.
*
* @param eager Whether to checkpoint this dataframe immediately
* @param reliableCheckpoint Whether to create a reliable checkpoint saved to files inside the
* checkpoint directory. If false creates a local checkpoint using
* the caching subsystem
*/
private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = {
val internalRdd = queryExecution.toRdd.map(_.copy())
internalRdd.checkpoint()
if (reliableCheckpoint) {
internalRdd.checkpoint()
} else {
internalRdd.localCheckpoint()
Copy link
Member

@gatorsmile gatorsmile Dec 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also issue a logWarning message here to indicate the checkpoint is not reliable? This call is a potential issue when users using AWS EC2 Spot instances.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi. Thanks for the review.
From the point of view of the user being aware he's doing a local checkpoint we already force him to use localCheckpoint() (the generic checkpoint is private)
If we should warn users about the potential issues with localCheckpoint() shouldn't we do it in the RDD API, so that users are always warned?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

if (eager) {
internalRdd.count()
Expand Down
107 changes: 61 additions & 46 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1156,67 +1156,82 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}

Seq(true, false).foreach { eager =>
def testCheckpointing(testName: String)(f: => Unit): Unit = {
test(s"Dataset.checkpoint() - $testName (eager = $eager)") {
withTempDir { dir =>
val originalCheckpointDir = spark.sparkContext.checkpointDir

try {
spark.sparkContext.setCheckpointDir(dir.getCanonicalPath)
Seq(true, false).foreach { reliable =>
def testCheckpointing(testName: String)(f: => Unit): Unit = {
test(s"Dataset.checkpoint() - $testName (eager = $eager, reliable = $reliable)") {
if (reliable) {
withTempDir { dir =>
val originalCheckpointDir = spark.sparkContext.checkpointDir

try {
spark.sparkContext.setCheckpointDir(dir.getCanonicalPath)
f
} finally {
// Since the original checkpointDir can be None, we need
// to set the variable directly.
spark.sparkContext.checkpointDir = originalCheckpointDir
}
}
} else {
// Local checkpoints dont require checkpoint_dir
f
} finally {
// Since the original checkpointDir can be None, we need
// to set the variable directly.
spark.sparkContext.checkpointDir = originalCheckpointDir
}
}
}
}

testCheckpointing("basic") {
val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc)
val cp = ds.checkpoint(eager)
testCheckpointing("basic") {
val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc)
val cp = if (reliable) ds.checkpoint(eager) else ds.localCheckpoint(eager)

val logicalRDD = cp.logicalPlan match {
case plan: LogicalRDD => plan
case _ =>
val treeString = cp.logicalPlan.treeString(verbose = true)
fail(s"Expecting a LogicalRDD, but got\n$treeString")
}
val logicalRDD = cp.logicalPlan match {
case plan: LogicalRDD => plan
case _ =>
val treeString = cp.logicalPlan.treeString(verbose = true)
fail(s"Expecting a LogicalRDD, but got\n$treeString")
}

val dsPhysicalPlan = ds.queryExecution.executedPlan
val cpPhysicalPlan = cp.queryExecution.executedPlan
val dsPhysicalPlan = ds.queryExecution.executedPlan
val cpPhysicalPlan = cp.queryExecution.executedPlan

assertResult(dsPhysicalPlan.outputPartitioning) { logicalRDD.outputPartitioning }
assertResult(dsPhysicalPlan.outputOrdering) { logicalRDD.outputOrdering }
assertResult(dsPhysicalPlan.outputPartitioning) {
logicalRDD.outputPartitioning
}
assertResult(dsPhysicalPlan.outputOrdering) {
logicalRDD.outputOrdering
}

assertResult(dsPhysicalPlan.outputPartitioning) { cpPhysicalPlan.outputPartitioning }
assertResult(dsPhysicalPlan.outputOrdering) { cpPhysicalPlan.outputOrdering }
assertResult(dsPhysicalPlan.outputPartitioning) {
cpPhysicalPlan.outputPartitioning
}
assertResult(dsPhysicalPlan.outputOrdering) {
cpPhysicalPlan.outputOrdering
}

// For a lazy checkpoint() call, the first check also materializes the checkpoint.
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
// For a lazy checkpoint() call, the first check also materializes the checkpoint.
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)

// Reads back from checkpointed data and check again.
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
}
// Reads back from checkpointed data and check again.
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
}

testCheckpointing("should preserve partitioning information") {
val ds = spark.range(10).repartition('id % 2)
val cp = ds.checkpoint(eager)
testCheckpointing("should preserve partitioning information") {
val ds = spark.range(10).repartition('id % 2)
val cp = if (reliable) ds.checkpoint(eager) else ds.localCheckpoint(eager)

val agg = cp.groupBy('id % 2).agg(count('id))
val agg = cp.groupBy('id % 2).agg(count('id))

agg.queryExecution.executedPlan.collectFirst {
case ShuffleExchangeExec(_, _: RDDScanExec, _) =>
case BroadcastExchangeExec(_, _: RDDScanExec) =>
}.foreach { _ =>
fail(
"No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " +
"preserves partitioning information:\n\n" + agg.queryExecution
)
}
agg.queryExecution.executedPlan.collectFirst {
case ShuffleExchangeExec(_, _: RDDScanExec, _) =>
case BroadcastExchangeExec(_, _: RDDScanExec) =>
}.foreach { _ =>
fail(
"No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " +
"preserves partitioning information:\n\n" + agg.queryExecution
)
}

checkAnswer(agg, ds.groupBy('id % 2).agg(count('id)))
checkAnswer(agg, ds.groupBy('id % 2).agg(count('id)))
}
}
}

Expand Down