Skip to content

Commit 5bf4a29

Browse files
committed
[SPARK-53433][PYTHON][TESTS] Add test for Arrow UDF with VariantType
### What changes were proposed in this pull request? Add test for Arrow UDF with VariantType ### Why are the changes needed? Arrow UDF natively support all datatypes which is arrow-compatible, so it should support VariantType. This PR adds tests to guard it. ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #52172 from zhengruifeng/array_test_variant. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 5b2c4cf commit 5bf4a29

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,66 @@ def build_time(h, mi, s):
489489
result.collect(),
490490
)
491491

492+
def test_arrow_udf_input_variant(self):
493+
import pyarrow as pa
494+
495+
@arrow_udf("int")
496+
def scalar_f(v: pa.Array) -> pa.Array:
497+
assert isinstance(v, pa.Array)
498+
assert isinstance(v, pa.StructArray)
499+
assert isinstance(v.field("metadata"), pa.BinaryArray)
500+
assert isinstance(v.field("value"), pa.BinaryArray)
501+
return pa.compute.binary_length(v.field("value"))
502+
503+
@arrow_udf("int")
504+
def iter_f(it: Iterator[pa.Array]) -> Iterator[pa.Array]:
505+
for v in it:
506+
assert isinstance(v, pa.Array)
507+
assert isinstance(v, pa.StructArray)
508+
assert isinstance(v.field("metadata"), pa.BinaryArray)
509+
assert isinstance(v.field("value"), pa.BinaryArray)
510+
yield pa.compute.binary_length(v.field("value"))
511+
512+
df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as string)) v")
513+
expected = [Row(l=2) for i in range(10)]
514+
515+
for f in [scalar_f, iter_f]:
516+
result = df.select(f("v").alias("l")).collect()
517+
self.assertEqual(result, expected)
518+
519+
def test_arrow_udf_output_variant(self):
520+
# referring to test_udf_with_variant_output in test_pandas_udf_scalar
521+
import pyarrow as pa
522+
523+
# referring to_arrow_type in to pyspark.sql.pandas.types
524+
fields = [
525+
pa.field("value", pa.binary(), nullable=False),
526+
pa.field("metadata", pa.binary(), nullable=False, metadata={b"variant": b"true"}),
527+
]
528+
variant_type = pa.struct(fields)
529+
530+
@arrow_udf("variant")
531+
def scalar_f(v: pa.Array) -> pa.Array:
532+
assert isinstance(v, pa.Array)
533+
v = pa.array([bytes([12, i.as_py()]) for i in v], pa.binary())
534+
m = pa.array([bytes([1, 0, 0]) for i in v], pa.binary())
535+
return pa.StructArray.from_arrays([v, m], type=variant_type)
536+
537+
@arrow_udf("variant")
538+
def iter_f(it: Iterator[pa.Array]) -> Iterator[pa.Array]:
539+
for v in it:
540+
assert isinstance(v, pa.Array)
541+
v = pa.array([bytes([12, i.as_py()]) for i in v])
542+
m = pa.array([bytes([1, 0, 0]) for i in v])
543+
yield pa.StructArray.from_arrays([v, m], type=variant_type)
544+
545+
df = self.spark.range(0, 10)
546+
expected = [Row(l=i) for i in range(10)]
547+
548+
for f in [scalar_f, iter_f]:
549+
result = df.select(f("id").cast("int").alias("l")).collect()
550+
self.assertEqual(result, expected)
551+
492552
def test_arrow_udf_null_boolean(self):
493553
data = [(True,), (True,), (None,), (False,)]
494554
schema = StructType().add("bool", BooleanType())

0 commit comments

Comments
 (0)