diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 6ca21d55555d..ab01d386645b 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -1103,6 +1103,11 @@ "`` is not supported, it should be one of the values from " ] }, + "UNSUPPORTED_PLOT_BACKEND_PARAM": { + "message": [ + "`` does not support `` set to , it should be one of the values from " + ] + }, "UNSUPPORTED_SIGNATURE": { "message": [ "Unsupported signature: ." diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index f9667ee2c0d6..4bf75474d92c 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -15,15 +15,17 @@ # limitations under the License. # -from typing import Any, TYPE_CHECKING, Optional, Union +from typing import Any, TYPE_CHECKING, List, Optional, Union from types import ModuleType from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError +from pyspark.sql import Column, functions as F from pyspark.sql.types import NumericType -from pyspark.sql.utils import require_minimum_plotly_version +from pyspark.sql.utils import is_remote, require_minimum_plotly_version if TYPE_CHECKING: - from pyspark.sql import DataFrame + from pyspark.sql import DataFrame, Row + from pyspark.sql._typing import ColumnOrName import pandas as pd from plotly.graph_objs import Figure @@ -338,3 +340,148 @@ def pie(self, x: str, y: str, **kwargs: Any) -> "Figure": }, ) return self(kind="pie", x=x, y=y, **kwargs) + + def box( + self, column: Union[str, List[str]], precision: float = 0.01, **kwargs: Any + ) -> "Figure": + """ + Make a box plot of the DataFrame columns. + + Make a box-and-whisker plot from DataFrame columns, optionally grouped by some + other columns. A box plot is a method for graphically depicting groups of numerical + data through their quartiles. The box extends from the Q1 to Q3 quartile values of + the data, with a line at the median (Q2). The whiskers extend from the edges of box + to show the range of the data. By default, they extend no more than + 1.5 * IQR (IQR = Q3 - Q1) from the edges of the box, ending at the farthest data point + within that interval. Outliers are plotted as separate dots. + + Parameters + ---------- + column: str or list of str + Column name or list of names to be used for creating the boxplot. + precision: float, default = 0.01 + This argument is used by pyspark to compute approximate statistics + for building a boxplot. + **kwargs + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> data = [ + ... ("A", 50, 55), + ... ("B", 55, 60), + ... ("C", 60, 65), + ... ("D", 65, 70), + ... ("E", 70, 75), + ... ("F", 10, 15), + ... ("G", 85, 90), + ... ("H", 5, 150), + ... ] + >>> columns = ["student", "math_score", "english_score"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.box(column="math_score") # doctest: +SKIP + >>> df.plot.box(column=["math_score", "english_score"]) # doctest: +SKIP + """ + return self(kind="box", column=column, precision=precision, **kwargs) + + +class PySparkBoxPlotBase: + @staticmethod + def compute_box( + sdf: "DataFrame", colnames: List[str], whis: float, precision: float, showfliers: bool + ) -> Optional["Row"]: + assert len(colnames) > 0 + formatted_colnames = ["`{}`".format(colname) for colname in colnames] + + stats_scols = [] + for i, colname in enumerate(formatted_colnames): + percentiles = F.percentile_approx(colname, [0.25, 0.50, 0.75], int(1.0 / precision)) + q1 = F.get(percentiles, 0) + med = F.get(percentiles, 1) + q3 = F.get(percentiles, 2) + iqr = q3 - q1 + lfence = q1 - F.lit(whis) * iqr + ufence = q3 + F.lit(whis) * iqr + + stats_scols.append( + F.struct( + F.mean(colname).alias("mean"), + med.alias("med"), + q1.alias("q1"), + q3.alias("q3"), + lfence.alias("lfence"), + ufence.alias("ufence"), + ).alias(f"_box_plot_stats_{i}") + ) + + sdf_stats = sdf.select(*stats_scols) + + result_scols = [] + for i, colname in enumerate(formatted_colnames): + value = F.col(colname) + + lfence = F.col(f"_box_plot_stats_{i}.lfence") + ufence = F.col(f"_box_plot_stats_{i}.ufence") + mean = F.col(f"_box_plot_stats_{i}.mean") + med = F.col(f"_box_plot_stats_{i}.med") + q1 = F.col(f"_box_plot_stats_{i}.q1") + q3 = F.col(f"_box_plot_stats_{i}.q3") + + outlier = ~value.between(lfence, ufence) + + # Computes min and max values of non-outliers - the whiskers + upper_whisker = F.max(F.when(~outlier, value).otherwise(F.lit(None))) + lower_whisker = F.min(F.when(~outlier, value).otherwise(F.lit(None))) + + # If it shows fliers, take the top 1k with the highest absolute values + # Here we normalize the values by subtracting the median. + if showfliers: + pair = F.when( + outlier, + F.struct(F.abs(value - med), value.alias("val")), + ).otherwise(F.lit(None)) + topk = collect_top_k(pair, 1001, False) + fliers = F.when(F.size(topk) > 0, topk["val"]).otherwise(F.lit(None)) + else: + fliers = F.lit(None) + + result_scols.append( + F.struct( + F.first(mean).alias("mean"), + F.first(med).alias("med"), + F.first(q1).alias("q1"), + F.first(q3).alias("q3"), + upper_whisker.alias("upper_whisker"), + lower_whisker.alias("lower_whisker"), + fliers.alias("fliers"), + ).alias(f"_box_plot_results_{i}") + ) + + sdf_result = sdf.join(sdf_stats.hint("broadcast")).select(*result_scols) + return sdf_result.first() + + +def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: + if is_remote(): + from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns + + return _invoke_function_over_columns(name, *cols) + + else: + from pyspark.sql.classic.column import _to_seq, _to_java_column + from pyspark import SparkContext + + sc = SparkContext._active_spark_context + return Column( + sc._jvm.PythonSQLUtils.internalFn( # type: ignore + name, _to_seq(sc, cols, _to_java_column) # type: ignore + ) + ) + + +def collect_top_k(col: Column, num: int, reverse: bool) -> Column: + return _invoke_internal_function_over_columns("collect_top_k", col, F.lit(num), F.lit(reverse)) diff --git a/python/pyspark/sql/plot/plotly.py b/python/pyspark/sql/plot/plotly.py index 91f536346471..71d40720e874 100644 --- a/python/pyspark/sql/plot/plotly.py +++ b/python/pyspark/sql/plot/plotly.py @@ -17,7 +17,8 @@ from typing import TYPE_CHECKING, Any -from pyspark.sql.plot import PySparkPlotAccessor +from pyspark.errors import PySparkValueError +from pyspark.sql.plot import PySparkPlotAccessor, PySparkBoxPlotBase if TYPE_CHECKING: from pyspark.sql import DataFrame @@ -29,6 +30,8 @@ def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure": if kind == "pie": return plot_pie(data, **kwargs) + if kind == "box": + return plot_box(data, **kwargs) return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs) @@ -43,3 +46,75 @@ def plot_pie(data: "DataFrame", **kwargs: Any) -> "Figure": fig = express.pie(pdf, values=y, names=x, **kwargs) return fig + + +def plot_box(data: "DataFrame", **kwargs: Any) -> "Figure": + import plotly.graph_objs as go + + # 'whis' isn't actually an argument in plotly (but in matplotlib). But seems like + # plotly doesn't expose the reach of the whiskers to the beyond the first and + # third quartiles (?). Looks they use default 1.5. + whis = kwargs.pop("whis", 1.5) + # 'precision' is pyspark specific to control precision for approx_percentile + precision = kwargs.pop("precision", 0.01) + colnames = kwargs.pop("column", None) + if isinstance(colnames, str): + colnames = [colnames] + + # Plotly options + boxpoints = kwargs.pop("boxpoints", "suspectedoutliers") + notched = kwargs.pop("notched", False) + if boxpoints not in ["suspectedoutliers", False]: + raise PySparkValueError( + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "boxpoints", + "value": str(boxpoints), + "supported_values": ", ".join(["suspectedoutliers", "False"]), + }, + ) + if notched: + raise PySparkValueError( + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "notched", + "value": str(notched), + "supported_values": ", ".join(["False"]), + }, + ) + + fig = go.Figure() + + results = PySparkBoxPlotBase.compute_box( + data, + colnames, + whis, + precision, + boxpoints is not None, + ) + assert len(results) == len(colnames) # type: ignore + + for i, colname in enumerate(colnames): + result = results[i] # type: ignore + + fig.add_trace( + go.Box( + x=[i], + name=colname, + q1=[result["q1"]], + median=[result["med"]], + q3=[result["q3"]], + mean=[result["mean"]], + lowerfence=[result["lower_whisker"]], + upperfence=[result["upper_whisker"]], + y=[result["fliers"]] if result["fliers"] else None, + boxpoints=boxpoints, + notched=notched, + **kwargs, + ) + ) + + fig["layout"]["yaxis"]["title"] = "value" + return fig diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index b92b5a91cb76..d870cdbf9959 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -19,7 +19,7 @@ from datetime import datetime import pyspark.sql.plot # noqa: F401 -from pyspark.errors import PySparkTypeError +from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message @@ -48,6 +48,22 @@ def sdf3(self): columns = ["sales", "signups", "visits", "date"] return self.spark.createDataFrame(data, columns) + @property + def sdf4(self): + data = [ + ("A", 50, 55), + ("B", 55, 60), + ("C", 60, 65), + ("D", 65, 70), + ("E", 70, 75), + # outliers + ("F", 10, 15), + ("G", 85, 90), + ("H", 5, 150), + ] + columns = ["student", "math_score", "english_score"] + return self.spark.createDataFrame(data, columns) + def _check_fig_data(self, fig_data, **kwargs): for key, expected_value in kwargs.items(): if key in ["x", "y", "labels", "values"]: @@ -300,6 +316,65 @@ def test_pie_plot(self): messageParameters={"arg_name": "y", "arg_type": "StringType()"}, ) + def test_box_plot(self): + fig = self.sdf4.plot.box(column="math_score") + expected_fig_data = { + "boxpoints": "suspectedoutliers", + "lowerfence": (5,), + "mean": (50.0,), + "median": (55,), + "name": "math_score", + "notched": False, + "q1": (10,), + "q3": (65,), + "upperfence": (85,), + "x": [0], + "type": "box", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + + fig = self.sdf4.plot(kind="box", column=["math_score", "english_score"]) + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "boxpoints": "suspectedoutliers", + "lowerfence": (55,), + "mean": (72.5,), + "median": (65,), + "name": "english_score", + "notched": False, + "q1": (55,), + "q3": (75,), + "upperfence": (90,), + "x": [1], + "y": [[150, 15]], + "type": "box", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) + with self.assertRaises(PySparkValueError) as pe: + self.sdf4.plot.box(column="math_score", boxpoints=True) + self.check_error( + exception=pe.exception, + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "boxpoints", + "value": "True", + "supported_values": ", ".join(["suspectedoutliers", "False"]), + }, + ) + with self.assertRaises(PySparkValueError) as pe: + self.sdf4.plot.box(column="math_score", notched=True) + self.check_error( + exception=pe.exception, + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "notched", + "value": "True", + "supported_values": ", ".join(["False"]), + }, + ) + class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase): pass