Skip to content

Commit d4f204c

Browse files
committed
[SPARK-23942][PYTHON][SQL][BRANCH-2.3] Makes collect in PySpark as action for a query executor listener
## What changes were proposed in this pull request? This PR proposes to add `collect` to a query executor as an action. Seems `collect` / `collect` with Arrow are not recognised via `QueryExecutionListener` as an action. For example, if we have a custom listener as below: ```scala package org.apache.spark.sql import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener class TestQueryExecutionListener extends QueryExecutionListener with Logging { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { logError("Look at me! I'm 'onSuccess'") } override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { } } ``` and set `spark.sql.queryExecutionListeners` to `org.apache.spark.sql.TestQueryExecutionListener` Other operations in PySpark or Scala side seems fine: ```python >>> sql("SELECT * FROM range(1)").show() ``` ``` 18/04/09 17:02:04 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' +---+ | id| +---+ | 0| +---+ ``` ```scala scala> sql("SELECT * FROM range(1)").collect() ``` ``` 18/04/09 16:58:41 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' res1: Array[org.apache.spark.sql.Row] = Array([0]) ``` but .. **Before** ```python >>> sql("SELECT * FROM range(1)").collect() ``` ``` [Row(id=0)] ``` ```python >>> spark.conf.set("spark.sql.execution.arrow.enabled", "true") >>> sql("SELECT * FROM range(1)").toPandas() ``` ``` id 0 0 ``` **After** ```python >>> sql("SELECT * FROM range(1)").collect() ``` ``` 18/04/09 16:57:58 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' [Row(id=0)] ``` ```python >>> spark.conf.set("spark.sql.execution.arrow.enabled", "true") >>> sql("SELECT * FROM range(1)").toPandas() ``` ``` 18/04/09 17:53:26 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' id 0 0 ``` ## How was this patch tested? I have manually tested as described above and unit test was added. Author: hyukjinkwon <[email protected]> Closes #21060 from HyukjinKwon/PR_TOOL_PICK_PR_21007_BRANCH-2.3.
1 parent dfdf1bb commit d4f204c

File tree

3 files changed

+140
-23
lines changed

3 files changed

+140
-23
lines changed

python/pyspark/sql/tests.py

Lines changed: 83 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -185,22 +185,12 @@ def __init__(self, key, value):
185185
self.value = value
186186

187187

188-
class ReusedSQLTestCase(ReusedPySparkTestCase):
189-
@classmethod
190-
def setUpClass(cls):
191-
ReusedPySparkTestCase.setUpClass()
192-
cls.spark = SparkSession(cls.sc)
193-
194-
@classmethod
195-
def tearDownClass(cls):
196-
ReusedPySparkTestCase.tearDownClass()
197-
cls.spark.stop()
198-
199-
def assertPandasEqual(self, expected, result):
200-
msg = ("DataFrames are not equal: " +
201-
"\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
202-
"\n\nResult:\n%s\n%s" % (result, result.dtypes))
203-
self.assertTrue(expected.equals(result), msg=msg)
188+
class SQLTestUtils(object):
189+
"""
190+
This util assumes the instance of this to have 'spark' attribute, having a spark session.
191+
It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the
192+
the implementation of this class has 'spark' attribute.
193+
"""
204194

205195
@contextmanager
206196
def sql_conf(self, pairs):
@@ -209,6 +199,7 @@ def sql_conf(self, pairs):
209199
`value` to the configuration `key` and then restores it back when it exits.
210200
"""
211201
assert isinstance(pairs, dict), "pairs should be a dictionary."
202+
assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
212203

213204
keys = pairs.keys()
214205
new_values = pairs.values()
@@ -225,6 +216,24 @@ def sql_conf(self, pairs):
225216
self.spark.conf.set(key, old_value)
226217

227218

219+
class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
220+
@classmethod
221+
def setUpClass(cls):
222+
ReusedPySparkTestCase.setUpClass()
223+
cls.spark = SparkSession(cls.sc)
224+
225+
@classmethod
226+
def tearDownClass(cls):
227+
ReusedPySparkTestCase.tearDownClass()
228+
cls.spark.stop()
229+
230+
def assertPandasEqual(self, expected, result):
231+
msg = ("DataFrames are not equal: " +
232+
"\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
233+
"\n\nResult:\n%s\n%s" % (result, result.dtypes))
234+
self.assertTrue(expected.equals(result), msg=msg)
235+
236+
228237
class DataTypeTests(unittest.TestCase):
229238
# regression test for SPARK-6055
230239
def test_data_type_eq(self):
@@ -2980,6 +2989,64 @@ def test_sparksession_with_stopped_sparkcontext(self):
29802989
sc.stop()
29812990

29822991

2992+
class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
2993+
# These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is
2994+
# static and immutable. This can't be set or unset, for example, via `spark.conf`.
2995+
2996+
@classmethod
2997+
def setUpClass(cls):
2998+
import glob
2999+
from pyspark.find_spark_home import _find_spark_home
3000+
3001+
SPARK_HOME = _find_spark_home()
3002+
filename_pattern = (
3003+
"sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
3004+
"TestQueryExecutionListener.class")
3005+
if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
3006+
raise unittest.SkipTest(
3007+
"'org.apache.spark.sql.TestQueryExecutionListener' is not "
3008+
"available. Will skip the related tests.")
3009+
3010+
# Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration.
3011+
cls.spark = SparkSession.builder \
3012+
.master("local[4]") \
3013+
.appName(cls.__name__) \
3014+
.config(
3015+
"spark.sql.queryExecutionListeners",
3016+
"org.apache.spark.sql.TestQueryExecutionListener") \
3017+
.getOrCreate()
3018+
3019+
@classmethod
3020+
def tearDownClass(cls):
3021+
cls.spark.stop()
3022+
3023+
def tearDown(self):
3024+
self.spark._jvm.OnSuccessCall.clear()
3025+
3026+
def test_query_execution_listener_on_collect(self):
3027+
self.assertFalse(
3028+
self.spark._jvm.OnSuccessCall.isCalled(),
3029+
"The callback from the query execution listener should not be called before 'collect'")
3030+
self.spark.sql("SELECT * FROM range(1)").collect()
3031+
self.assertTrue(
3032+
self.spark._jvm.OnSuccessCall.isCalled(),
3033+
"The callback from the query execution listener should be called after 'collect'")
3034+
3035+
@unittest.skipIf(
3036+
not _have_pandas or not _have_pyarrow,
3037+
_pandas_requirement_message or _pyarrow_requirement_message)
3038+
def test_query_execution_listener_on_collect_with_arrow(self):
3039+
with self.sql_conf({"spark.sql.execution.arrow.enabled": True}):
3040+
self.assertFalse(
3041+
self.spark._jvm.OnSuccessCall.isCalled(),
3042+
"The callback from the query execution listener should not be "
3043+
"called before 'toPandas'")
3044+
self.spark.sql("SELECT * FROM range(1)").toPandas()
3045+
self.assertTrue(
3046+
self.spark._jvm.OnSuccessCall.isCalled(),
3047+
"The callback from the query execution listener should be called after 'toPandas'")
3048+
3049+
29833050
class UDFInitializationTests(unittest.TestCase):
29843051
def tearDown(self):
29853052
if SparkSession._instantiatedSession is not None:

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3189,10 +3189,10 @@ class Dataset[T] private[sql](
31893189

31903190
private[sql] def collectToPython(): Int = {
31913191
EvaluatePython.registerPicklers()
3192-
withNewExecutionId {
3192+
withAction("collectToPython", queryExecution) { plan =>
31933193
val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
3194-
val iter = new SerDeUtil.AutoBatchedPickler(
3195-
queryExecution.executedPlan.executeCollect().iterator.map(toJava))
3194+
val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
3195+
plan.executeCollect().iterator.map(toJava))
31963196
PythonRDD.serveIterator(iter, "serve-DataFrame")
31973197
}
31983198
}
@@ -3201,8 +3201,9 @@ class Dataset[T] private[sql](
32013201
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
32023202
*/
32033203
private[sql] def collectAsArrowToPython(): Int = {
3204-
withNewExecutionId {
3205-
val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable)
3204+
withAction("collectAsArrowToPython", queryExecution) { plan =>
3205+
val iter: Iterator[Array[Byte]] =
3206+
toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
32063207
PythonRDD.serveIterator(iter, "serve-Arrow")
32073208
}
32083209
}
@@ -3311,14 +3312,19 @@ class Dataset[T] private[sql](
33113312
}
33123313

33133314
/** Convert to an RDD of ArrowPayload byte arrays */
3314-
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
3315+
private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = {
33153316
val schemaCaptured = this.schema
33163317
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
33173318
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
3318-
queryExecution.toRdd.mapPartitionsInternal { iter =>
3319+
plan.execute().mapPartitionsInternal { iter =>
33193320
val context = TaskContext.get()
33203321
ArrowConverters.toPayloadIterator(
33213322
iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context)
33223323
}
33233324
}
3325+
3326+
// This is only used in tests, for now.
3327+
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
3328+
toArrowPayload(queryExecution.executedPlan)
3329+
}
33243330
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import java.util.concurrent.atomic.AtomicBoolean
21+
22+
import org.apache.spark.sql.execution.QueryExecution
23+
import org.apache.spark.sql.util.QueryExecutionListener
24+
25+
26+
class TestQueryExecutionListener extends QueryExecutionListener {
27+
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
28+
OnSuccessCall.isOnSuccessCalled.set(true)
29+
}
30+
31+
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { }
32+
}
33+
34+
/**
35+
* This has a variable to check if `onSuccess` is actually called or not. Currently, this is for
36+
* the test case in PySpark. See SPARK-23942.
37+
*/
38+
object OnSuccessCall {
39+
val isOnSuccessCalled = new AtomicBoolean(false)
40+
41+
def isCalled(): Boolean = isOnSuccessCalled.get()
42+
43+
def clear(): Unit = isOnSuccessCalled.set(false)
44+
}

0 commit comments

Comments
 (0)