From d40109fa50c286b4e469dc6ffe953c914ba559e0 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Sat, 8 Dec 2018 11:58:21 +0800 Subject: [PATCH 01/11] [SPARK-23375][SQL][FOLLOWUP][TEST] Test Sort metrics while Sort is missing --- .../sql/execution/metric/SQLMetricsSuite.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 47265df4831d..46022e6d896e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -194,10 +194,13 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared } test("Sort metrics") { - // Assume the execution plan is - // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) - val ds = spark.range(10).sort('id) - testSparkPlanMetrics(ds.toDF(), 2, Map.empty) + // Assume the execution plan with node id is + // Sort(nodeId = 0) + // Exchange(nodeId = 1) + // Range(nodeId = 2) + val df = spark.range(9, -1, -1).sort('id).toDF() + testSparkPlanMetrics(df, 2, Map.empty) + df.queryExecution.executedPlan.find(_.isInstanceOf[SortExec]).getOrElse(assert(false)) } test("SortMergeJoin metrics") { From 42c77f24504a92bad9ef13664f15acb0d3347547 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Sat, 8 Dec 2018 14:35:14 +0800 Subject: [PATCH 02/11] fix according to comments: make query more readable --- .../spark/sql/execution/metric/SQLMetricsSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 46022e6d896e..d6ccefe47303 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.{FilterExec, RangeExec, SortExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -197,10 +197,10 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared // Assume the execution plan with node id is // Sort(nodeId = 0) // Exchange(nodeId = 1) - // Range(nodeId = 2) - val df = spark.range(9, -1, -1).sort('id).toDF() + // LocalTableScan(nodeId = 2) + val df = Seq(1, 3, 2).toDF("id").sort('id) testSparkPlanMetrics(df, 2, Map.empty) - df.queryExecution.executedPlan.find(_.isInstanceOf[SortExec]).getOrElse(assert(false)) + assert(df.queryExecution.executedPlan.find(_.isInstanceOf[SortExec]).isDefined) } test("SortMergeJoin metrics") { From 68dbdd7b51a5e2d4bbd0abd807fe832ba9e46a95 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Tue, 25 Dec 2018 14:32:56 +0800 Subject: [PATCH 03/11] check Sort metrics values --- .../execution/metric/SQLMetricsSuite.scala | 19 ++++++-- .../metric/SQLMetricsTestUtils.scala | 46 +++++++++++++++++++ 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index d6ccefe47303..36fbfff2ec9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -197,10 +197,23 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared // Assume the execution plan with node id is // Sort(nodeId = 0) // Exchange(nodeId = 1) - // LocalTableScan(nodeId = 2) + // Project(nodeId = 2) + // LocalTableScan(nodeId = 3) + // Because of SPARK-25267, ConvertToLocalRelation is disabled in the test cases of sql/core, + // so Project here is not collapsed into LocalTableScan. val df = Seq(1, 3, 2).toDF("id").sort('id) - testSparkPlanMetrics(df, 2, Map.empty) - assert(df.queryExecution.executedPlan.find(_.isInstanceOf[SortExec]).isDefined) + val metrics = getSparkPlanMetrics(df, 2, Set(0)) + val sortMetrics = metrics.get.get(0).get + // Check node 0 is Sort node + val operatorName = sortMetrics._1 + assert(operatorName == "Sort") + // Check metrics values + val sortTimeStr = sortMetrics._2.get("sort time total (min, med, max)").get.toString + timingMetricStats(sortTimeStr).foreach { case (sortTime, _) => assert(sortTime >= 0) } + val peakMemoryStr = sortMetrics._2.get("peak memory total (min, med, max)").get.toString + sizeMetricStats(peakMemoryStr).foreach { case (peakMemory, _) => assert(peakMemory > 0) } + val spillSizeStr = sortMetrics._2.get("spill size total (min, med, max)").get.toString + sizeMetricStats(spillSizeStr).foreach { case (spillSize, _) => assert(spillSize >= 0) } } test("SortMergeJoin metrics") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index dcc540fc4f10..c1c54816bff2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.metric import java.io.File +import java.util.regex.Pattern import scala.collection.mutable.HashMap @@ -198,6 +199,51 @@ trait SQLMetricsTestUtils extends SQLTestUtils { } } } + + private def metricStats(metricStr: String): Seq[String] = { + val sum = metricStr.substring(0, metricStr.indexOf("(")).stripPrefix("\n").stripSuffix(" ") + val minMedMax = metricStr.substring(metricStr.indexOf("(") + 1, metricStr.indexOf(")")) + .split(", ").toSeq + (sum +: minMedMax) + } + + private def stringToBytes(str: String): (Float, String) = { + val matcher = Pattern.compile("([0-9]+(\\.[0-9]+)?) (EB|PB|TB|GB|MB|KB|B)").matcher(str) + if (matcher.matches()) { + (matcher.group(1).toFloat, matcher.group(3)) + } else { + throw new NumberFormatException("Failed to parse byte string: " + str) + } + } + + private def stringToDuration(str: String): (Float, String) = { + val matcher = Pattern.compile("([0-9]+(\\.[0-9]+)?) (ms|s|m|h)").matcher(str) + if (matcher.matches()) { + (matcher.group(1).toFloat, matcher.group(3)) + } else { + throw new NumberFormatException("Failed to parse time string: " + str) + } + } + + /** + * Convert a size metric string to a sequence of stats, including sum, min, med and max in order, + * each a tuple of (value, unit). + * @param metricStr size metric string, e.g. "\n96.2 MB (32.1 MB, 32.1 MB, 32.1 MB)" + * @return A sequence of stats, e.g. ((96.2,MB), (32.1,MB), (32.1,MB), (32.1,MB)) + */ + protected def sizeMetricStats(metricStr: String): Seq[(Float, String)] = { + metricStats(metricStr).map(stringToBytes) + } + + /** + * Convert a timing metric string to a sequence of stats, including sum, min, med and max in + * order, each a tuple of (value, unit). + * @param metricStr timing metric string, e.g. "\n2.0 ms (1.0 ms, 1.0 ms, 1.0 ms)" + * @return A sequence of stats, e.g. ((2.0,ms), (1.0,ms), (1.0,ms), (1.0,ms)) + */ + protected def timingMetricStats(metricStr: String): Seq[(Float, String)] = { + metricStats(metricStr).map(stringToDuration) + } } From 4ee2c8d7de4024dae9605e4ad18a8f541a234625 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Tue, 25 Dec 2018 15:34:27 +0800 Subject: [PATCH 04/11] MB to MiB per SPARK-25696 --- .../spark/sql/execution/metric/SQLMetricsTestUtils.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index c1c54816bff2..148366171024 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -208,7 +208,8 @@ trait SQLMetricsTestUtils extends SQLTestUtils { } private def stringToBytes(str: String): (Float, String) = { - val matcher = Pattern.compile("([0-9]+(\\.[0-9]+)?) (EB|PB|TB|GB|MB|KB|B)").matcher(str) + val matcher = + Pattern.compile("([0-9]+(\\.[0-9]+)?) (EiB|PiB|TiB|GiB|MiB|KiB|B)").matcher(str) if (matcher.matches()) { (matcher.group(1).toFloat, matcher.group(3)) } else { From 1e55f31e382cf67fd38ea8001d0b1d6b3bdcc586 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Thu, 27 Dec 2018 13:59:20 +0800 Subject: [PATCH 05/11] fix according to comments --- .../spark/sql/execution/metric/SQLMetricsSuite.scala | 7 ++++--- .../spark/sql/execution/metric/SQLMetricsTestUtils.scala | 9 ++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 36fbfff2ec9e..9f56416d3c79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -203,17 +203,18 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared // so Project here is not collapsed into LocalTableScan. val df = Seq(1, 3, 2).toDF("id").sort('id) val metrics = getSparkPlanMetrics(df, 2, Set(0)) + assert(metrics.isDefined) val sortMetrics = metrics.get.get(0).get // Check node 0 is Sort node val operatorName = sortMetrics._1 assert(operatorName == "Sort") // Check metrics values val sortTimeStr = sortMetrics._2.get("sort time total (min, med, max)").get.toString - timingMetricStats(sortTimeStr).foreach { case (sortTime, _) => assert(sortTime >= 0) } + assert(timingMetricStats(sortTimeStr).forall { case (sortTime, _) => sortTime >= 0 }) val peakMemoryStr = sortMetrics._2.get("peak memory total (min, med, max)").get.toString - sizeMetricStats(peakMemoryStr).foreach { case (peakMemory, _) => assert(peakMemory > 0) } + assert(sizeMetricStats(peakMemoryStr).forall { case (peakMemory, _) => peakMemory > 0 }) val spillSizeStr = sortMetrics._2.get("spill size total (min, med, max)").get.toString - sizeMetricStats(spillSizeStr).foreach { case (spillSize, _) => assert(spillSize >= 0) } + assert(sizeMetricStats(spillSizeStr).forall { case (spillSize, _) => spillSize >= 0 }) } test("SortMergeJoin metrics") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 148366171024..219dc2576695 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -41,6 +41,10 @@ trait SQLMetricsTestUtils extends SQLTestUtils { protected def statusStore: SQLAppStatusStore = spark.sharedState.statusStore + protected val bytesPattern = Pattern.compile("([0-9]+(\\.[0-9]+)?) (EiB|PiB|TiB|GiB|MiB|KiB|B)") + + protected val durationPattern = Pattern.compile("([0-9]+(\\.[0-9]+)?) (ms|s|m|h)") + /** * Get execution metrics for the SQL execution and verify metrics values. * @@ -208,8 +212,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { } private def stringToBytes(str: String): (Float, String) = { - val matcher = - Pattern.compile("([0-9]+(\\.[0-9]+)?) (EiB|PiB|TiB|GiB|MiB|KiB|B)").matcher(str) + val matcher = bytesPattern.matcher(str) if (matcher.matches()) { (matcher.group(1).toFloat, matcher.group(3)) } else { @@ -218,7 +221,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { } private def stringToDuration(str: String): (Float, String) = { - val matcher = Pattern.compile("([0-9]+(\\.[0-9]+)?) (ms|s|m|h)").matcher(str) + val matcher = durationPattern.matcher(str) if (matcher.matches()) { (matcher.group(1).toFloat, matcher.group(3)) } else { From c3336d8568f06a7fde98e18cac3c01eceb496344 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Fri, 28 Dec 2018 18:21:40 +0800 Subject: [PATCH 06/11] add testSparkPlanMetricsWithPredicates and comments for sort time --- .../execution/metric/SQLMetricsSuite.scala | 21 +++----- .../metric/SQLMetricsTestUtils.scala | 53 ++++++++++++++++--- 2 files changed, 55 insertions(+), 19 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 9f56416d3c79..e240a1bf9338 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -202,19 +202,14 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared // Because of SPARK-25267, ConvertToLocalRelation is disabled in the test cases of sql/core, // so Project here is not collapsed into LocalTableScan. val df = Seq(1, 3, 2).toDF("id").sort('id) - val metrics = getSparkPlanMetrics(df, 2, Set(0)) - assert(metrics.isDefined) - val sortMetrics = metrics.get.get(0).get - // Check node 0 is Sort node - val operatorName = sortMetrics._1 - assert(operatorName == "Sort") - // Check metrics values - val sortTimeStr = sortMetrics._2.get("sort time total (min, med, max)").get.toString - assert(timingMetricStats(sortTimeStr).forall { case (sortTime, _) => sortTime >= 0 }) - val peakMemoryStr = sortMetrics._2.get("peak memory total (min, med, max)").get.toString - assert(sizeMetricStats(peakMemoryStr).forall { case (peakMemory, _) => peakMemory > 0 }) - val spillSizeStr = sortMetrics._2.get("spill size total (min, med, max)").get.toString - assert(sizeMetricStats(spillSizeStr).forall { case (spillSize, _) => spillSize >= 0 }) + testSparkPlanMetricsWithPredicates(df, 2, Map( + 0L -> (("Sort", Map( + // In SortExec, sort time is collected as nanoseconds, but it is converted and stored as + // milliseconds. So sort time may be 0 if sort is executed very fast. + "sort time total (min, med, max)" -> timingMetricAllStatsShould(_ >= 0), + "peak memory total (min, med, max)" -> sizeMetricAllStatsShould(_ > 0), + "spill size total (min, med, max)" -> sizeMetricAllStatsShould(_ >= 0)))) + )) } test("SortMergeJoin metrics") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 219dc2576695..16d1e8933360 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -190,15 +190,34 @@ trait SQLMetricsTestUtils extends SQLTestUtils { df: DataFrame, expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - val optActualMetrics = getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetrics.keySet) + val expectedMetricsPredicates = expectedMetrics.mapValues { case (nodeName, nodeMetrics) => + (nodeName, nodeMetrics.mapValues(expectedMetricValue => + (actualMetricValue: Any) => expectedMetricValue.toString === actualMetricValue) + )} + testSparkPlanMetricsWithPredicates(df, expectedNumOfJobs, expectedMetricsPredicates) + } + + /** + * Call `df.collect()` and verify if the collected metrics satisfy the specified predicates. + * @param df `DataFrame` to run + * @param expectedNumOfJobs number of jobs that will run + * @param expectedMetricsPredicates the expected metrics predicates. The format is + * `nodeId -> (operatorName, metric name -> metric value predicate)`. + */ + protected def testSparkPlanMetricsWithPredicates( + df: DataFrame, + expectedNumOfJobs: Int, + expectedMetricsPredicates: Map[Long, (String, Map[String, Any => Boolean])]): Unit = { + val optActualMetrics = + getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetricsPredicates.keySet) optActualMetrics.foreach { actualMetrics => - assert(expectedMetrics.keySet === actualMetrics.keySet) - for (nodeId <- expectedMetrics.keySet) { - val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId) + assert(expectedMetricsPredicates.keySet === actualMetrics.keySet) + for (nodeId <- expectedMetricsPredicates.keySet) { + val (expectedNodeName, expectedMetricsPredicatesMap) = expectedMetricsPredicates(nodeId) val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) assert(expectedNodeName === actualNodeName) - for (metricName <- expectedMetricsMap.keySet) { - assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName)) + for (metricName <- expectedMetricsPredicatesMap.keySet) { + assert(expectedMetricsPredicatesMap(metricName)(actualMetricsMap(metricName))) } } } @@ -248,6 +267,28 @@ trait SQLMetricsTestUtils extends SQLTestUtils { protected def timingMetricStats(metricStr: String): Seq[(Float, String)] = { metricStats(metricStr).map(stringToDuration) } + + /** + * Returns a function to check whether all stats (sum, min, med and max) of a timing metric + * satisfy the specified predicate. + * @param predicate predicate to check stats + * @return function to check all stats of a timing metric + */ + protected def timingMetricAllStatsShould(predicate: Float => Boolean): Any => Boolean = { + (timingMetric: Any) => + timingMetricStats(timingMetric.toString).forall { case (duration, _) => predicate(duration) } + } + + /** + * Returns a function to check whether all stats (sum, min, med and max) of a size metric satisfy + * the specified predicate. + * @param predicate predicate to check stats + * @return function to check all stats of a size metric + */ + protected def sizeMetricAllStatsShould(predicate: Float => Boolean): Any => Boolean = { + (sizeMetric: Any) => + sizeMetricStats(sizeMetric.toString).forall { case (bytes, _) => predicate(bytes)} + } } From 75d0c082bc0f3fc6b0d8996d3b717cd17ddc4aa9 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Fri, 28 Dec 2018 23:53:52 +0800 Subject: [PATCH 07/11] fix indentation --- .../spark/sql/execution/metric/SQLMetricsTestUtils.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 16d1e8933360..4e6a66520947 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -199,10 +199,11 @@ trait SQLMetricsTestUtils extends SQLTestUtils { /** * Call `df.collect()` and verify if the collected metrics satisfy the specified predicates. + * * @param df `DataFrame` to run * @param expectedNumOfJobs number of jobs that will run * @param expectedMetricsPredicates the expected metrics predicates. The format is - * `nodeId -> (operatorName, metric name -> metric value predicate)`. + * `nodeId -> (operatorName, metric name -> metric value predicate)`. */ protected def testSparkPlanMetricsWithPredicates( df: DataFrame, From 386a7e5d8857caa2e494e4de38651642c8c2f0d2 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Sat, 29 Dec 2018 13:42:39 +0800 Subject: [PATCH 08/11] test metrics by pattern matching --- .../execution/metric/SQLMetricsSuite.scala | 12 +- .../metric/SQLMetricsTestUtils.scala | 141 ++++++------------ 2 files changed, 52 insertions(+), 101 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index e240a1bf9338..94f50a26ed3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -202,14 +202,12 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared // Because of SPARK-25267, ConvertToLocalRelation is disabled in the test cases of sql/core, // so Project here is not collapsed into LocalTableScan. val df = Seq(1, 3, 2).toDF("id").sort('id) - testSparkPlanMetricsWithPredicates(df, 2, Map( + testSparkPlanMetrics(df, 2, Map( 0L -> (("Sort", Map( - // In SortExec, sort time is collected as nanoseconds, but it is converted and stored as - // milliseconds. So sort time may be 0 if sort is executed very fast. - "sort time total (min, med, max)" -> timingMetricAllStatsShould(_ >= 0), - "peak memory total (min, med, max)" -> sizeMetricAllStatsShould(_ > 0), - "spill size total (min, med, max)" -> sizeMetricAllStatsShould(_ >= 0)))) - )) + "sort time total (min, med, max)" -> timingMetricPattern, + "peak memory total (min, med, max)" -> sizeMetricPattern, + "spill size total (min, med, max)" -> sizeMetricPattern)))), + ExpectedMetricsType.PATTERN) } test("SortMergeJoin metrics") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 4e6a66520947..719cf71df8b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -41,9 +41,21 @@ trait SQLMetricsTestUtils extends SQLTestUtils { protected def statusStore: SQLAppStatusStore = spark.sharedState.statusStore - protected val bytesPattern = Pattern.compile("([0-9]+(\\.[0-9]+)?) (EiB|PiB|TiB|GiB|MiB|KiB|B)") + protected object ExpectedMetricsType extends Enumeration { + type ExpectedMetricsType = Value + val VALUE, PATTERN, PATTERN_STRING = Value + } + + protected val bytes = "([0-9]+(\\.[0-9]+)?) (EiB|PiB|TiB|GiB|MiB|KiB|B)" + + protected val duration = "([0-9]+(\\.[0-9]+)?) (ms|s|m|h)" + + // "\n96.2 MiB (32.1 MiB, 32.1 MiB, 32.1 MiB)" + protected val sizeMetricPattern = Pattern.compile(s"\\n$bytes \\($bytes, $bytes, $bytes\\)") - protected val durationPattern = Pattern.compile("([0-9]+(\\.[0-9]+)?) (ms|s|m|h)") + // "\n2.0 ms (1.0 ms, 1.0 ms, 1.0 ms)" + protected val timingMetricPattern = + Pattern.compile(s"\\n$duration \\($duration, $duration, $duration\\)") /** * Get execution metrics for the SQL execution and verify metrics values. @@ -179,116 +191,57 @@ trait SQLMetricsTestUtils extends SQLTestUtils { } /** - * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". + * Call `df.collect()` and verify if the collected metrics match "expectedMetrics". By + * "expectedMetrics", you can either specify exact metric value or metric value pattern. A pattern + * can be a regex string or a compiled `Pattern` object. * * @param df `DataFrame` to run * @param expectedNumOfJobs number of jobs that will run * @param expectedMetrics the expected metrics. The format is - * `nodeId -> (operatorName, metric name -> metric value)`. + * `nodeId -> (operatorName, metric name -> metric value or pattern)`. + * @param expectedMetricsType the type of the expected metrics. */ protected def testSparkPlanMetrics( df: DataFrame, expectedNumOfJobs: Int, - expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - val expectedMetricsPredicates = expectedMetrics.mapValues { case (nodeName, nodeMetrics) => - (nodeName, nodeMetrics.mapValues(expectedMetricValue => - (actualMetricValue: Any) => expectedMetricValue.toString === actualMetricValue) - )} - testSparkPlanMetricsWithPredicates(df, expectedNumOfJobs, expectedMetricsPredicates) - } - - /** - * Call `df.collect()` and verify if the collected metrics satisfy the specified predicates. - * - * @param df `DataFrame` to run - * @param expectedNumOfJobs number of jobs that will run - * @param expectedMetricsPredicates the expected metrics predicates. The format is - * `nodeId -> (operatorName, metric name -> metric value predicate)`. - */ - protected def testSparkPlanMetricsWithPredicates( - df: DataFrame, - expectedNumOfJobs: Int, - expectedMetricsPredicates: Map[Long, (String, Map[String, Any => Boolean])]): Unit = { - val optActualMetrics = - getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetricsPredicates.keySet) + expectedMetrics: Map[Long, (String, Map[String, Any])], + expectedMetricsType: ExpectedMetricsType.Value): Unit = { + val optActualMetrics = getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetrics.keySet) optActualMetrics.foreach { actualMetrics => - assert(expectedMetricsPredicates.keySet === actualMetrics.keySet) - for (nodeId <- expectedMetricsPredicates.keySet) { - val (expectedNodeName, expectedMetricsPredicatesMap) = expectedMetricsPredicates(nodeId) + assert(expectedMetrics.keySet === actualMetrics.keySet) + for (nodeId <- expectedMetrics.keySet) { + val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId) val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) assert(expectedNodeName === actualNodeName) - for (metricName <- expectedMetricsPredicatesMap.keySet) { - assert(expectedMetricsPredicatesMap(metricName)(actualMetricsMap(metricName))) + for (metricName <- expectedMetricsMap.keySet) { + expectedMetricsType match { + case ExpectedMetricsType.VALUE => + assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName)) + case ExpectedMetricsType.PATTERN => + assert(expectedMetricsMap(metricName).asInstanceOf[Pattern].matcher( + actualMetricsMap(metricName).toString).matches()) + case ExpectedMetricsType.PATTERN_STRING => + assert(Pattern.compile(expectedMetricsMap(metricName).toString).matcher( + actualMetricsMap(metricName).toString).matches()) + } } } } } - private def metricStats(metricStr: String): Seq[String] = { - val sum = metricStr.substring(0, metricStr.indexOf("(")).stripPrefix("\n").stripSuffix(" ") - val minMedMax = metricStr.substring(metricStr.indexOf("(") + 1, metricStr.indexOf(")")) - .split(", ").toSeq - (sum +: minMedMax) - } - - private def stringToBytes(str: String): (Float, String) = { - val matcher = bytesPattern.matcher(str) - if (matcher.matches()) { - (matcher.group(1).toFloat, matcher.group(3)) - } else { - throw new NumberFormatException("Failed to parse byte string: " + str) - } - } - - private def stringToDuration(str: String): (Float, String) = { - val matcher = durationPattern.matcher(str) - if (matcher.matches()) { - (matcher.group(1).toFloat, matcher.group(3)) - } else { - throw new NumberFormatException("Failed to parse time string: " + str) - } - } - - /** - * Convert a size metric string to a sequence of stats, including sum, min, med and max in order, - * each a tuple of (value, unit). - * @param metricStr size metric string, e.g. "\n96.2 MB (32.1 MB, 32.1 MB, 32.1 MB)" - * @return A sequence of stats, e.g. ((96.2,MB), (32.1,MB), (32.1,MB), (32.1,MB)) - */ - protected def sizeMetricStats(metricStr: String): Seq[(Float, String)] = { - metricStats(metricStr).map(stringToBytes) - } - - /** - * Convert a timing metric string to a sequence of stats, including sum, min, med and max in - * order, each a tuple of (value, unit). - * @param metricStr timing metric string, e.g. "\n2.0 ms (1.0 ms, 1.0 ms, 1.0 ms)" - * @return A sequence of stats, e.g. ((2.0,ms), (1.0,ms), (1.0,ms), (1.0,ms)) - */ - protected def timingMetricStats(metricStr: String): Seq[(Float, String)] = { - metricStats(metricStr).map(stringToDuration) - } - - /** - * Returns a function to check whether all stats (sum, min, med and max) of a timing metric - * satisfy the specified predicate. - * @param predicate predicate to check stats - * @return function to check all stats of a timing metric - */ - protected def timingMetricAllStatsShould(predicate: Float => Boolean): Any => Boolean = { - (timingMetric: Any) => - timingMetricStats(timingMetric.toString).forall { case (duration, _) => predicate(duration) } - } - /** - * Returns a function to check whether all stats (sum, min, med and max) of a size metric satisfy - * the specified predicate. - * @param predicate predicate to check stats - * @return function to check all stats of a size metric + * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". + * + * @param df `DataFrame` to run + * @param expectedNumOfJobs number of jobs that will run + * @param expectedMetrics the expected metrics. The format is + * `nodeId -> (operatorName, metric name -> metric value)`. */ - protected def sizeMetricAllStatsShould(predicate: Float => Boolean): Any => Boolean = { - (sizeMetric: Any) => - sizeMetricStats(sizeMetric.toString).forall { case (bytes, _) => predicate(bytes)} + protected def testSparkPlanMetrics( + df: DataFrame, + expectedNumOfJobs: Int, + expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { + testSparkPlanMetrics(df, expectedNumOfJobs, expectedMetrics, ExpectedMetricsType.VALUE) } } From c496c544e27d11a083493a5d20aa82cfa4bb2a25 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Sun, 30 Dec 2018 20:19:56 +0800 Subject: [PATCH 09/11] use testSparkPlanMetricsWithPredicates and checkPattern together --- .../execution/metric/SQLMetricsSuite.scala | 10 ++-- .../metric/SQLMetricsTestUtils.scala | 60 ++++++++++--------- 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 94f50a26ed3f..3d4a2de5d4ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -202,12 +202,12 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared // Because of SPARK-25267, ConvertToLocalRelation is disabled in the test cases of sql/core, // so Project here is not collapsed into LocalTableScan. val df = Seq(1, 3, 2).toDF("id").sort('id) - testSparkPlanMetrics(df, 2, Map( + testSparkPlanMetricsWithPredicates(df, 2, Map( 0L -> (("Sort", Map( - "sort time total (min, med, max)" -> timingMetricPattern, - "peak memory total (min, med, max)" -> sizeMetricPattern, - "spill size total (min, med, max)" -> sizeMetricPattern)))), - ExpectedMetricsType.PATTERN) + "sort time total (min, med, max)" -> checkPattern(timingMetricPattern), + "peak memory total (min, med, max)" -> checkPattern(sizeMetricPattern), + "spill size total (min, med, max)" -> checkPattern(sizeMetricPattern)))) + )) } test("SortMergeJoin metrics") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 719cf71df8b9..249807dc7922 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -41,11 +41,6 @@ trait SQLMetricsTestUtils extends SQLTestUtils { protected def statusStore: SQLAppStatusStore = spark.sharedState.statusStore - protected object ExpectedMetricsType extends Enumeration { - type ExpectedMetricsType = Value - val VALUE, PATTERN, PATTERN_STRING = Value - } - protected val bytes = "([0-9]+(\\.[0-9]+)?) (EiB|PiB|TiB|GiB|MiB|KiB|B)" protected val duration = "([0-9]+(\\.[0-9]+)?) (ms|s|m|h)" @@ -57,6 +52,15 @@ trait SQLMetricsTestUtils extends SQLTestUtils { protected val timingMetricPattern = Pattern.compile(s"\\n$duration \\($duration, $duration, $duration\\)") + /** Generate a function to check the specified pattern. + * + * @param pattern a pattern + * @return a function to check the specified pattern + */ + protected def checkPattern(pattern: Pattern): (Any => Boolean) = { + (in: Any) => pattern.matcher(in.toString).matches() + } + /** * Get execution metrics for the SQL execution and verify metrics values. * @@ -191,21 +195,17 @@ trait SQLMetricsTestUtils extends SQLTestUtils { } /** - * Call `df.collect()` and verify if the collected metrics match "expectedMetrics". By - * "expectedMetrics", you can either specify exact metric value or metric value pattern. A pattern - * can be a regex string or a compiled `Pattern` object. + * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". * * @param df `DataFrame` to run * @param expectedNumOfJobs number of jobs that will run * @param expectedMetrics the expected metrics. The format is - * `nodeId -> (operatorName, metric name -> metric value or pattern)`. - * @param expectedMetricsType the type of the expected metrics. + * `nodeId -> (operatorName, metric name -> metric value)`. */ protected def testSparkPlanMetrics( df: DataFrame, expectedNumOfJobs: Int, - expectedMetrics: Map[Long, (String, Map[String, Any])], - expectedMetricsType: ExpectedMetricsType.Value): Unit = { + expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { val optActualMetrics = getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetrics.keySet) optActualMetrics.foreach { actualMetrics => assert(expectedMetrics.keySet === actualMetrics.keySet) @@ -214,34 +214,36 @@ trait SQLMetricsTestUtils extends SQLTestUtils { val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) assert(expectedNodeName === actualNodeName) for (metricName <- expectedMetricsMap.keySet) { - expectedMetricsType match { - case ExpectedMetricsType.VALUE => - assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName)) - case ExpectedMetricsType.PATTERN => - assert(expectedMetricsMap(metricName).asInstanceOf[Pattern].matcher( - actualMetricsMap(metricName).toString).matches()) - case ExpectedMetricsType.PATTERN_STRING => - assert(Pattern.compile(expectedMetricsMap(metricName).toString).matcher( - actualMetricsMap(metricName).toString).matches()) - } + assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName)) } } } } /** - * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". - * + * Call `df.collect()` and verify if the collected metrics satisfy the specified predicates. * @param df `DataFrame` to run * @param expectedNumOfJobs number of jobs that will run - * @param expectedMetrics the expected metrics. The format is - * `nodeId -> (operatorName, metric name -> metric value)`. + * @param expectedMetricsPredicates the expected metrics predicates. The format is + * `nodeId -> (operatorName, metric name -> metric value predicate)`. */ - protected def testSparkPlanMetrics( + protected def testSparkPlanMetricsWithPredicates( df: DataFrame, expectedNumOfJobs: Int, - expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - testSparkPlanMetrics(df, expectedNumOfJobs, expectedMetrics, ExpectedMetricsType.VALUE) + expectedMetricsPredicates: Map[Long, (String, Map[String, Any => Boolean])]): Unit = { + val optActualMetrics = + getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetricsPredicates.keySet) + optActualMetrics.foreach { actualMetrics => + assert(expectedMetricsPredicates.keySet === actualMetrics.keySet) + for (nodeId <- expectedMetricsPredicates.keySet) { + val (expectedNodeName, expectedMetricsPredicatesMap) = expectedMetricsPredicates(nodeId) + val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) + assert(expectedNodeName === actualNodeName) + for (metricName <- expectedMetricsPredicatesMap.keySet) { + assert(expectedMetricsPredicatesMap(metricName)(actualMetricsMap(metricName))) + } + } + } } } From 3ce0e0373f431384b60d257f07d29257d8326488 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Mon, 31 Dec 2018 00:22:30 +0800 Subject: [PATCH 10/11] fix according to comments --- .../execution/metric/SQLMetricsSuite.scala | 6 +-- .../metric/SQLMetricsTestUtils.scala | 48 +++++++------------ 2 files changed, 19 insertions(+), 35 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 3d4a2de5d4ec..bd7e57197d12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -204,9 +204,9 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared val df = Seq(1, 3, 2).toDF("id").sort('id) testSparkPlanMetricsWithPredicates(df, 2, Map( 0L -> (("Sort", Map( - "sort time total (min, med, max)" -> checkPattern(timingMetricPattern), - "peak memory total (min, med, max)" -> checkPattern(sizeMetricPattern), - "spill size total (min, med, max)" -> checkPattern(sizeMetricPattern)))) + "sort time total (min, med, max)" -> {_.toString.matches(timingMetricPattern)}, + "peak memory total (min, med, max)" -> {_.toString.matches(sizeMetricPattern)}, + "spill size total (min, med, max)" -> {_.toString.matches(sizeMetricPattern)}))) )) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 249807dc7922..2318f4086b2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.metric import java.io.File -import java.util.regex.Pattern import scala.collection.mutable.HashMap @@ -41,24 +40,16 @@ trait SQLMetricsTestUtils extends SQLTestUtils { protected def statusStore: SQLAppStatusStore = spark.sharedState.statusStore - protected val bytes = "([0-9]+(\\.[0-9]+)?) (EiB|PiB|TiB|GiB|MiB|KiB|B)" - - protected val duration = "([0-9]+(\\.[0-9]+)?) (ms|s|m|h)" - - // "\n96.2 MiB (32.1 MiB, 32.1 MiB, 32.1 MiB)" - protected val sizeMetricPattern = Pattern.compile(s"\\n$bytes \\($bytes, $bytes, $bytes\\)") - - // "\n2.0 ms (1.0 ms, 1.0 ms, 1.0 ms)" - protected val timingMetricPattern = - Pattern.compile(s"\\n$duration \\($duration, $duration, $duration\\)") + // Pattern of size SQLMetric value, e.g. "\n96.2 MiB (32.1 MiB, 32.1 MiB, 32.1 MiB)" + protected val sizeMetricPattern = { + val bytes = "([0-9]+(\\.[0-9]+)?) (EiB|PiB|TiB|GiB|MiB|KiB|B)" + s"\\n$bytes \\($bytes, $bytes, $bytes\\)" + } - /** Generate a function to check the specified pattern. - * - * @param pattern a pattern - * @return a function to check the specified pattern - */ - protected def checkPattern(pattern: Pattern): (Any => Boolean) = { - (in: Any) => pattern.matcher(in.toString).matches() + // Pattern of timing SQLMetric value, e.g. "\n2.0 ms (1.0 ms, 1.0 ms, 1.0 ms)" + protected val timingMetricPattern = { + val duration = "([0-9]+(\\.[0-9]+)?) (ms|s|m|h)" + s"\\n$duration \\($duration, $duration, $duration\\)" } /** @@ -206,18 +197,11 @@ trait SQLMetricsTestUtils extends SQLTestUtils { df: DataFrame, expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - val optActualMetrics = getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetrics.keySet) - optActualMetrics.foreach { actualMetrics => - assert(expectedMetrics.keySet === actualMetrics.keySet) - for (nodeId <- expectedMetrics.keySet) { - val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId) - val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) - assert(expectedNodeName === actualNodeName) - for (metricName <- expectedMetricsMap.keySet) { - assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName)) - } - } + val expectedMetricsPredicates = expectedMetrics.mapValues { case (nodeName, nodeMetrics) => + (nodeName, nodeMetrics.mapValues(expectedMetricValue => + (actualMetricValue: Any) => expectedMetricValue.toString === actualMetricValue)) } + testSparkPlanMetricsWithPredicates(df, expectedNumOfJobs, expectedMetricsPredicates) } /** @@ -225,7 +209,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { * @param df `DataFrame` to run * @param expectedNumOfJobs number of jobs that will run * @param expectedMetricsPredicates the expected metrics predicates. The format is - * `nodeId -> (operatorName, metric name -> metric value predicate)`. + * `nodeId -> (operatorName, metric name -> metric predicate)`. */ protected def testSparkPlanMetricsWithPredicates( df: DataFrame, @@ -235,8 +219,8 @@ trait SQLMetricsTestUtils extends SQLTestUtils { getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetricsPredicates.keySet) optActualMetrics.foreach { actualMetrics => assert(expectedMetricsPredicates.keySet === actualMetrics.keySet) - for (nodeId <- expectedMetricsPredicates.keySet) { - val (expectedNodeName, expectedMetricsPredicatesMap) = expectedMetricsPredicates(nodeId) + for ((nodeId, (expectedNodeName, expectedMetricsPredicatesMap)) + <- expectedMetricsPredicates) { val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) assert(expectedNodeName === actualNodeName) for (metricName <- expectedMetricsPredicatesMap.keySet) { From 5e94a3eb8bb515ee9638c7767882ddd559d7b6ac Mon Sep 17 00:00:00 2001 From: seancxmao Date: Mon, 31 Dec 2018 06:41:56 +0800 Subject: [PATCH 11/11] update according to comments --- .../apache/spark/sql/execution/metric/SQLMetricsSuite.scala | 2 +- .../spark/sql/execution/metric/SQLMetricsTestUtils.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index bd7e57197d12..7368a6c9e1d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.{FilterExec, RangeExec, SortExec, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 2318f4086b2d..2d245d2ba1e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -223,8 +223,8 @@ trait SQLMetricsTestUtils extends SQLTestUtils { <- expectedMetricsPredicates) { val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) assert(expectedNodeName === actualNodeName) - for (metricName <- expectedMetricsPredicatesMap.keySet) { - assert(expectedMetricsPredicatesMap(metricName)(actualMetricsMap(metricName))) + for ((metricName, metricPredicate) <- expectedMetricsPredicatesMap) { + assert(metricPredicate(actualMetricsMap(metricName))) } } }