From 98ba08bbb35e2bd2e5cc9779aa8618cab4f5856e Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Fri, 11 Oct 2024 15:46:47 +0800 Subject: [PATCH 1/5] box plot --- python/pyspark/errors/error-conditions.json | 5 + python/pyspark/sql/plot/core.py | 157 +++++++++++++++++++- python/pyspark/sql/plot/plotly.py | 77 +++++++++- 3 files changed, 236 insertions(+), 3 deletions(-) diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 6ca21d55555d..ab01d386645b 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -1103,6 +1103,11 @@ "`` is not supported, it should be one of the values from " ] }, + "UNSUPPORTED_PLOT_BACKEND_PARAM": { + "message": [ + "`` does not support `` set to , it should be one of the values from " + ] + }, "UNSUPPORTED_SIGNATURE": { "message": [ "Unsupported signature: ." diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index f9667ee2c0d6..ceeea0affb0e 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -15,15 +15,17 @@ # limitations under the License. # -from typing import Any, TYPE_CHECKING, Optional, Union +from typing import Any, TYPE_CHECKING, List, Optional, Union from types import ModuleType from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError +from pyspark.sql import Column, functions as F from pyspark.sql.types import NumericType -from pyspark.sql.utils import require_minimum_plotly_version +from pyspark.sql.utils import is_remote, require_minimum_plotly_version if TYPE_CHECKING: from pyspark.sql import DataFrame + from pyspark.sql._typing import ColumnOrName import pandas as pd from plotly.graph_objs import Figure @@ -338,3 +340,154 @@ def pie(self, x: str, y: str, **kwargs: Any) -> "Figure": }, ) return self(kind="pie", x=x, y=y, **kwargs) + + def box(self, column: Union[str, List[str]], precision: float = 0.01, **kwargs: Any): + """ + Make a box plot of the DataFrame columns. + + Make a box-and-whisker plot from DataFrame columns, optionally grouped by some + other columns. A box plot is a method for graphically depicting groups of numerical + data through their quartiles. The box extends from the Q1 to Q3 quartile values of + the data, with a line at the median (Q2). The whiskers extend from the edges of box + to show the range of the data. By default, they extend no more than + 1.5 * IQR (IQR = Q3 - Q1) from the edges of the box, ending at the farthest data point + within that interval. Outliers are plotted as separate dots. + + Parameters + ---------- + column: str or list of str + Column name or list of names to be used for creating the boxplot. + precision: float, default = 0.01 + This argument is used by pyspark to compute approximate statistics + for building a boxplot. + **kwargs + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + Return an custom object when ``backend!=plotly``. + Return an ndarray when ``subplots=True`` (matplotlib-only). + + Notes + ----- + There are behavior differences between pandas-on-Spark and pandas. + + * pandas-on-Spark computes approximate statistics - expect differences between + pandas and pandas-on-Spark boxplots, especially regarding 1st and 3rd quartiles. + * The `whis` argument is only supported as a single number. + * pandas-on-Spark doesn't support the following argument(s) (matplotlib-only). + + * `bootstrap` argument is not supported + * `autorange` argument is not supported + + Examples + -------- + Draw a box plot from a DataFrame with four columns of randomly + generated data. + + For Series: + + .. plotly:: + + >>> data = np.random.randn(25, 4) + >>> df = ps.DataFrame(data, columns=list('ABCD')) + >>> df['A'].plot.box() # doctest: +SKIP + + This is an unsupported function for DataFrame type + """ + return self(kind="box", column=column, precision=precision, **kwargs) + + +class PySparkBoxPlotBase: + @staticmethod + def compute_box( + sdf: "DataFrame", colnames: List[str], whis: float, precision: float, showfliers: bool + ): + assert len(colnames) > 0 + formatted_colnames = ["`{}`".format(colname) for colname in colnames] + + stats_scols = [] + for i, colname in enumerate(formatted_colnames): + percentiles = F.percentile_approx(colname, [0.25, 0.50, 0.75], int(1.0 / precision)) + q1 = F.get(percentiles, 0) + med = F.get(percentiles, 1) + q3 = F.get(percentiles, 2) + iqr = q3 - q1 + lfence = q1 - F.lit(whis) * iqr + ufence = q3 + F.lit(whis) * iqr + + stats_scols.append( + F.struct( + F.mean(colname).alias("mean"), + med.alias("med"), + q1.alias("q1"), + q3.alias("q3"), + lfence.alias("lfence"), + ufence.alias("ufence"), + ).alias(f"_box_plot_stats_{i}") + ) + + sdf_stats = sdf.select(*stats_scols) + + result_scols = [] + for i, colname in enumerate(formatted_colnames): + value = F.col(colname) + + lfence = F.col(f"_box_plot_stats_{i}.lfence") + ufence = F.col(f"_box_plot_stats_{i}.ufence") + mean = F.col(f"_box_plot_stats_{i}.mean") + med = F.col(f"_box_plot_stats_{i}.med") + q1 = F.col(f"_box_plot_stats_{i}.q1") + q3 = F.col(f"_box_plot_stats_{i}.q3") + + outlier = ~value.between(lfence, ufence) + + # Computes min and max values of non-outliers - the whiskers + upper_whisker = F.max(F.when(~outlier, value).otherwise(F.lit(None))) + lower_whisker = F.min(F.when(~outlier, value).otherwise(F.lit(None))) + + # If it shows fliers, take the top 1k with the highest absolute values + # Here we normalize the values by subtracting the median. + if showfliers: + pair = F.when( + outlier, + F.struct(F.abs(value - med), value.alias("val")), + ).otherwise(F.lit(None)) + topk = collect_top_k(pair, 1001, False) + fliers = F.when(F.size(topk) > 0, topk["val"]).otherwise(F.lit(None)) + else: + fliers = F.lit(None) + + result_scols.append( + F.struct( + F.first(mean).alias("mean"), + F.first(med).alias("med"), + F.first(q1).alias("q1"), + F.first(q3).alias("q3"), + upper_whisker.alias("upper_whisker"), + lower_whisker.alias("lower_whisker"), + fliers.alias("fliers"), + ).alias(f"_box_plot_results_{i}") + ) + + sdf_result = sdf.join(sdf_stats.hint("broadcast")).select(*result_scols) + return sdf_result.first() + + +def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: + if is_remote(): + from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns + + return _invoke_function_over_columns(name, *cols) + + else: + from pyspark.sql.classic.column import _to_seq, _to_java_column + from pyspark import SparkContext + + sc = SparkContext._active_spark_context + return Column(sc._jvm.PythonSQLUtils.internalFn(name, _to_seq(sc, cols, _to_java_column))) + + +def collect_top_k(col: Column, num: int, reverse: bool) -> Column: + return _invoke_internal_function_over_columns("collect_top_k", col, F.lit(num), F.lit(reverse)) diff --git a/python/pyspark/sql/plot/plotly.py b/python/pyspark/sql/plot/plotly.py index 91f536346471..7c1deb6cfc3b 100644 --- a/python/pyspark/sql/plot/plotly.py +++ b/python/pyspark/sql/plot/plotly.py @@ -17,7 +17,8 @@ from typing import TYPE_CHECKING, Any -from pyspark.sql.plot import PySparkPlotAccessor +from pyspark.errors import PySparkValueError +from pyspark.sql.plot import PySparkPlotAccessor, PySparkBoxPlotBase if TYPE_CHECKING: from pyspark.sql import DataFrame @@ -29,6 +30,8 @@ def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure": if kind == "pie": return plot_pie(data, **kwargs) + if kind == "box": + return plot_box(data, **kwargs) return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs) @@ -43,3 +46,75 @@ def plot_pie(data: "DataFrame", **kwargs: Any) -> "Figure": fig = express.pie(pdf, values=y, names=x, **kwargs) return fig + + +def plot_box(data: "DataFrame", **kwargs: Any) -> "Figure": + import plotly.graph_objs as go + + # 'whis' isn't actually an argument in plotly (but in matplotlib). But seems like + # plotly doesn't expose the reach of the whiskers to the beyond the first and + # third quartiles (?). Looks they use default 1.5. + whis = kwargs.pop("whis", 1.5) + # 'precision' is pyspark specific to control precision for approx_percentile + precision = kwargs.pop("precision", 0.01) + colnames = kwargs.pop("column", None) + if isinstance(colnames, str): + colnames = [colnames] + + # Plotly options + boxpoints = kwargs.pop("boxpoints", "suspectedoutliers") + notched = kwargs.pop("notched", False) + if boxpoints not in ["suspectedoutliers", False]: + raise PySparkValueError( + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "boxpoints", + "value": str(boxpoints), + "supported_values": ", ".join(["suspectedoutliers", "False"]), + }, + ) + if notched: + raise PySparkValueError( + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "notched", + "value": str(notched), + "supported_values": ", ".join(["False"]), + }, + ) + + fig = go.Figure() + + results = PySparkBoxPlotBase.compute_box( + data, + colnames, + whis, + precision, + boxpoints is not None, + ) + assert len(results) == len(colnames) + + for i, colname in enumerate(colnames): + result = results[i] + + fig.add_trace( + go.Box( + x=[i], + name=colname, + q1=[result["q1"]], + median=[result["med"]], + q3=[result["q3"]], + mean=[result["mean"]], + lowerfence=[result["lower_whisker"]], + upperfence=[result["upper_whisker"]], + y=[result["fliers"]] if result["fliers"] else None, + boxpoints=boxpoints, + notched=notched, + **kwargs, + ) + ) + + fig["layout"]["yaxis"]["title"] = "value" + return fig From f517f68aad9eb65d476b4bc2a711388e825a4a44 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Fri, 11 Oct 2024 15:46:52 +0800 Subject: [PATCH 2/5] test --- .../sql/tests/plot/test_frame_plot_plotly.py | 77 ++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index b92b5a91cb76..d870cdbf9959 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -19,7 +19,7 @@ from datetime import datetime import pyspark.sql.plot # noqa: F401 -from pyspark.errors import PySparkTypeError +from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message @@ -48,6 +48,22 @@ def sdf3(self): columns = ["sales", "signups", "visits", "date"] return self.spark.createDataFrame(data, columns) + @property + def sdf4(self): + data = [ + ("A", 50, 55), + ("B", 55, 60), + ("C", 60, 65), + ("D", 65, 70), + ("E", 70, 75), + # outliers + ("F", 10, 15), + ("G", 85, 90), + ("H", 5, 150), + ] + columns = ["student", "math_score", "english_score"] + return self.spark.createDataFrame(data, columns) + def _check_fig_data(self, fig_data, **kwargs): for key, expected_value in kwargs.items(): if key in ["x", "y", "labels", "values"]: @@ -300,6 +316,65 @@ def test_pie_plot(self): messageParameters={"arg_name": "y", "arg_type": "StringType()"}, ) + def test_box_plot(self): + fig = self.sdf4.plot.box(column="math_score") + expected_fig_data = { + "boxpoints": "suspectedoutliers", + "lowerfence": (5,), + "mean": (50.0,), + "median": (55,), + "name": "math_score", + "notched": False, + "q1": (10,), + "q3": (65,), + "upperfence": (85,), + "x": [0], + "type": "box", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + + fig = self.sdf4.plot(kind="box", column=["math_score", "english_score"]) + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "boxpoints": "suspectedoutliers", + "lowerfence": (55,), + "mean": (72.5,), + "median": (65,), + "name": "english_score", + "notched": False, + "q1": (55,), + "q3": (75,), + "upperfence": (90,), + "x": [1], + "y": [[150, 15]], + "type": "box", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) + with self.assertRaises(PySparkValueError) as pe: + self.sdf4.plot.box(column="math_score", boxpoints=True) + self.check_error( + exception=pe.exception, + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "boxpoints", + "value": "True", + "supported_values": ", ".join(["suspectedoutliers", "False"]), + }, + ) + with self.assertRaises(PySparkValueError) as pe: + self.sdf4.plot.box(column="math_score", notched=True) + self.check_error( + exception=pe.exception, + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "notched", + "value": "True", + "supported_values": ", ".join(["False"]), + }, + ) + class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase): pass From cb8ae7ad1247b54c16d6e63c3b67fbefb9ca8891 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 14 Oct 2024 14:05:14 +0800 Subject: [PATCH 3/5] typing; doc --- python/pyspark/sql/plot/core.py | 52 +++++++++++++------------------ python/pyspark/sql/plot/plotly.py | 4 +-- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index ceeea0affb0e..ad26fa0d10f4 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: - from pyspark.sql import DataFrame + from pyspark.sql import DataFrame, Row from pyspark.sql._typing import ColumnOrName import pandas as pd from plotly.graph_objs import Figure @@ -341,7 +341,9 @@ def pie(self, x: str, y: str, **kwargs: Any) -> "Figure": ) return self(kind="pie", x=x, y=y, **kwargs) - def box(self, column: Union[str, List[str]], precision: float = 0.01, **kwargs: Any): + def box( + self, column: Union[str, List[str]], precision: float = 0.01, **kwargs: Any + ) -> "Figure": """ Make a box plot of the DataFrame columns. @@ -366,35 +368,23 @@ def box(self, column: Union[str, List[str]], precision: float = 0.01, **kwargs: Returns ------- :class:`plotly.graph_objs.Figure` - Return an custom object when ``backend!=plotly``. - Return an ndarray when ``subplots=True`` (matplotlib-only). - - Notes - ----- - There are behavior differences between pandas-on-Spark and pandas. - - * pandas-on-Spark computes approximate statistics - expect differences between - pandas and pandas-on-Spark boxplots, especially regarding 1st and 3rd quartiles. - * The `whis` argument is only supported as a single number. - * pandas-on-Spark doesn't support the following argument(s) (matplotlib-only). - - * `bootstrap` argument is not supported - * `autorange` argument is not supported Examples -------- - Draw a box plot from a DataFrame with four columns of randomly - generated data. - - For Series: - - .. plotly:: - - >>> data = np.random.randn(25, 4) - >>> df = ps.DataFrame(data, columns=list('ABCD')) - >>> df['A'].plot.box() # doctest: +SKIP - - This is an unsupported function for DataFrame type + >>> data = [ + ... ("A", 50, 55), + ... ("B", 55, 60), + ... ("C", 60, 65), + ... ("D", 65, 70), + ... ("E", 70, 75), + ... ("F", 10, 15), + ... ("G", 85, 90), + ... ("H", 5, 150), + ... ] + >>> columns = ["student", "math_score", "english_score"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.box(column="math_score") # doctest: +SKIP + >>> df.plot.box(column=["math_score", "english_score"]) # doctest: +SKIP """ return self(kind="box", column=column, precision=precision, **kwargs) @@ -403,7 +393,7 @@ class PySparkBoxPlotBase: @staticmethod def compute_box( sdf: "DataFrame", colnames: List[str], whis: float, precision: float, showfliers: bool - ): + ) -> Optional["Row"]: assert len(colnames) > 0 formatted_colnames = ["`{}`".format(colname) for colname in colnames] @@ -486,7 +476,9 @@ def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName") -> from pyspark import SparkContext sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.internalFn(name, _to_seq(sc, cols, _to_java_column))) + return Column( + sc._jvm.PythonSQLUtils.internalFn(name, _to_seq(sc, cols, _to_java_column)) # type: ignore + ) def collect_top_k(col: Column, num: int, reverse: bool) -> Column: diff --git a/python/pyspark/sql/plot/plotly.py b/python/pyspark/sql/plot/plotly.py index 7c1deb6cfc3b..71d40720e874 100644 --- a/python/pyspark/sql/plot/plotly.py +++ b/python/pyspark/sql/plot/plotly.py @@ -94,10 +94,10 @@ def plot_box(data: "DataFrame", **kwargs: Any) -> "Figure": precision, boxpoints is not None, ) - assert len(results) == len(colnames) + assert len(results) == len(colnames) # type: ignore for i, colname in enumerate(colnames): - result = results[i] + result = results[i] # type: ignore fig.add_trace( go.Box( From d75b5c9a914c3c9699f2628a76a486fea91bac61 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 14 Oct 2024 14:27:14 +0800 Subject: [PATCH 4/5] type --- python/pyspark/sql/plot/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index ad26fa0d10f4..a34486e12e75 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -477,7 +477,9 @@ def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName") -> sc = SparkContext._active_spark_context return Column( - sc._jvm.PythonSQLUtils.internalFn(name, _to_seq(sc, cols, _to_java_column)) # type: ignore + sc._jvm.PythonSQLUtils.internalFn( # type: ignore + name, _to_seq(sc, cols, _to_java_column) + ) ) From 0f2287fa53915ea6814b487c5188e03c6cb610f6 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 14 Oct 2024 18:52:30 +0800 Subject: [PATCH 5/5] type --- python/pyspark/sql/plot/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index a34486e12e75..4bf75474d92c 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -478,7 +478,7 @@ def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName") -> sc = SparkContext._active_spark_context return Column( sc._jvm.PythonSQLUtils.internalFn( # type: ignore - name, _to_seq(sc, cols, _to_java_column) + name, _to_seq(sc, cols, _to_java_column) # type: ignore ) )