Skip to content

Commit 1e55f31

Browse files
committed
fix according to comments
1 parent 4ee2c8d commit 1e55f31

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,17 +203,18 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
203203
// so Project here is not collapsed into LocalTableScan.
204204
val df = Seq(1, 3, 2).toDF("id").sort('id)
205205
val metrics = getSparkPlanMetrics(df, 2, Set(0))
206+
assert(metrics.isDefined)
206207
val sortMetrics = metrics.get.get(0).get
207208
// Check node 0 is Sort node
208209
val operatorName = sortMetrics._1
209210
assert(operatorName == "Sort")
210211
// Check metrics values
211212
val sortTimeStr = sortMetrics._2.get("sort time total (min, med, max)").get.toString
212-
timingMetricStats(sortTimeStr).foreach { case (sortTime, _) => assert(sortTime >= 0) }
213+
assert(timingMetricStats(sortTimeStr).forall { case (sortTime, _) => sortTime >= 0 })
213214
val peakMemoryStr = sortMetrics._2.get("peak memory total (min, med, max)").get.toString
214-
sizeMetricStats(peakMemoryStr).foreach { case (peakMemory, _) => assert(peakMemory > 0) }
215+
assert(sizeMetricStats(peakMemoryStr).forall { case (peakMemory, _) => peakMemory > 0 })
215216
val spillSizeStr = sortMetrics._2.get("spill size total (min, med, max)").get.toString
216-
sizeMetricStats(spillSizeStr).foreach { case (spillSize, _) => assert(spillSize >= 0) }
217+
assert(sizeMetricStats(spillSizeStr).forall { case (spillSize, _) => spillSize >= 0 })
217218
}
218219

219220
test("SortMergeJoin metrics") {

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
4141

4242
protected def statusStore: SQLAppStatusStore = spark.sharedState.statusStore
4343

44+
protected val bytesPattern = Pattern.compile("([0-9]+(\\.[0-9]+)?) (EiB|PiB|TiB|GiB|MiB|KiB|B)")
45+
46+
protected val durationPattern = Pattern.compile("([0-9]+(\\.[0-9]+)?) (ms|s|m|h)")
47+
4448
/**
4549
* Get execution metrics for the SQL execution and verify metrics values.
4650
*
@@ -208,8 +212,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
208212
}
209213

210214
private def stringToBytes(str: String): (Float, String) = {
211-
val matcher =
212-
Pattern.compile("([0-9]+(\\.[0-9]+)?) (EiB|PiB|TiB|GiB|MiB|KiB|B)").matcher(str)
215+
val matcher = bytesPattern.matcher(str)
213216
if (matcher.matches()) {
214217
(matcher.group(1).toFloat, matcher.group(3))
215218
} else {
@@ -218,7 +221,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
218221
}
219222

220223
private def stringToDuration(str: String): (Float, String) = {
221-
val matcher = Pattern.compile("([0-9]+(\\.[0-9]+)?) (ms|s|m|h)").matcher(str)
224+
val matcher = durationPattern.matcher(str)
222225
if (matcher.matches()) {
223226
(matcher.group(1).toFloat, matcher.group(3))
224227
} else {

0 commit comments

Comments
 (0)