diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 9864dc98c1f33..75395a754a831 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -368,6 +368,20 @@ def checkpoint(self, eager=True): jdf = self._jdf.checkpoint(eager) return DataFrame(jdf, self.sql_ctx) + @since(2.3) + def localCheckpoint(self, eager=True): + """Returns a locally checkpointed version of this Dataset. Checkpointing can be used to + truncate the logical plan of this DataFrame, which is especially useful in iterative + algorithms where the plan may grow exponentially. Local checkpoints are stored in the + executors using the caching subsystem and therefore they are not reliable. + + :param eager: Whether to checkpoint this DataFrame immediately + + .. note:: Experimental + """ + jdf = self._jdf.localCheckpoint(eager) + return DataFrame(jdf, self.sql_ctx) + @since(2.1) def withWatermark(self, eventTime, delayThreshold): """Defines an event time watermark for this :class:`DataFrame`. A watermark tracks a point diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c34cf0a7a7718..ef00562672a7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -527,7 +527,7 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def checkpoint(): Dataset[T] = checkpoint(eager = true) + def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true) /** * Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the @@ -540,9 +540,52 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def checkpoint(eager: Boolean): Dataset[T] = { + def checkpoint(eager: Boolean): Dataset[T] = checkpoint(eager = eager, reliableCheckpoint = true) + + /** + * Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be + * used to truncate the logical plan of this Dataset, which is especially useful in iterative + * algorithms where the plan may grow exponentially. Local checkpoints are written to executor + * storage and despite potentially faster they are unreliable and may compromise job completion. + * + * @group basic + * @since 2.3.0 + */ + @Experimental + @InterfaceStability.Evolving + def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false) + + /** + * Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to truncate + * the logical plan of this Dataset, which is especially useful in iterative algorithms where the + * plan may grow exponentially. Local checkpoints are written to executor storage and despite + * potentially faster they are unreliable and may compromise job completion. + * + * @group basic + * @since 2.3.0 + */ + @Experimental + @InterfaceStability.Evolving + def localCheckpoint(eager: Boolean): Dataset[T] = checkpoint( + eager = eager, + reliableCheckpoint = false + ) + + /** + * Returns a checkpointed version of this Dataset. + * + * @param eager Whether to checkpoint this dataframe immediately + * @param reliableCheckpoint Whether to create a reliable checkpoint saved to files inside the + * checkpoint directory. If false creates a local checkpoint using + * the caching subsystem + */ + private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = { val internalRdd = queryExecution.toRdd.map(_.copy()) - internalRdd.checkpoint() + if (reliableCheckpoint) { + internalRdd.checkpoint() + } else { + internalRdd.localCheckpoint() + } if (eager) { internalRdd.count() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b02db7721aa7f..bd1e7adefc7a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1156,67 +1156,82 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } Seq(true, false).foreach { eager => - def testCheckpointing(testName: String)(f: => Unit): Unit = { - test(s"Dataset.checkpoint() - $testName (eager = $eager)") { - withTempDir { dir => - val originalCheckpointDir = spark.sparkContext.checkpointDir - - try { - spark.sparkContext.setCheckpointDir(dir.getCanonicalPath) + Seq(true, false).foreach { reliable => + def testCheckpointing(testName: String)(f: => Unit): Unit = { + test(s"Dataset.checkpoint() - $testName (eager = $eager, reliable = $reliable)") { + if (reliable) { + withTempDir { dir => + val originalCheckpointDir = spark.sparkContext.checkpointDir + + try { + spark.sparkContext.setCheckpointDir(dir.getCanonicalPath) + f + } finally { + // Since the original checkpointDir can be None, we need + // to set the variable directly. + spark.sparkContext.checkpointDir = originalCheckpointDir + } + } + } else { + // Local checkpoints dont require checkpoint_dir f - } finally { - // Since the original checkpointDir can be None, we need - // to set the variable directly. - spark.sparkContext.checkpointDir = originalCheckpointDir } } } - } - testCheckpointing("basic") { - val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc) - val cp = ds.checkpoint(eager) + testCheckpointing("basic") { + val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc) + val cp = if (reliable) ds.checkpoint(eager) else ds.localCheckpoint(eager) - val logicalRDD = cp.logicalPlan match { - case plan: LogicalRDD => plan - case _ => - val treeString = cp.logicalPlan.treeString(verbose = true) - fail(s"Expecting a LogicalRDD, but got\n$treeString") - } + val logicalRDD = cp.logicalPlan match { + case plan: LogicalRDD => plan + case _ => + val treeString = cp.logicalPlan.treeString(verbose = true) + fail(s"Expecting a LogicalRDD, but got\n$treeString") + } - val dsPhysicalPlan = ds.queryExecution.executedPlan - val cpPhysicalPlan = cp.queryExecution.executedPlan + val dsPhysicalPlan = ds.queryExecution.executedPlan + val cpPhysicalPlan = cp.queryExecution.executedPlan - assertResult(dsPhysicalPlan.outputPartitioning) { logicalRDD.outputPartitioning } - assertResult(dsPhysicalPlan.outputOrdering) { logicalRDD.outputOrdering } + assertResult(dsPhysicalPlan.outputPartitioning) { + logicalRDD.outputPartitioning + } + assertResult(dsPhysicalPlan.outputOrdering) { + logicalRDD.outputOrdering + } - assertResult(dsPhysicalPlan.outputPartitioning) { cpPhysicalPlan.outputPartitioning } - assertResult(dsPhysicalPlan.outputOrdering) { cpPhysicalPlan.outputOrdering } + assertResult(dsPhysicalPlan.outputPartitioning) { + cpPhysicalPlan.outputPartitioning + } + assertResult(dsPhysicalPlan.outputOrdering) { + cpPhysicalPlan.outputOrdering + } - // For a lazy checkpoint() call, the first check also materializes the checkpoint. - checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) + // For a lazy checkpoint() call, the first check also materializes the checkpoint. + checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) - // Reads back from checkpointed data and check again. - checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) - } + // Reads back from checkpointed data and check again. + checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) + } - testCheckpointing("should preserve partitioning information") { - val ds = spark.range(10).repartition('id % 2) - val cp = ds.checkpoint(eager) + testCheckpointing("should preserve partitioning information") { + val ds = spark.range(10).repartition('id % 2) + val cp = if (reliable) ds.checkpoint(eager) else ds.localCheckpoint(eager) - val agg = cp.groupBy('id % 2).agg(count('id)) + val agg = cp.groupBy('id % 2).agg(count('id)) - agg.queryExecution.executedPlan.collectFirst { - case ShuffleExchangeExec(_, _: RDDScanExec, _) => - case BroadcastExchangeExec(_, _: RDDScanExec) => - }.foreach { _ => - fail( - "No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " + - "preserves partitioning information:\n\n" + agg.queryExecution - ) - } + agg.queryExecution.executedPlan.collectFirst { + case ShuffleExchangeExec(_, _: RDDScanExec, _) => + case BroadcastExchangeExec(_, _: RDDScanExec) => + }.foreach { _ => + fail( + "No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " + + "preserves partitioning information:\n\n" + agg.queryExecution + ) + } - checkAnswer(agg, ds.groupBy('id % 2).agg(count('id))) + checkAnswer(agg, ds.groupBy('id % 2).agg(count('id))) + } } }