Skip to content

Commit d905e85

Browse files
tashoyanMarcelo Vanzin
authored andcommitted
[SPARK-22471][SQL] SQLListener consumes much memory causing OutOfMemoryError
## What changes were proposed in this pull request? This PR addresses the issue [SPARK-22471](https://issues.apache.org/jira/browse/SPARK-22471). The modified version of `SQLListener` respects the setting `spark.ui.retainedStages` and keeps the number of the tracked stages within the specified limit. The hash map `_stageIdToStageMetrics` does not outgrow the limit, hence overall memory consumption does not grow with time anymore. A 2.2-compatible fix. Maybe incompatible with 2.3 due to #19681. ## How was this patch tested? A new unit test covers this fix - see `SQLListenerMemorySuite.scala`. Author: Arseniy Tashoyan <[email protected]> Closes #19711 from tashoyan/SPARK-22471-branch-2.2.
1 parent af0b185 commit d905e85

File tree

3 files changed

+119
-47
lines changed

3 files changed

+119
-47
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ class SQLListener(conf: SparkConf) extends SparkListener with Logging {
101101

102102
private val retainedExecutions = conf.getInt("spark.sql.ui.retainedExecutions", 1000)
103103

104+
private val retainedStages = conf.getInt("spark.ui.retainedStages",
105+
SparkUI.DEFAULT_RETAINED_STAGES)
106+
104107
private val activeExecutions = mutable.HashMap[Long, SQLExecutionUIData]()
105108

106109
// Old data in the following fields must be removed in "trimExecutionsIfNecessary".
@@ -113,7 +116,7 @@ class SQLListener(conf: SparkConf) extends SparkListener with Logging {
113116
*/
114117
private val _jobIdToExecutionId = mutable.HashMap[Long, Long]()
115118

116-
private val _stageIdToStageMetrics = mutable.HashMap[Long, SQLStageMetrics]()
119+
private val _stageIdToStageMetrics = mutable.LinkedHashMap[Long, SQLStageMetrics]()
117120

118121
private val failedExecutions = mutable.ListBuffer[SQLExecutionUIData]()
119122

@@ -207,6 +210,14 @@ class SQLListener(conf: SparkConf) extends SparkListener with Logging {
207210
}
208211
}
209212

213+
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized {
214+
val extraStages = _stageIdToStageMetrics.size - retainedStages
215+
if (extraStages > 0) {
216+
val toRemove = _stageIdToStageMetrics.take(extraStages).keys
217+
_stageIdToStageMetrics --= toRemove
218+
}
219+
}
220+
210221
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
211222
if (taskEnd.taskMetrics != null) {
212223
updateTaskAccumulatorValues(
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.ui
19+
20+
import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
21+
import org.apache.spark.LocalSparkContext.withSpark
22+
import org.apache.spark.internal.config
23+
import org.apache.spark.sql.{Column, SparkSession}
24+
import org.apache.spark.sql.catalyst.util.quietly
25+
import org.apache.spark.sql.functions._
26+
27+
class SQLListenerMemorySuite extends SparkFunSuite {
28+
29+
test("SPARK-22471 - _stageIdToStageMetrics grows too large on long executions") {
30+
quietly {
31+
val conf = new SparkConf()
32+
.setMaster("local[*]")
33+
.setAppName("MemoryLeakTest")
34+
/* Don't retry the tasks to run this test quickly */
35+
.set(config.MAX_TASK_FAILURES, 1)
36+
.set("spark.ui.retainedStages", "50")
37+
withSpark(new SparkContext(conf)) { sc =>
38+
SparkSession.sqlListener.set(null)
39+
val spark = new SparkSession(sc)
40+
import spark.implicits._
41+
42+
val sample = List(
43+
(1, 10),
44+
(2, 20),
45+
(3, 30)
46+
).toDF("id", "value")
47+
48+
/* Some complex computation with many stages. */
49+
val joins = 1 to 100
50+
val summedCol: Column = joins
51+
.map(j => col(s"value$j"))
52+
.reduce(_ + _)
53+
val res = joins
54+
.map { j =>
55+
sample.select($"id", $"value" * j as s"value$j")
56+
}
57+
.reduce(_.join(_, "id"))
58+
.select($"id", summedCol as "value")
59+
.groupBy("id")
60+
.agg(sum($"value") as "value")
61+
.orderBy("id")
62+
res.collect()
63+
64+
sc.listenerBus.waitUntilEmpty(10000)
65+
assert(spark.sharedState.listener.stageIdToStageMetrics.size <= 50)
66+
}
67+
}
68+
}
69+
70+
test("no memory leak") {
71+
quietly {
72+
val conf = new SparkConf()
73+
.setMaster("local")
74+
.setAppName("test")
75+
.set(config.MAX_TASK_FAILURES, 1) // Don't retry the tasks to run this test quickly
76+
.set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly
77+
withSpark(new SparkContext(conf)) { sc =>
78+
SparkSession.sqlListener.set(null)
79+
val spark = new SparkSession(sc)
80+
import spark.implicits._
81+
// Run 100 successful executions and 100 failed executions.
82+
// Each execution only has one job and one stage.
83+
for (i <- 0 until 100) {
84+
val df = Seq(
85+
(1, 1),
86+
(2, 2)
87+
).toDF()
88+
df.collect()
89+
try {
90+
df.foreach(_ => throw new RuntimeException("Oops"))
91+
} catch {
92+
case e: SparkException => // This is expected for a failed job
93+
}
94+
}
95+
sc.listenerBus.waitUntilEmpty(10000)
96+
assert(spark.sharedState.listener.getCompletedExecutions.size <= 50)
97+
assert(spark.sharedState.listener.getFailedExecutions.size <= 50)
98+
// 50 for successful executions and 50 for failed executions
99+
assert(spark.sharedState.listener.executionIdToData.size <= 100)
100+
assert(spark.sharedState.listener.jobIdToExecutionId.size <= 100)
101+
assert(spark.sharedState.listener.stageIdToStageMetrics.size <= 100)
102+
}
103+
}
104+
}
105+
106+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,12 @@ import org.mockito.Mockito.mock
2424

2525
import org.apache.spark._
2626
import org.apache.spark.executor.TaskMetrics
27-
import org.apache.spark.internal.config
2827
import org.apache.spark.rdd.RDD
2928
import org.apache.spark.scheduler._
30-
import org.apache.spark.sql.{DataFrame, SparkSession}
29+
import org.apache.spark.sql.DataFrame
3130
import org.apache.spark.sql.catalyst.InternalRow
3231
import org.apache.spark.sql.catalyst.expressions.Attribute
3332
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
34-
import org.apache.spark.sql.catalyst.util.quietly
3533
import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanInfo, SQLExecution}
3634
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
3735
import org.apache.spark.sql.test.SharedSQLContext
@@ -485,46 +483,3 @@ private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExe
485483
sc.emptyRDD
486484
}
487485
}
488-
489-
490-
class SQLListenerMemoryLeakSuite extends SparkFunSuite {
491-
492-
test("no memory leak") {
493-
quietly {
494-
val conf = new SparkConf()
495-
.setMaster("local")
496-
.setAppName("test")
497-
.set(config.MAX_TASK_FAILURES, 1) // Don't retry the tasks to run this test quickly
498-
.set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly
499-
val sc = new SparkContext(conf)
500-
try {
501-
SparkSession.sqlListener.set(null)
502-
val spark = new SparkSession(sc)
503-
import spark.implicits._
504-
// Run 100 successful executions and 100 failed executions.
505-
// Each execution only has one job and one stage.
506-
for (i <- 0 until 100) {
507-
val df = Seq(
508-
(1, 1),
509-
(2, 2)
510-
).toDF()
511-
df.collect()
512-
try {
513-
df.foreach(_ => throw new RuntimeException("Oops"))
514-
} catch {
515-
case e: SparkException => // This is expected for a failed job
516-
}
517-
}
518-
sc.listenerBus.waitUntilEmpty(10000)
519-
assert(spark.sharedState.listener.getCompletedExecutions.size <= 50)
520-
assert(spark.sharedState.listener.getFailedExecutions.size <= 50)
521-
// 50 for successful executions and 50 for failed executions
522-
assert(spark.sharedState.listener.executionIdToData.size <= 100)
523-
assert(spark.sharedState.listener.jobIdToExecutionId.size <= 100)
524-
assert(spark.sharedState.listener.stageIdToStageMetrics.size <= 100)
525-
} finally {
526-
sc.stop()
527-
}
528-
}
529-
}
530-
}

0 commit comments

Comments
 (0)