diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 47265df4831df..7368a6c9e1d64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -194,10 +194,20 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared } test("Sort metrics") { - // Assume the execution plan is - // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) - val ds = spark.range(10).sort('id) - testSparkPlanMetrics(ds.toDF(), 2, Map.empty) + // Assume the execution plan with node id is + // Sort(nodeId = 0) + // Exchange(nodeId = 1) + // Project(nodeId = 2) + // LocalTableScan(nodeId = 3) + // Because of SPARK-25267, ConvertToLocalRelation is disabled in the test cases of sql/core, + // so Project here is not collapsed into LocalTableScan. + val df = Seq(1, 3, 2).toDF("id").sort('id) + testSparkPlanMetricsWithPredicates(df, 2, Map( + 0L -> (("Sort", Map( + "sort time total (min, med, max)" -> {_.toString.matches(timingMetricPattern)}, + "peak memory total (min, med, max)" -> {_.toString.matches(sizeMetricPattern)}, + "spill size total (min, med, max)" -> {_.toString.matches(sizeMetricPattern)}))) + )) } test("SortMergeJoin metrics") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index dcc540fc4f109..2d245d2ba1e35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -40,6 +40,18 @@ trait SQLMetricsTestUtils extends SQLTestUtils { protected def statusStore: SQLAppStatusStore = spark.sharedState.statusStore + // Pattern of size SQLMetric value, e.g. "\n96.2 MiB (32.1 MiB, 32.1 MiB, 32.1 MiB)" + protected val sizeMetricPattern = { + val bytes = "([0-9]+(\\.[0-9]+)?) (EiB|PiB|TiB|GiB|MiB|KiB|B)" + s"\\n$bytes \\($bytes, $bytes, $bytes\\)" + } + + // Pattern of timing SQLMetric value, e.g. "\n2.0 ms (1.0 ms, 1.0 ms, 1.0 ms)" + protected val timingMetricPattern = { + val duration = "([0-9]+(\\.[0-9]+)?) (ms|s|m|h)" + s"\\n$duration \\($duration, $duration, $duration\\)" + } + /** * Get execution metrics for the SQL execution and verify metrics values. * @@ -185,15 +197,34 @@ trait SQLMetricsTestUtils extends SQLTestUtils { df: DataFrame, expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - val optActualMetrics = getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetrics.keySet) + val expectedMetricsPredicates = expectedMetrics.mapValues { case (nodeName, nodeMetrics) => + (nodeName, nodeMetrics.mapValues(expectedMetricValue => + (actualMetricValue: Any) => expectedMetricValue.toString === actualMetricValue)) + } + testSparkPlanMetricsWithPredicates(df, expectedNumOfJobs, expectedMetricsPredicates) + } + + /** + * Call `df.collect()` and verify if the collected metrics satisfy the specified predicates. + * @param df `DataFrame` to run + * @param expectedNumOfJobs number of jobs that will run + * @param expectedMetricsPredicates the expected metrics predicates. The format is + * `nodeId -> (operatorName, metric name -> metric predicate)`. + */ + protected def testSparkPlanMetricsWithPredicates( + df: DataFrame, + expectedNumOfJobs: Int, + expectedMetricsPredicates: Map[Long, (String, Map[String, Any => Boolean])]): Unit = { + val optActualMetrics = + getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetricsPredicates.keySet) optActualMetrics.foreach { actualMetrics => - assert(expectedMetrics.keySet === actualMetrics.keySet) - for (nodeId <- expectedMetrics.keySet) { - val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId) + assert(expectedMetricsPredicates.keySet === actualMetrics.keySet) + for ((nodeId, (expectedNodeName, expectedMetricsPredicatesMap)) + <- expectedMetricsPredicates) { val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) assert(expectedNodeName === actualNodeName) - for (metricName <- expectedMetricsMap.keySet) { - assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName)) + for ((metricName, metricPredicate) <- expectedMetricsPredicatesMap) { + assert(metricPredicate(actualMetricsMap(metricName))) } } }