From 8f6f47c254cbca5ef32b7e189db99a7517e772f9 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 7 Jun 2024 18:09:35 +0800 Subject: [PATCH 1/3] init --- python/pyspark/sql/connect/dataframe.py | 29 +++-- .../test_connect_dataframe_property.py | 100 ++++++++++++++++++ 2 files changed, 123 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index dae5f3d122ee8..39df6de2909e9 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1131,24 +1131,31 @@ def observe( def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: print(self._show_string(n, truncate, vertical)) + def _merge_cached_schema(self, other: ParentDataFrame) -> Optional[StructType]: + if self._cached_schema is not None and self._cached_schema == other._cached_schema: # type: ignore[arg-type] + return self.schema + return None + def union(self, other: ParentDataFrame) -> ParentDataFrame: self._check_same_session(other) return self.unionAll(other) def unionAll(self, other: ParentDataFrame) -> ParentDataFrame: self._check_same_session(other) - return DataFrame( + res = DataFrame( plan.SetOperation( self._plan, other._plan, "union", is_all=True # type: ignore[arg-type] ), session=self._session, ) + res._cached_schema = self._merge_cached_schema(other) + return res def unionByName( self, other: ParentDataFrame, allowMissingColumns: bool = False ) -> ParentDataFrame: self._check_same_session(other) - return DataFrame( + res = DataFrame( plan.SetOperation( self._plan, other._plan, # type: ignore[arg-type] @@ -1158,42 +1165,52 @@ def unionByName( ), session=self._session, ) + res._cached_schema = self._merge_cached_schema(other) + return res def subtract(self, other: ParentDataFrame) -> ParentDataFrame: self._check_same_session(other) - return DataFrame( + res = DataFrame( plan.SetOperation( self._plan, other._plan, "except", is_all=False # type: ignore[arg-type] ), session=self._session, ) + res._cached_schema = self._merge_cached_schema(other) + return res def exceptAll(self, other: ParentDataFrame) -> ParentDataFrame: self._check_same_session(other) - return DataFrame( + res = DataFrame( plan.SetOperation( self._plan, other._plan, "except", is_all=True # type: ignore[arg-type] ), session=self._session, ) + res._cached_schema = self._merge_cached_schema(other) + return res def intersect(self, other: ParentDataFrame) -> ParentDataFrame: self._check_same_session(other) - return DataFrame( + res = DataFrame( plan.SetOperation( self._plan, other._plan, "intersect", is_all=False # type: ignore[arg-type] ), session=self._session, ) + res._cached_schema = self._merge_cached_schema(other) + return res def intersectAll(self, other: ParentDataFrame) -> ParentDataFrame: self._check_same_session(other) - return DataFrame( + res = DataFrame( plan.SetOperation( self._plan, other._plan, "intersect", is_all=True # type: ignore[arg-type] ), session=self._session, ) + res._cached_schema = self._merge_cached_schema(other) + return res def where(self, condition: Union[Column, str]) -> ParentDataFrame: if not isinstance(condition, (str, Column)): diff --git a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py index c87c44760256e..4a7e1e1ea7606 100644 --- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py +++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py @@ -293,6 +293,106 @@ def summarize(left, right): self.assertEqual(cdf3.schema, sdf3.schema) self.assertEqual(cdf3.collect(), sdf3.collect()) + def test_cached_schema_set_op(self): + data1 = [(1, 2, 3)] + data2 = [(6, 2, 5)] + data3 = [(6, 2, 5.0)] + + cdf1 = self.connect.createDataFrame(data1, ["a", "b", "c"]) + sdf1 = self.spark.createDataFrame(data1, ["a", "b", "c"]) + cdf2 = self.connect.createDataFrame(data2, ["a", "b", "c"]) + sdf2 = self.spark.createDataFrame(data2, ["a", "b", "c"]) + cdf3 = self.connect.createDataFrame(data3, ["a", "b", "c"]) + sdf3 = self.spark.createDataFrame(data3, ["a", "b", "c"]) + + # schema not yet cached + self.assertTrue(cdf1._cached_schema is None) + self.assertTrue(cdf2._cached_schema is None) + self.assertTrue(cdf3._cached_schema is None) + + # no cached schema in result dataframe + self.assertTrue(cdf1.union(cdf1)._cached_schema is None) + self.assertTrue(cdf1.union(cdf2)._cached_schema is None) + self.assertTrue(cdf1.union(cdf3)._cached_schema is None) + + self.assertTrue(cdf1.unionAll(cdf1)._cached_schema is None) + self.assertTrue(cdf1.unionAll(cdf2)._cached_schema is None) + self.assertTrue(cdf1.unionAll(cdf3)._cached_schema is None) + + self.assertTrue(cdf1.unionByName(cdf1)._cached_schema is None) + self.assertTrue(cdf1.unionByName(cdf2)._cached_schema is None) + self.assertTrue(cdf1.unionByName(cdf3)._cached_schema is None) + + self.assertTrue(cdf1.subtract(cdf1)._cached_schema is None) + self.assertTrue(cdf1.subtract(cdf2)._cached_schema is None) + self.assertTrue(cdf1.subtract(cdf3)._cached_schema is None) + + self.assertTrue(cdf1.exceptAll(cdf1)._cached_schema is None) + self.assertTrue(cdf1.exceptAll(cdf2)._cached_schema is None) + self.assertTrue(cdf1.exceptAll(cdf3)._cached_schema is None) + + self.assertTrue(cdf1.intersect(cdf1)._cached_schema is None) + self.assertTrue(cdf1.intersect(cdf2)._cached_schema is None) + self.assertTrue(cdf1.intersect(cdf3)._cached_schema is None) + + self.assertTrue(cdf1.intersectAll(cdf1)._cached_schema is None) + self.assertTrue(cdf1.intersectAll(cdf2)._cached_schema is None) + self.assertTrue(cdf1.intersectAll(cdf3)._cached_schema is None) + + # trigger analysis of cdf1.schema + self.assertEqual(cdf1.schema, sdf1.schema) + self.assertTrue(cdf1._cached_schema is not None) + + self.assertEqual(cdf1.union(cdf1)._cached_schema, cdf1._cached_schema) + # cannot infer when cdf2 doesn't cache schema + self.assertTrue(cdf1.union(cdf2)._cached_schema is None) + # cannot infer when cdf3 doesn't cache schema + self.assertTrue(cdf1.union(cdf3)._cached_schema is None) + + # trigger analysis of cdf2.schema, cdf3.schema + self.assertEqual(cdf2.schema, sdf2.schema) + self.assertEqual(cdf3.schema, sdf3.schema) + + # now all the schemas are cached + self.assertTrue(cdf1._cached_schema is not None) + self.assertTrue(cdf2._cached_schema is not None) + self.assertTrue(cdf3._cached_schema is not None) + + self.assertEqual(cdf1.union(cdf1)._cached_schema, cdf1._cached_schema) + self.assertEqual(cdf1.union(cdf2)._cached_schema, cdf1._cached_schema) + # cannot infer when schemas mismatch + self.assertTrue(cdf1.union(cdf3)._cached_schema is None) + + self.assertEqual(cdf1.unionAll(cdf1)._cached_schema, cdf1._cached_schema) + self.assertEqual(cdf1.unionAll(cdf2)._cached_schema, cdf1._cached_schema) + # cannot infer when schemas mismatch + self.assertTrue(cdf1.unionAll(cdf3)._cached_schema is None) + + self.assertEqual(cdf1.unionByName(cdf1)._cached_schema, cdf1._cached_schema) + self.assertEqual(cdf1.unionByName(cdf2)._cached_schema, cdf1._cached_schema) + # cannot infer when schemas mismatch + self.assertTrue(cdf1.unionByName(cdf3)._cached_schema is None) + + self.assertEqual(cdf1.subtract(cdf1)._cached_schema, cdf1._cached_schema) + self.assertEqual(cdf1.subtract(cdf2)._cached_schema, cdf1._cached_schema) + # cannot infer when schemas mismatch + self.assertTrue(cdf1.subtract(cdf3)._cached_schema is None) + + self.assertEqual(cdf1.exceptAll(cdf1)._cached_schema, cdf1._cached_schema) + self.assertEqual(cdf1.exceptAll(cdf2)._cached_schema, cdf1._cached_schema) + # cannot infer when schemas mismatch + self.assertTrue(cdf1.exceptAll(cdf3)._cached_schema is None) + + self.assertEqual(cdf1.intersect(cdf1)._cached_schema, cdf1._cached_schema) + self.assertEqual(cdf1.intersect(cdf2)._cached_schema, cdf1._cached_schema) + # cannot infer when schemas mismatch + self.assertTrue(cdf1.intersect(cdf3)._cached_schema is None) + + self.assertEqual(cdf1.intersectAll(cdf1)._cached_schema, cdf1._cached_schema) + self.assertEqual(cdf1.intersectAll(cdf2)._cached_schema, cdf1._cached_schema) + # cannot infer when schemas mismatch + self.assertTrue(cdf1.intersectAll(cdf3)._cached_schema is None) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_dataframe_property import * # noqa: F401 From 9772edc722eac21afa25a8a5f8c2730d1c38375d Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 7 Jun 2024 18:17:35 +0800 Subject: [PATCH 2/3] nit --- python/pyspark/sql/connect/dataframe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 39df6de2909e9..9e93dec2e9312 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1132,6 +1132,8 @@ def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = print(self._show_string(n, truncate, vertical)) def _merge_cached_schema(self, other: ParentDataFrame) -> Optional[StructType]: + # to avoid type coercion, only propagate the schema + # when the cached schemas are exactly the same if self._cached_schema is not None and self._cached_schema == other._cached_schema: # type: ignore[arg-type] return self.schema return None From 56d1917073fd2bb0e10b59518f251a0cfc1df849 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sat, 8 Jun 2024 07:51:21 +0800 Subject: [PATCH 3/3] nit --- python/pyspark/sql/connect/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 9e93dec2e9312..3cba002b87d1f 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1134,7 +1134,7 @@ def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = def _merge_cached_schema(self, other: ParentDataFrame) -> Optional[StructType]: # to avoid type coercion, only propagate the schema # when the cached schemas are exactly the same - if self._cached_schema is not None and self._cached_schema == other._cached_schema: # type: ignore[arg-type] + if self._cached_schema is not None and self._cached_schema == other._cached_schema: return self.schema return None