Skip to content

Commit 8682bb1

Browse files
committed
[SPARK-29627][PYTHON][SQL] Allow array_contains to take column instances
### What changes were proposed in this pull request? This PR proposes to allow `array_contains` to take column instances. ### Why are the changes needed? For consistent support in Scala and Python APIs. Scala allows column instances at `array_contains` Scala: ```scala import org.apache.spark.sql.functions._ val df = Seq(Array("a", "b", "c"), Array.empty[String]).toDF("data") df.select(array_contains($"data", lit("a"))).show() ``` Python: ```python from pyspark.sql.functions import array_contains, lit df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) df.select(array_contains(df.data, lit("a"))).show() ``` However, PySpark sides does not allow. ### Does this PR introduce any user-facing change? Yes. ```python from pyspark.sql.functions import array_contains, lit df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) df.select(array_contains(df.data, lit("a"))).show() ``` **Before:** ``` Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/.../spark/python/pyspark/sql/functions.py", line 1950, in array_contains return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/java_gateway.py", line 1277, in __call__ File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/java_gateway.py", line 1241, in _build_args File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/java_gateway.py", line 1228, in _get_args File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/java_collections.py", line 500, in convert File "/.../spark/python/pyspark/sql/column.py", line 344, in __iter__ raise TypeError("Column is not iterable") TypeError: Column is not iterable ``` **After:** ``` +-----------------------+ |array_contains(data, a)| +-----------------------+ | true| | false| +-----------------------+ ``` ### How was this patch tested? Manually tested and added a doctest. Closes #26288 from HyukjinKwon/SPARK-29627. Authored-by: HyukjinKwon <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
1 parent 8e667db commit 8682bb1

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

python/pyspark/sql/functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1940,13 +1940,16 @@ def array_contains(col, value):
19401940
given value, and false otherwise.
19411941
19421942
:param col: name of column containing array
1943-
:param value: value to check for in array
1943+
:param value: value or column to check for in array
19441944
19451945
>>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
19461946
>>> df.select(array_contains(df.data, "a")).collect()
19471947
[Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]
1948+
>>> df.select(array_contains(df.data, lit("a"))).collect()
1949+
[Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]
19481950
"""
19491951
sc = SparkContext._active_spark_context
1952+
value = value._jc if isinstance(value, Column) else value
19501953
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
19511954

19521955

0 commit comments

Comments
 (0)