Skip to content

Commit 5c47924

Browse files
committed
[SPARK-27612][PYTHON] Use Python's default protocol instead of highest protocol
## What changes were proposed in this pull request? This PR partially reverts #20691 After we changed the Python protocol to highest ones, seems like it introduced a correctness bug. This potentially affects all Python related code paths. I suspect a bug related to Pryolite (maybe opcodes `MEMOIZE`, `FRAME` and/or our `RowPickler`). I would like to stick to default protocol for now and investigate the issue separately. I will separately investigate later to bring highest protocol back. ## How was this patch tested? Unittest was added. ```bash ./run-tests --python-executables=python3.7 --testname "pyspark.sql.tests.test_serde SerdeTests.test_int_array_serialization" ``` Closes #24519 from HyukjinKwon/SPARK-27612. Authored-by: HyukjinKwon <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
1 parent 3859ca3 commit 5c47924

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

python/pyspark/serializers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,12 @@
6262
if sys.version < '3':
6363
import cPickle as pickle
6464
from itertools import izip as zip, imap as map
65+
pickle_protocol = 2
6566
else:
6667
import pickle
6768
basestring = unicode = str
6869
xrange = range
69-
pickle_protocol = pickle.HIGHEST_PROTOCOL
70+
pickle_protocol = 3
7071

7172
from pyspark import cloudpickle
7273
from pyspark.util import _exception_message

python/pyspark/sql/tests/test_serde.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,12 @@ def test_BinaryType_serialization(self):
126126
df = self.spark.createDataFrame(data, schema=schema)
127127
df.collect()
128128

129+
def test_int_array_serialization(self):
130+
# Note that this test seems dependent on parallelism.
131+
data = self.spark.sparkContext.parallelize([[1, 2, 3, 4]] * 100, numSlices=12)
132+
df = self.spark.createDataFrame(data, "array<integer>")
133+
self.assertEqual(len(list(filter(lambda r: None in r.value, df.collect()))), 0)
134+
129135

130136
if __name__ == "__main__":
131137
import unittest

0 commit comments

Comments
 (0)