diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index dae5f3d122ee8..3cba002b87d1f 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1131,24 +1131,33 @@ def observe( def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: print(self._show_string(n, truncate, vertical)) + def _merge_cached_schema(self, other: ParentDataFrame) -> Optional[StructType]: + # to avoid type coercion, only propagate the schema + # when the cached schemas are exactly the same + if self._cached_schema is not None and self._cached_schema == other._cached_schema: + return self.schema + return None + def union(self, other: ParentDataFrame) -> ParentDataFrame: self._check_same_session(other) return self.unionAll(other) def unionAll(self, other: ParentDataFrame) -> ParentDataFrame: self._check_same_session(other) - return DataFrame( + res = DataFrame( plan.SetOperation( self._plan, other._plan, "union", is_all=True # type: ignore[arg-type] ), session=self._session, ) + res._cached_schema = self._merge_cached_schema(other) + return res def unionByName( self, other: ParentDataFrame, allowMissingColumns: bool = False ) -> ParentDataFrame: self._check_same_session(other) - return DataFrame( + res = DataFrame( plan.SetOperation( self._plan, other._plan, # type: ignore[arg-type] @@ -1158,42 +1167,52 @@ def unionByName( ), session=self._session, ) + res._cached_schema = self._merge_cached_schema(other) + return res def subtract(self, other: ParentDataFrame) -> ParentDataFrame: self._check_same_session(other) - return DataFrame( + res = DataFrame( plan.SetOperation( self._plan, other._plan, "except", is_all=False # type: ignore[arg-type] ), session=self._session, ) + res._cached_schema = self._merge_cached_schema(other) + return res def exceptAll(self, other: ParentDataFrame) -> ParentDataFrame: self._check_same_session(other) - return DataFrame( + res = DataFrame( plan.SetOperation( self._plan, other._plan, "except", is_all=True # type: ignore[arg-type] ), session=self._session, ) + res._cached_schema = self._merge_cached_schema(other) + return res def intersect(self, other: ParentDataFrame) -> ParentDataFrame: self._check_same_session(other) - return DataFrame( + res = DataFrame( plan.SetOperation( self._plan, other._plan, "intersect", is_all=False # type: ignore[arg-type] ), session=self._session, ) + res._cached_schema = self._merge_cached_schema(other) + return res def intersectAll(self, other: ParentDataFrame) -> ParentDataFrame: self._check_same_session(other) - return DataFrame( + res = DataFrame( plan.SetOperation( self._plan, other._plan, "intersect", is_all=True # type: ignore[arg-type] ), session=self._session, ) + res._cached_schema = self._merge_cached_schema(other) + return res def where(self, condition: Union[Column, str]) -> ParentDataFrame: if not isinstance(condition, (str, Column)): diff --git a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py index c87c44760256e..4a7e1e1ea7606 100644 --- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py +++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py @@ -293,6 +293,106 @@ def summarize(left, right): self.assertEqual(cdf3.schema, sdf3.schema) self.assertEqual(cdf3.collect(), sdf3.collect()) + def test_cached_schema_set_op(self): + data1 = [(1, 2, 3)] + data2 = [(6, 2, 5)] + data3 = [(6, 2, 5.0)] + + cdf1 = self.connect.createDataFrame(data1, ["a", "b", "c"]) + sdf1 = self.spark.createDataFrame(data1, ["a", "b", "c"]) + cdf2 = self.connect.createDataFrame(data2, ["a", "b", "c"]) + sdf2 = self.spark.createDataFrame(data2, ["a", "b", "c"]) + cdf3 = self.connect.createDataFrame(data3, ["a", "b", "c"]) + sdf3 = self.spark.createDataFrame(data3, ["a", "b", "c"]) + + # schema not yet cached + self.assertTrue(cdf1._cached_schema is None) + self.assertTrue(cdf2._cached_schema is None) + self.assertTrue(cdf3._cached_schema is None) + + # no cached schema in result dataframe + self.assertTrue(cdf1.union(cdf1)._cached_schema is None) + self.assertTrue(cdf1.union(cdf2)._cached_schema is None) + self.assertTrue(cdf1.union(cdf3)._cached_schema is None) + + self.assertTrue(cdf1.unionAll(cdf1)._cached_schema is None) + self.assertTrue(cdf1.unionAll(cdf2)._cached_schema is None) + self.assertTrue(cdf1.unionAll(cdf3)._cached_schema is None) + + self.assertTrue(cdf1.unionByName(cdf1)._cached_schema is None) + self.assertTrue(cdf1.unionByName(cdf2)._cached_schema is None) + self.assertTrue(cdf1.unionByName(cdf3)._cached_schema is None) + + self.assertTrue(cdf1.subtract(cdf1)._cached_schema is None) + self.assertTrue(cdf1.subtract(cdf2)._cached_schema is None) + self.assertTrue(cdf1.subtract(cdf3)._cached_schema is None) + + self.assertTrue(cdf1.exceptAll(cdf1)._cached_schema is None) + self.assertTrue(cdf1.exceptAll(cdf2)._cached_schema is None) + self.assertTrue(cdf1.exceptAll(cdf3)._cached_schema is None) + + self.assertTrue(cdf1.intersect(cdf1)._cached_schema is None) + self.assertTrue(cdf1.intersect(cdf2)._cached_schema is None) + self.assertTrue(cdf1.intersect(cdf3)._cached_schema is None) + + self.assertTrue(cdf1.intersectAll(cdf1)._cached_schema is None) + self.assertTrue(cdf1.intersectAll(cdf2)._cached_schema is None) + self.assertTrue(cdf1.intersectAll(cdf3)._cached_schema is None) + + # trigger analysis of cdf1.schema + self.assertEqual(cdf1.schema, sdf1.schema) + self.assertTrue(cdf1._cached_schema is not None) + + self.assertEqual(cdf1.union(cdf1)._cached_schema, cdf1._cached_schema) + # cannot infer when cdf2 doesn't cache schema + self.assertTrue(cdf1.union(cdf2)._cached_schema is None) + # cannot infer when cdf3 doesn't cache schema + self.assertTrue(cdf1.union(cdf3)._cached_schema is None) + + # trigger analysis of cdf2.schema, cdf3.schema + self.assertEqual(cdf2.schema, sdf2.schema) + self.assertEqual(cdf3.schema, sdf3.schema) + + # now all the schemas are cached + self.assertTrue(cdf1._cached_schema is not None) + self.assertTrue(cdf2._cached_schema is not None) + self.assertTrue(cdf3._cached_schema is not None) + + self.assertEqual(cdf1.union(cdf1)._cached_schema, cdf1._cached_schema) + self.assertEqual(cdf1.union(cdf2)._cached_schema, cdf1._cached_schema) + # cannot infer when schemas mismatch + self.assertTrue(cdf1.union(cdf3)._cached_schema is None) + + self.assertEqual(cdf1.unionAll(cdf1)._cached_schema, cdf1._cached_schema) + self.assertEqual(cdf1.unionAll(cdf2)._cached_schema, cdf1._cached_schema) + # cannot infer when schemas mismatch + self.assertTrue(cdf1.unionAll(cdf3)._cached_schema is None) + + self.assertEqual(cdf1.unionByName(cdf1)._cached_schema, cdf1._cached_schema) + self.assertEqual(cdf1.unionByName(cdf2)._cached_schema, cdf1._cached_schema) + # cannot infer when schemas mismatch + self.assertTrue(cdf1.unionByName(cdf3)._cached_schema is None) + + self.assertEqual(cdf1.subtract(cdf1)._cached_schema, cdf1._cached_schema) + self.assertEqual(cdf1.subtract(cdf2)._cached_schema, cdf1._cached_schema) + # cannot infer when schemas mismatch + self.assertTrue(cdf1.subtract(cdf3)._cached_schema is None) + + self.assertEqual(cdf1.exceptAll(cdf1)._cached_schema, cdf1._cached_schema) + self.assertEqual(cdf1.exceptAll(cdf2)._cached_schema, cdf1._cached_schema) + # cannot infer when schemas mismatch + self.assertTrue(cdf1.exceptAll(cdf3)._cached_schema is None) + + self.assertEqual(cdf1.intersect(cdf1)._cached_schema, cdf1._cached_schema) + self.assertEqual(cdf1.intersect(cdf2)._cached_schema, cdf1._cached_schema) + # cannot infer when schemas mismatch + self.assertTrue(cdf1.intersect(cdf3)._cached_schema is None) + + self.assertEqual(cdf1.intersectAll(cdf1)._cached_schema, cdf1._cached_schema) + self.assertEqual(cdf1.intersectAll(cdf2)._cached_schema, cdf1._cached_schema) + # cannot infer when schemas mismatch + self.assertTrue(cdf1.intersectAll(cdf3)._cached_schema is None) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_dataframe_property import * # noqa: F401