Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val bos = new ByteArrayOutputStream()
val out = new DataOutputStream(codec.compressedOutputStream(bos))
while (iter.hasNext && (n < 0 || count < n)) {
// `iter.hasNext` may produce one row and buffer it, we should only call it when the limit is
// not hit.
while ((n < 0 || count < n) && iter.hasNext) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch this one!

val row = iter.next().asInstanceOf[UnsafeRow]
out.writeInt(row.getSizeInBytes)
row.writeToStream(out, buffer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.execution.ui.SQLAppStatusStore
import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
Expand Down Expand Up @@ -517,4 +517,57 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
test("writing data out metrics with dynamic partition: parquet") {
testMetricsDynamicPartition("parquet", "parquet", "t1")
}

test("SPARK-25602: SparkPlan.getByteArrayRdd should not consume the input when not necessary") {
def checkFilterAndRangeMetrics(
df: DataFrame,
filterNumOutputs: Int,
rangeNumOutputs: Int): Unit = {
var filter: FilterExec = null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about something like this:

def collectExecNode[T](pf: PartialFunction[SparkPlan, T]): PartialFunction[SparkPlan, T] = {
        pf.orElse {
          case w: WholeStageCodegenExec =>
            w.child.collect(pf).head
        }
      }
      val range = df.queryExecution.executedPlan.collectFirst(
        collectExecNode { case r: RangeExec => r })
      val filter = df.queryExecution.executedPlan.collectFirst(
        collectExecNode { case f: FilterExec => f })

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future if we need to catch more nodes, we should abstract it. But for now it's only range and filter, I think it's ok.

var range: RangeExec = null
val collectFilterAndRange: SparkPlan => Unit = {
case f: FilterExec =>
assert(filter == null, "the query should only have one Filter")
filter = f
case r: RangeExec =>
assert(range == null, "the query should only have one Range")
range = r
case _ =>
}
if (SQLConf.get.wholeStageEnabled) {
df.queryExecution.executedPlan.foreach {
case w: WholeStageCodegenExec =>
w.child.foreach(collectFilterAndRange)
case _ =>
}
} else {
df.queryExecution.executedPlan.foreach(collectFilterAndRange)
}

assert(filter != null && range != null, "the query doesn't have Filter and Range")
assert(filter.metrics("numOutputRows").value == filterNumOutputs)
assert(range.metrics("numOutputRows").value == rangeNumOutputs)
}

val df = spark.range(0, 3000, 1, 2).toDF().filter('id % 3 === 0)
val df2 = df.limit(2)
Seq(true, false).foreach { wholeStageEnabled =>
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStageEnabled.toString) {
df.collect()
checkFilterAndRangeMetrics(df, filterNumOutputs = 1000, rangeNumOutputs = 3000)

df.queryExecution.executedPlan.foreach(_.resetMetrics())
// For each partition, we get 2 rows. Then the Filter should produce 2 rows per-partition,
// and Range should produce 1000 rows (one batch) per-partition. Totally Filter produces
// 4 rows, and Range produces 2000 rows.
df.queryExecution.toRdd.mapPartitions(_.take(2)).collect()
checkFilterAndRangeMetrics(df, filterNumOutputs = 4, rangeNumOutputs = 2000)

// Top-most limit will call `CollectLimitExec.executeCollect`, which will only run the first
// task, so totally the Filter produces 2 rows, and Range produces 1000 rows (one batch).
df2.collect()
checkFilterAndRangeMetrics(df2, filterNumOutputs = 2, rangeNumOutputs = 1000)
}
}
}
}