Skip to content

Commit c4bb486

Browse files
HyukjinKwonBryanCutler
authored andcommitted
[SPARK-27992][SPARK-28881][PYTHON][2.4] Allow Python to join with connection thread to propagate errors
### What changes were proposed in this pull request? This PR proposes to backport #24834 with minimised changes, and the tests added at #25594. #24834 was not backported before because basically it targeted a better exception by propagating the exception from JVM. However, actually this PR fixed another problem accidentally (see #25594 and [SPARK-28881](https://issues.apache.org/jira/browse/SPARK-28881)). This regression seems introduced by #21546. Root cause is that, seems https://github.com/apache/spark/blob/23bed0d3c08e03085d3f0c3a7d457eedd30bd67f/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L3370-L3384 `runJob` with `resultHandler` seems able to write partial output. JVM throws an exception but, since the JVM exception is not propagated into Python process, Python process doesn't know if the exception is thrown or not from JVM (it just closes the socket), which results as below: ``` ./bin/pyspark --conf spark.driver.maxResultSize=1m ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled",True) spark.range(10000000).toPandas() ``` ``` Empty DataFrame Columns: [id] Index: [] ``` With this change, it lets Python process catches exceptions from JVM. ### Why are the changes needed? It returns incorrect data. And potentially it returns partial results when an exception happens in JVM sides. This is a regression. The codes work fine in Spark 2.3.3. ### Does this PR introduce any user-facing change? Yes. ``` ./bin/pyspark --conf spark.driver.maxResultSize=1m ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled",True) spark.range(10000000).toPandas() ``` ``` Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/.../pyspark/sql/dataframe.py", line 2122, in toPandas batches = self._collectAsArrow() File "/.../pyspark/sql/dataframe.py", line 2184, in _collectAsArrow jsocket_auth_server.getResult() # Join serving thread and raise any exceptions File "/.../lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__ File "/.../pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/.../lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o42.getResult. : org.apache.spark.SparkException: Exception thrown in awaitResult: ... Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Total size of serialized results of 1 tasks (6.5 MB) is bigger than spark.driver.maxResultSize (1024.0 KB) ``` now throws an exception as expected. ### How was this patch tested? Manually as described above. unittest added. Closes #25593 from HyukjinKwon/SPARK-27992. Lead-authored-by: HyukjinKwon <[email protected]> Co-authored-by: Bryan Cutler <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 0d0686e commit c4bb486

File tree

4 files changed

+72
-5
lines changed

4 files changed

+72
-5
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,29 @@ private[spark] object PythonRDD extends Logging {
440440
Array(port, secret)
441441
}
442442

443+
/**
444+
* Create a socket server object and background thread to execute the writeFunc
445+
* with the given OutputStream.
446+
*
447+
* This is the same as serveToStream, only it returns a server object that
448+
* can be used to sync in Python.
449+
*/
450+
private[spark] def serveToStreamWithSync(
451+
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
452+
453+
val handleFunc = (sock: Socket) => {
454+
val out = new BufferedOutputStream(sock.getOutputStream())
455+
Utils.tryWithSafeFinally {
456+
writeFunc(out)
457+
} {
458+
out.close()
459+
}
460+
}
461+
462+
val server = new SocketFuncServer(authHelper, threadName, handleFunc)
463+
Array(server.port, server.secret, server)
464+
}
465+
443466
private def getMergedConf(confAsMap: java.util.HashMap[String, String],
444467
baseConf: Configuration): Configuration = {
445468
val conf = PythonHadoopUtil.mapToConf(confAsMap)
@@ -957,3 +980,17 @@ private[spark] class PythonParallelizeServer(sc: SparkContext, parallelism: Int)
957980
}
958981
}
959982

983+
/**
984+
* Create a socket server class and run user function on the socket in a background thread.
985+
* This is the same as calling SocketAuthServer.setupOneConnectionServer except it creates
986+
* a server object that can then be synced from Python.
987+
*/
988+
private [spark] class SocketFuncServer(
989+
authHelper: SocketAuthHelper,
990+
threadName: String,
991+
func: Socket => Unit) extends PythonServer[Unit](authHelper, threadName) {
992+
993+
override def handleConnection(sock: Socket): Unit = {
994+
func(sock)
995+
}
996+
}

python/pyspark/sql/dataframe.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,9 +2175,13 @@ def _collectAsArrow(self):
21752175
21762176
.. note:: Experimental.
21772177
"""
2178-
with SCCallSiteSync(self._sc) as css:
2179-
sock_info = self._jdf.collectAsArrowToPython()
2180-
return list(_load_from_socket(sock_info, ArrowStreamSerializer()))
2178+
with SCCallSiteSync(self._sc):
2179+
from pyspark.rdd import _load_from_socket
2180+
port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython()
2181+
try:
2182+
return list(_load_from_socket((port, auth_secret), ArrowStreamSerializer()))
2183+
finally:
2184+
jsocket_auth_server.getResult() # Join serving thread and raise any exceptions
21812185

21822186
##########################################################################################
21832187
# Pandas compatibility

python/pyspark/sql/tests.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
_have_pyarrow = _pyarrow_requirement_message is None
8181
_test_compiled = _test_not_compiled_message is None
8282

83-
from pyspark import SparkContext
83+
from pyspark import SparkContext, SparkConf
8484
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
8585
from pyspark.sql.types import *
8686
from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
@@ -4550,6 +4550,32 @@ def test_timestamp_dst(self):
45504550
self.assertPandasEqual(pdf, df_from_pandas.toPandas())
45514551

45524552

4553+
@unittest.skipIf(
4554+
not _have_pandas or not _have_pyarrow,
4555+
_pandas_requirement_message or _pyarrow_requirement_message)
4556+
class MaxResultArrowTests(unittest.TestCase):
4557+
# These tests are separate as 'spark.driver.maxResultSize' configuration
4558+
# is a static configuration to Spark context.
4559+
4560+
@classmethod
4561+
def setUpClass(cls):
4562+
cls.spark = SparkSession(SparkContext(
4563+
'local[4]', cls.__name__, conf=SparkConf().set("spark.driver.maxResultSize", "10k")))
4564+
4565+
# Explicitly enable Arrow and disable fallback.
4566+
cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
4567+
cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false")
4568+
4569+
@classmethod
4570+
def tearDownClass(cls):
4571+
if hasattr(cls, "spark"):
4572+
cls.spark.stop()
4573+
4574+
def test_exception_by_max_results(self):
4575+
with self.assertRaisesRegexp(Exception, "is bigger than"):
4576+
self.spark.range(0, 10000, 1, 100).toPandas()
4577+
4578+
45534579
class EncryptionArrowTests(ArrowTests):
45544580

45554581
@classmethod

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3284,7 +3284,7 @@ class Dataset[T] private[sql](
32843284
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
32853285

32863286
withAction("collectAsArrowToPython", queryExecution) { plan =>
3287-
PythonRDD.serveToStream("serve-Arrow") { out =>
3287+
PythonRDD.serveToStreamWithSync("serve-Arrow") { out =>
32883288
val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
32893289
val arrowBatchRdd = toArrowBatchRdd(plan)
32903290
val numPartitions = arrowBatchRdd.partitions.length

0 commit comments

Comments
 (0)