diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index f5330835a3f22..50c82b0b5f88a 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -22,7 +22,7 @@ import unittest import warnings -from pyspark.sql import Row +from pyspark.sql import Row, SparkSession from pyspark.sql.functions import udf from pyspark.sql.types import * from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ @@ -421,6 +421,35 @@ def run_test(num_records, num_parts, max_records, use_delay=False): run_test(*case) +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message) +class MaxResultArrowTests(unittest.TestCase): + # These tests are separate as 'spark.driver.maxResultSize' configuration + # is a static configuration to Spark context. + + @classmethod + def setUpClass(cls): + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config("spark.driver.maxResultSize", "10k") \ + .getOrCreate() + + # Explicitly enable Arrow and disable fallback. + cls.spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") + cls.spark.conf.set("spark.sql.execution.arrow.pyspark.fallback.enabled", "false") + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "spark"): + cls.spark.stop() + + def test_exception_by_max_results(self): + with self.assertRaisesRegexp(Exception, "is bigger than"): + self.spark.range(0, 10000, 1, 100).toPandas() + + class EncryptionArrowTests(ArrowTests): @classmethod