Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
StringType,
)
from pyspark.sql.utils import enum_to_value as _enum_to_value
from pyspark.util import JVM_INT_MAX

# The implementation of pandas_udf is embedded in pyspark.sql.function.pandas_udf
# for code reuse.
Expand Down Expand Up @@ -1126,11 +1127,12 @@ def grouping_id(*cols: "ColumnOrName") -> Column:

def count_min_sketch(
col: "ColumnOrName",
eps: "ColumnOrName",
confidence: "ColumnOrName",
seed: "ColumnOrName",
eps: Union[Column, float],
confidence: Union[Column, float],
seed: Optional[Union[Column, int]] = None,
) -> Column:
return _invoke_function_over_columns("count_min_sketch", col, eps, confidence, seed)
_seed = lit(random.randint(0, JVM_INT_MAX)) if seed is None else lit(seed)
return _invoke_function_over_columns("count_min_sketch", col, lit(eps), lit(confidence), _seed)


count_min_sketch.__doc__ = pysparkfuncs.count_min_sketch.__doc__
Expand Down
71 changes: 59 additions & 12 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6015,9 +6015,9 @@ def grouping_id(*cols: "ColumnOrName") -> Column:
@_try_remote_functions
def count_min_sketch(
col: "ColumnOrName",
eps: "ColumnOrName",
confidence: "ColumnOrName",
seed: "ColumnOrName",
eps: Union[Column, float],
confidence: Union[Column, float],
seed: Optional[Union[Column, int]] = None,
) -> Column:
"""
Returns a count-min sketch of a column with the given esp, confidence and seed.
Expand All @@ -6031,26 +6031,73 @@ def count_min_sketch(
----------
col : :class:`~pyspark.sql.Column` or str
target column to compute on.
eps : :class:`~pyspark.sql.Column` or str
eps : :class:`~pyspark.sql.Column` or float
relative error, must be positive
confidence : :class:`~pyspark.sql.Column` or str

.. versionchanged:: 4.0.0
`eps` now accepts float value.

confidence : :class:`~pyspark.sql.Column` or float
confidence, must be positive and less than 1.0
seed : :class:`~pyspark.sql.Column` or str

.. versionchanged:: 4.0.0
`confidence` now accepts float value.

seed : :class:`~pyspark.sql.Column` or int, optional
random seed

.. versionchanged:: 4.0.0
`seed` now accepts int value.

Returns
-------
:class:`~pyspark.sql.Column`
count-min sketch of the column

Examples
--------
>>> df = spark.createDataFrame([[1], [2], [1]], ['data'])
>>> df = df.agg(count_min_sketch(df.data, lit(0.5), lit(0.5), lit(1)).alias('sketch'))
>>> df.select(hex(df.sketch).alias('r')).collect()
[Row(r='0000000100000000000000030000000100000004000000005D8D6AB90000000000000000000000000000000200000000000000010000000000000000')]
"""
return _invoke_function_over_columns("count_min_sketch", col, eps, confidence, seed)
Example 1: Using columns as arguments

>>> from pyspark.sql import functions as sf
>>> spark.range(100).select(
... sf.hex(sf.count_min_sketch(sf.col("id"), sf.lit(3.0), sf.lit(0.1), sf.lit(1)))
... ).show(truncate=False)
+------------------------------------------------------------------------+
|hex(count_min_sketch(id, 3.0, 0.1, 1)) |
+------------------------------------------------------------------------+
|0000000100000000000000640000000100000001000000005D8D6AB90000000000000064|
+------------------------------------------------------------------------+

Example 2: Using numbers as arguments

>>> from pyspark.sql import functions as sf
>>> spark.range(100).select(
... sf.hex(sf.count_min_sketch("id", 1.0, 0.3, 2))
... ).show(truncate=False)
+----------------------------------------------------------------------------------------+
|hex(count_min_sketch(id, 1.0, 0.3, 2)) |
+----------------------------------------------------------------------------------------+
|0000000100000000000000640000000100000002000000005D96391C00000000000000320000000000000032|
+----------------------------------------------------------------------------------------+

Example 3: Using a random seed

>>> from pyspark.sql import functions as sf
>>> spark.range(100).select(
... sf.hex(sf.count_min_sketch("id", sf.lit(1.5), 0.6))
... ).show(truncate=False) # doctest: +SKIP
+----------------------------------------------------------------------------------------------------------------------------------------+
|hex(count_min_sketch(id, 1.5, 0.6, 2120704260)) |
+----------------------------------------------------------------------------------------------------------------------------------------+
|0000000100000000000000640000000200000002000000005ADECCEE00000000153EBE090000000000000033000000000000003100000000000000320000000000000032|
+----------------------------------------------------------------------------------------------------------------------------------------+
""" # noqa: E501
_eps = lit(eps)
_conf = lit(confidence)
if seed is None:
return _invoke_function_over_columns("count_min_sketch", col, _eps, _conf)
else:
return _invoke_function_over_columns("count_min_sketch", col, _eps, _conf, lit(seed))


@_try_remote_functions
Expand Down
12 changes: 12 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,18 @@ object functions {
def count_min_sketch(e: Column, eps: Column, confidence: Column, seed: Column): Column =
Column.fn("count_min_sketch", e, eps, confidence, seed)

/**
* Returns a count-min sketch of a column with the given esp, confidence and seed. The result is
* an array of bytes, which can be deserialized to a `CountMinSketch` before usage. Count-min
* sketch is a probabilistic data structure used for cardinality estimation using sub-linear
* space.
*
* @group agg_funcs
* @since 4.0.0
*/
def count_min_sketch(e: Column, eps: Column, confidence: Column): Column =
count_min_sketch(e, eps, confidence, lit(SparkClassUtils.random.nextInt))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would need a connect version as well.

Copy link
Contributor Author

@zhengruifeng zhengruifeng Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! let me add one for Scala client.

Copy link
Contributor Author

@zhengruifeng zhengruifeng Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon it seems we don't have a separate functions.scala for Scala client now, after a lot of refactoring?


private[spark] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
Column.internalFn("collect_top_k", e, lit(num), lit(reverse))

Expand Down