From e9200400aa25af4ec15a0d0be34d61c44822ef4c Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Sun, 30 Jun 2024 23:21:50 +0200 Subject: [PATCH 1/2] [SPARK-48756] Support for `df.debug()` in Connect Mode --- .../reference/pyspark.sql/dataframe.rst | 1 + .../connect_execution_info_and_debug.rst | 60 +++++++++++ python/docs/source/user_guide/index.rst | 1 + python/pyspark/errors/error-conditions.json | 5 + python/pyspark/errors/utils.py | 55 ++++++++--- python/pyspark/sql/classic/dataframe.py | 6 ++ python/pyspark/sql/connect/dataframe.py | 35 ++++++- python/pyspark/sql/connect/observation.py | 24 ++++- python/pyspark/sql/connect/plan.py | 4 + python/pyspark/sql/dataframe.py | 15 +++ python/pyspark/sql/metrics.py | 99 ++++++++++++++++++- python/pyspark/sql/observation.py | 6 +- .../sql/tests/connect/test_df_debug.py | 93 +++++++++++++++-- 13 files changed, 373 insertions(+), 31 deletions(-) create mode 100644 python/docs/source/user_guide/connect_execution_info_and_debug.rst diff --git a/python/docs/source/reference/pyspark.sql/dataframe.rst b/python/docs/source/reference/pyspark.sql/dataframe.rst index d0196baa7a05b..abc9441160e10 100644 --- a/python/docs/source/reference/pyspark.sql/dataframe.rst +++ b/python/docs/source/reference/pyspark.sql/dataframe.rst @@ -46,6 +46,7 @@ DataFrame DataFrame.crossJoin DataFrame.crosstab DataFrame.cube + DataFrame.debug DataFrame.describe DataFrame.distinct DataFrame.drop diff --git a/python/docs/source/user_guide/connect_execution_info_and_debug.rst b/python/docs/source/user_guide/connect_execution_info_and_debug.rst new file mode 100644 index 0000000000000..aba96999562af --- /dev/null +++ b/python/docs/source/user_guide/connect_execution_info_and_debug.rst @@ -0,0 +1,60 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +=========== +Spark Connect - Execution Info and Debug +=========== + + +Execution Info +-------------- + +The ``executionInfo`` property of the DataFrame allows users to access execution +metrics about a previously executed operation. In Spark Connect mode, the +plan metrics of the execution are always submitted as the last elements of the +response allowing users an easy way to present this information. + +.. code-block:: python + df = spark.range(100) + df.collect() + ei = df.executionInfo + + # Access the execution metrics: + metrics = ei.metrics + print(metrics.toText()) + +Debugging DataFrame Data Flows +------------------------------- +Sometimes it is useful to understand the data flow of a DataFrame operation. Whereas +metrics allow to track row counts between different operators, the execution plan +does not always resemble the semantic execution. + +The ``debug`` method allows users to inject predefiend observation points into the +query execution. After execution the user can access the observations and access +the associated metrics. + +By default, calling ``debug()`` will inject a single observation that counts the number +of rows flowing out of the DataFrame. + + +.. code-block:: python + df = spark.range(100).debug() + filtered = df.where(df.id < 10).debug() + filtered.collect() + ei = df.executionInfo + for op in ei.observations: + print(op.debugString()) diff --git a/python/docs/source/user_guide/index.rst b/python/docs/source/user_guide/index.rst index 67f8c8d4d0fe3..4264c48d06baf 100644 --- a/python/docs/source/user_guide/index.rst +++ b/python/docs/source/user_guide/index.rst @@ -28,6 +28,7 @@ PySpark specific user guides are available here: python_packaging sql/index pandas_on_spark/index + connect_execution_info_and_debug There are also basic programming guides covering multiple languages available in `the Spark documentation `_, including these: diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index dd70e814b1ea8..e756ce5642b4b 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -1028,6 +1028,11 @@ "Unknown value for ``." ] }, + "UNSUPPORTED_DATADEBUGOP": { + "message": [ + "Argument `` should be a DataDebugOp, got ." + ] + }, "UNSUPPORTED_DATA_TYPE": { "message": [ "Unsupported DataType ``." diff --git a/python/pyspark/errors/utils.py b/python/pyspark/errors/utils.py index 9155bfb54abe8..20936f2081c08 100644 --- a/python/pyspark/errors/utils.py +++ b/python/pyspark/errors/utils.py @@ -20,7 +20,7 @@ import inspect import os import threading -from typing import Any, Callable, Dict, Match, TypeVar, Type, Optional, TYPE_CHECKING +from typing import Any, Callable, Dict, Match, TypeVar, Type, Optional, TYPE_CHECKING, List import pyspark from pyspark.errors.error_classes import ERROR_CLASSES_MAP @@ -165,19 +165,11 @@ def _capture_call_site(spark_session: "SparkSession", depth: int) -> str: The call site information is used to enhance error messages with the exact location in the user code that led to the error. """ - # Filtering out PySpark code and keeping user code only - pyspark_root = os.path.dirname(pyspark.__file__) - stack = [ - frame_info for frame_info in inspect.stack() if pyspark_root not in frame_info.filename - ] - - selected_frames = stack[:depth] - - # We try import here since IPython is not a required dependency + selected_frames = call_site_stack(depth) try: - from IPython import get_ipython + import IPython - ipython = get_ipython() + ipython = IPython.get_ipython() except ImportError: ipython = None @@ -189,7 +181,6 @@ def _capture_call_site(spark_session: "SparkSession", depth: int) -> str: else: call_sites = [f"{frame.filename}:{frame.lineno}" for frame in selected_frames] call_sites_str = "\n".join(call_sites) - return call_sites_str @@ -257,3 +248,41 @@ def with_origin_to_class(cls: Type[T]) -> Type[T]: ): setattr(cls, name, _with_origin(method)) return cls + + +def call_site_stack(depth: int = 10) -> List[inspect.FrameInfo]: + """ + Capture the call site stack and filter out all stack frames that are not user code. + + This function will return the call stack above all PySpark code and IPython code. Usually + the first frame will be the place where the user code reached the PySpark API. + + If SPARK_TESTING is set in the environment, all frames will be returned. + + Parameters + ---------- + depth : int + How many stack frames to select + + """ + + # Filtering out PySpark code and keeping user code only + pyspark_root = os.path.dirname(pyspark.__file__) + stack = [ + frame_info + for frame_info in inspect.stack() + if pyspark_root not in frame_info.filename or "SPARK_TESTING" in os.environ + ] + + selected_frames = stack[:depth] + + # We try import here since IPython is not a required dependency + try: + import IPython + + ipy_root = IPython.__file__ + selected_frames = [f for f in selected_frames if ipy_root not in f.filename] + except ImportError: + pass + + return selected_frames diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index 1bedd624603e1..ad78653e9cdfa 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -1843,6 +1843,12 @@ def executionInfo(self) -> Optional["ExecutionInfo"]: message_parameters={"member": "queryExecution"}, ) + def debug(self) -> "DataFrame": + raise PySparkValueError( + error_class="CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF", + message_parameters={"member": "debug"}, + ) + def _to_scala_map(sc: "SparkContext", jm: Dict) -> "JavaObject": """ diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 1aa8fc00cfcc9..fcc815a6b62c4 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -50,8 +50,10 @@ import warnings from collections.abc import Iterable import functools +from uuid import uuid4 from pyspark import _NoValue +from pyspark.errors.utils import call_site_stack from pyspark._globals import _NoValueType from pyspark.util import is_remote_only from pyspark.sql.types import Row, StructType, _create_row @@ -84,6 +86,7 @@ from pyspark.sql.connect.functions import builtin as F from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined] +from pyspark.sql.metrics import DataDebugOp if TYPE_CHECKING: @@ -101,7 +104,7 @@ from pyspark.sql.connect.observation import Observation from pyspark.sql.connect.session import SparkSession from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame - from pyspark.sql.metrics import ExecutionInfo + from pyspark.sql.metrics import ExecutionInfo, DataDebugOp class DataFrame(ParentDataFrame): @@ -2227,8 +2230,38 @@ def rdd(self) -> "RDD[Row]": @property def executionInfo(self) -> Optional["ExecutionInfo"]: + # Update the observations if needed. + if self._plan.observations: + if self._execution_info and not self._execution_info.observations: + self._execution_info.setObservations(self._plan.observations) return self._execution_info + def debug(self, *other: List["DataDebugOp"]) -> "DataFrame": + # Needs to be imported here to avoid the recursive import. + from pyspark.sql.connect.observation import Observation + + # Extract the stack + stack = call_site_stack(depth=10) + frames = [f"{s.filename}:{s.lineno}@{s.function}" for s in stack] + + # Check that all elements are of type 'DataDebugOp' + for op in other: + if not isinstance(op, DataDebugOp): + raise PySparkTypeError( + error_class="UNSUPPORTED_DATADEBUGOP", + message_parameters={"arg_name": "other", "arg_type": type(op).__name__}, + ) + + # Capture the expressions for the debug op. + ops: List[DataDebugOp] = [ + DataDebugOp.count_values(), + ] + list(other) + exprs = list(map(lambda x: x(), ops)) + + # Create the Observation that captures all the expressions for this "debug" op. + obs = Observation(name=f"debug:{uuid4()}", call_site=frames, plan_id=self._plan.plan_id) + return self.observe(obs, *exprs) + class DataFrameNaFunctions(ParentDataFrameNaFunctions): def __init__(self, df: ParentDataFrame): diff --git a/python/pyspark/sql/connect/observation.py b/python/pyspark/sql/connect/observation.py index 2471cf04cfbe7..1d6e576d4f1f1 100644 --- a/python/pyspark/sql/connect/observation.py +++ b/python/pyspark/sql/connect/observation.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List import uuid from pyspark.errors import ( @@ -33,7 +33,9 @@ class Observation: - def __init__(self, name: Optional[str] = None) -> None: + def __init__( + self, name: Optional[str] = None, call_site: Optional[List[str]] = None, plan_id: int = -1 + ) -> None: if name is not None: if not isinstance(name, str): raise PySparkTypeError( @@ -47,9 +49,19 @@ def __init__(self, name: Optional[str] = None) -> None: ) self._name = name self._result: Optional[Dict[str, Any]] = None + self._call_site = call_site + self._plan_id = plan_id __init__.__doc__ = PySparkObservation.__init__.__doc__ + @property + def callSite(self) -> Optional[List[str]]: + return self._call_site + + @property + def planId(self) -> int: + return self._plan_id + def _on(self, df: DataFrame, *exprs: Column) -> DataFrame: if self._result is not None: raise PySparkAssertionError(error_class="REUSE_OBSERVATION", message_parameters={}) @@ -75,6 +87,14 @@ def get(self) -> Dict[str, Any]: return self._result + def debugString(self) -> str: + call_site = self._call_site[0] if self._call_site is not None else "" + metrics = ", ".join([f"{k}={v}" for k, v in self.get.items()]) + return ( + f"Observation(name={self._name}, planId={self._plan_id}," + f" metrics={metrics}, callSite={call_site})" + ) + get.__doc__ = PySparkObservation.get.__doc__ diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 19377515ed28c..af02978a39707 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -87,6 +87,10 @@ def _fresh_plan_id() -> int: assert plan_id is not None return plan_id + @property + def plan_id(self) -> int: + return self._plan_id + def _create_proto_relation(self) -> proto.Relation: plan = proto.Relation() plan.common.plan_id = self._plan_id diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 625678588bf9e..33b5ff0a97c1e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -6307,6 +6307,21 @@ def executionInfo(self) -> Optional["ExecutionInfo"]: """ ... + def debug(self) -> "DataFrame": + """ + Helper function that allows to debug the query execution with customer observations. + + Essentially, this method is a wrapper around the `observe()` method, but simplifies + the usage. In addition, it makes sure that the captured metrics are properly collected + as part of the execution info. + + .. versionadded:: 4.0.0 + Returns + ------- + DataFrame instance with the observations added + """ + ... + class DataFrameNaFunctions: """Functionality for working with missing data in :class:`DataFrame`. diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py index 6664582952014..fe3af1a6b304f 100644 --- a/python/pyspark/sql/metrics.py +++ b/python/pyspark/sql/metrics.py @@ -18,7 +18,10 @@ import dataclasses from typing import Optional, List, Tuple, Dict, Any, Union, TYPE_CHECKING, Sequence -from pyspark.errors import PySparkValueError +from pyspark.errors import PySparkValueError, PySparkTypeError +from pyspark.sql import Observation, Column + +import pyspark.sql.functions as F if TYPE_CHECKING: from pyspark.testing.connectutils import have_graphviz @@ -98,6 +101,72 @@ def metrics(self) -> List[MetricValue]: return self._metrics +class DataDebugOp: + """ + The DataDebugOp class is a helper class that allows to encapsulate different reusable + data debug operations. + """ + + @classmethod + def count_values(cls) -> "DataDebugOp": + return DataDebugOp("count_values", F.count(F.lit(1)).alias("count_values")) + + @classmethod + def count_null_values(cls, name: str) -> "DataDebugOp": + if not isinstance(name, str): + raise PySparkTypeError( + error_class="NOT_STR", + message_parameters={"arg_name": "name", "arg_type": type(name).__name__}, + ) + return DataDebugOp( + f"count_null_values_{name}", + F.count(F.when(F.col(name).isNull(), 1)).alias(f"count_null_values_{name}"), + ) + + @classmethod + def count_distinct_values(cls, name: str) -> "DataDebugOp": + if not isinstance(name, str): + raise PySparkTypeError( + error_class="NOT_STR", + message_parameters={"arg_name": "name", "arg_type": type(name).__name__}, + ) + return DataDebugOp( + f"count_distinct_values_{name}", + F.approx_count_distinct(name).alias(f"count_distinct_values_{name}"), + ) + + @classmethod + def max_value(cls, name: str) -> "DataDebugOp": + if not isinstance(name, str): + raise PySparkTypeError( + error_class="NOT_STR", + message_parameters={"arg_name": "name", "arg_type": type(name).__name__}, + ) + return DataDebugOp( + f"max_value_{name}", + F.max(name).alias(f"max_value_{name}"), + ) + + @classmethod + def min_value(cls, name: str) -> "DataDebugOp": + if not isinstance(name, str): + raise PySparkTypeError( + error_class="NOT_STR", + message_parameters={"arg_name": "name", "arg_type": type(name).__name__}, + ) + return DataDebugOp( + f"min_value_{name}", + F.min(name).alias(f"min_value_{name}"), + ) + + def __init__(self, name: str, col: Column): + self._name = name + self._col = col + + def __call__(self) -> Column: + return self._col + + class CollectedMetrics: @dataclasses.dataclass class Node: @@ -276,7 +345,12 @@ def __init__( self, metrics: Optional[list[PlanMetrics]], obs: Optional[Sequence[ObservedMetrics]] ): self._metrics = CollectedMetrics(metrics) if metrics else None - self._observations = obs if obs else [] + # These are the metrics that were observed from the + self._observed_metrics = [o for o in obs if o.name.startswith("debug:")] if obs else [] + self._observations: Optional[Dict[str, Observation]] = None + + def setObservations(self, observations: Dict[str, Observation]) -> None: + self._observations = observations @property def metrics(self) -> Optional[CollectedMetrics]: @@ -284,4 +358,23 @@ def metrics(self) -> Optional[CollectedMetrics]: @property def flows(self) -> List[Tuple[str, Dict[str, Any]]]: - return [(f.name, f.pairs) for f in self._observations] + return [(f.name, f.pairs) for f in self._observed_metrics] + + @property + def observedMetrics(self) -> List[ObservedMetrics]: + return self._observed_metrics + + @property + def observations(self) -> Optional[List[Observation]]: + """ + Returns the observations that were collected during the execution of the query. The + observations are returned as a list oredered by the plan ID under the assumption that + lower plan IDs have been created earlier. + + Returns + ------- + A list of observations, None if no observations were collected. + """ + if self._observations: + return sorted(self._observations.values(), key=lambda x: x._plan_id) + return None diff --git a/python/pyspark/sql/observation.py b/python/pyspark/sql/observation.py index 4ef4c78ba3c33..3a160c096423a 100644 --- a/python/pyspark/sql/observation.py +++ b/python/pyspark/sql/observation.py @@ -15,7 +15,7 @@ # limitations under the License. # import os -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING, List from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkAssertionError from pyspark.sql.column import Column @@ -77,7 +77,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any: return ConnectObservation(*args, **kwargs) return super().__new__(cls) - def __init__(self, name: Optional[str] = None) -> None: + def __init__(self, name: Optional[str] = None, call_site: Optional[List[str]] = None) -> None: """Constructs a named or unnamed Observation instance. Parameters @@ -99,6 +99,8 @@ def __init__(self, name: Optional[str] = None) -> None: self._name = name self._jvm: Optional[JVMView] = None self._jo: Optional["JavaObject"] = None + self._call_site = call_site + self._plan_id = -1 def _on(self, df: DataFrame, *exprs: Column) -> DataFrame: """Attaches this observation to the given :class:`DataFrame` to observe aggregations. diff --git a/python/pyspark/sql/tests/connect/test_df_debug.py b/python/pyspark/sql/tests/connect/test_df_debug.py index 8a4ec68fda844..a0a5324ff0cca 100644 --- a/python/pyspark/sql/tests/connect/test_df_debug.py +++ b/python/pyspark/sql/tests/connect/test_df_debug.py @@ -14,23 +14,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest +from pyspark.sql.metrics import DataDebugOp from pyspark.testing.connectutils import ( should_test_connect, have_graphviz, graphviz_requirement_message, + ReusedConnectTestCase, ) -from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame -class SparkConnectDataFrameDebug(SparkConnectSQLTestCase): +class SparkConnectDataFrameDebug(ReusedConnectTestCase): def test_df_debug_basics(self): - df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + df: DataFrame = self.spark.range(100).repartition(10).groupBy("id").count() x = df.collect() # noqa: F841 ei = df.executionInfo @@ -38,12 +39,12 @@ def test_df_debug_basics(self): self.assertIn(root, graph, "The root must be rooted in the graph") def test_df_quey_execution_empty_before_execution(self): - df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + df: DataFrame = self.spark.range(100).repartition(10).groupBy("id").count() ei = df.executionInfo self.assertIsNone(ei, "The query execution must be None before the action is executed") def test_df_query_execution_with_writes(self): - df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + df: DataFrame = self.spark.range(100).repartition(10).groupBy("id").count() df.write.save("/tmp/test_df_query_execution_with_writes", format="json", mode="overwrite") ei = df.executionInfo self.assertIsNotNone( @@ -51,19 +52,19 @@ def test_df_query_execution_with_writes(self): ) def test_query_execution_text_format(self): - df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + df: DataFrame = self.spark.range(100).repartition(10).groupBy("id").count() df.collect() self.assertIn("HashAggregate", df.executionInfo.metrics.toText()) # Different execution mode. - df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + df: DataFrame = self.spark.range(100).repartition(10).groupBy("id").count() df.toPandas() self.assertIn("HashAggregate", df.executionInfo.metrics.toText()) @unittest.skipIf(not have_graphviz, graphviz_requirement_message) def test_df_query_execution_metrics_to_dot(self): - df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() - x = df.collect() # noqa: F841 + df: DataFrame = self.spark.range(100).repartition(10).groupBy("id").count() + df.collect() ei = df.executionInfo dot = ei.metrics.toDot() @@ -73,6 +74,78 @@ def test_df_query_execution_metrics_to_dot(self): self.assertIn("digraph", source, "The dot representation must contain the digraph keyword") self.assertIn("Metrics", source, "The dot representation must contain the Metrics keyword") + def test_query_with_debug(self): + self.assertIn("SPARK_TESTING", os.environ, "SPARK_TESTING must be set to run this test") + + df: DataFrame = self.spark.range(100).repartition(10).groupBy("id").count() + df = df.debug() + df.collect() + ei = df.executionInfo + self.assertIsNotNone( + ei.observations, "Observations must be present when debug has been used" + ) + + # Check that the call site contains this function in the top stack element + observations = ei.observations + + # THe stack should contain call_site_stack (0), df.debug() (1), test_query_with_debug (2) + stack = observations[0]._call_site[2] + self.assertIn( + "test_query_with_debug", + stack, + ( + "The call site must contain this function: stack " + ";".join(observations[0]._call_site) + ), + ) + + def test_query_with_debug_and_plan_id_order(self): + a: DataFrame = self.spark.range(100).repartition(10).groupBy("id").count() + a = a.debug() + b: DataFrame = self.spark.range(1000) + b = b.filter(b.id < 10).debug() + c: DataFrame = a.join(b, "id") + c = c.debug() + c.collect() + + ei = c.executionInfo + self.assertIsNotNone(ei, "The query execution must be set after the action is executed") + self.assertEqual( + len(ei.observations), + 3, + "The number of observations must be equal to the number of debug", + ) + + # Check the count values for all observations + o1 = ei.observations[0] + o2 = ei.observations[1] + o3 = ei.observations[2] + + self.assertEqual(o1.get["count_values"], 100, "The count values must be 100") + self.assertEqual(o2.get["count_values"], 10, "The count values must be 10 after filter") + self.assertEqual(o3.get["count_values"], 10, "The count values must be 10 after join") + + def test_query_with_debug_and_other_ops(self): + a: DataFrame = self.spark.range(100).repartition(10).groupBy("id").count() + b = a.debug(DataDebugOp.count_distinct_values("id")) + b.collect() + ei = b.executionInfo + self.assertIsNotNone(ei, "The query execution must be set after the action is executed") + self.assertEqual( + len(ei.observations), + 1, + "The number of observations must be equal to the number of debug observations", + ) + + o1 = ei.observations[0] # count_values + self.assertEqual(2, len(o1.get.values()), "Two metrics were collected") + self.assertEqual(o1.get["count_values"], 100, "The count values must be 100") + self.assertGreater( + o1.get["count_distinct_values_id"], + 90, + "The approx count distinct values must be roughly", + ) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_df_debug import * # noqa: F401 From fff10910245e640663a5dacc614e9b4711c113e7 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Sun, 30 Jun 2024 23:37:12 +0200 Subject: [PATCH 2/2] more tests --- python/pyspark/sql/tests/connect/test_df_debug.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_df_debug.py b/python/pyspark/sql/tests/connect/test_df_debug.py index a0a5324ff0cca..35283f30c88cb 100644 --- a/python/pyspark/sql/tests/connect/test_df_debug.py +++ b/python/pyspark/sql/tests/connect/test_df_debug.py @@ -127,7 +127,11 @@ def test_query_with_debug_and_plan_id_order(self): def test_query_with_debug_and_other_ops(self): a: DataFrame = self.spark.range(100).repartition(10).groupBy("id").count() - b = a.debug(DataDebugOp.count_distinct_values("id")) + b = a.debug( + DataDebugOp.count_distinct_values("id"), + DataDebugOp.max_value("id"), + DataDebugOp.min_value("id"), + ) b.collect() ei = b.executionInfo self.assertIsNotNone(ei, "The query execution must be set after the action is executed") @@ -138,13 +142,17 @@ def test_query_with_debug_and_other_ops(self): ) o1 = ei.observations[0] # count_values - self.assertEqual(2, len(o1.get.values()), "Two metrics were collected") + self.assertEqual( + 4, len(o1.get.values()), "Four metrics were collected (count, min, max. distinct)" + ) self.assertEqual(o1.get["count_values"], 100, "The count values must be 100") self.assertGreater( o1.get["count_distinct_values_id"], 90, "The approx count distinct values must be roughly", ) + self.assertEqual(o1.get["max_value_id"], 99, "The max value must be 99") + self.assertEqual(o1.get["min_value_id"], 0, "The min value must be 0") if __name__ == "__main__":