From 87840b81e9a4a2ff7f595376bd1d20f086c5b384 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Tue, 27 Aug 2019 15:21:24 +0900 Subject: [PATCH] Add a test to make sure toPandas with Arrow optimization throws an exception per maxResultSize --- python/pyspark/sql/tests/test_arrow.py | 31 +++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index f5330835a3f22..50c82b0b5f88a 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -22,7 +22,7 @@ import unittest import warnings -from pyspark.sql import Row +from pyspark.sql import Row, SparkSession from pyspark.sql.functions import udf from pyspark.sql.types import * from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ @@ -421,6 +421,35 @@ def run_test(num_records, num_parts, max_records, use_delay=False): run_test(*case) +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message) +class MaxResultArrowTests(unittest.TestCase): + # These tests are separate as 'spark.driver.maxResultSize' configuration + # is a static configuration to Spark context. + + @classmethod + def setUpClass(cls): + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config("spark.driver.maxResultSize", "10k") \ + .getOrCreate() + + # Explicitly enable Arrow and disable fallback. + cls.spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") + cls.spark.conf.set("spark.sql.execution.arrow.pyspark.fallback.enabled", "false") + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "spark"): + cls.spark.stop() + + def test_exception_by_max_results(self): + with self.assertRaisesRegexp(Exception, "is bigger than"): + self.spark.range(0, 10000, 1, 100).toPandas() + + class EncryptionArrowTests(ArrowTests): @classmethod