Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,29 @@ def test_np_scalar_input(self):
res = df.select(array_position(df.data, dtype(1)).alias("c")).collect()
self.assertEqual([Row(c=1), Row(c=0)], res)

@unittest.skipIf(not have_numpy, "NumPy not installed")
def test_ndarray_input(self):
import numpy as np

arr_dtype_to_spark_dtypes = [
("int8", [("b", "array<smallint>")]),
("int16", [("b", "array<smallint>")]),
("int32", [("b", "array<int>")]),
("int64", [("b", "array<bigint>")]),
("float32", [("b", "array<float>")]),
("float64", [("b", "array<double>")]),
]
for t, expected_spark_dtypes in arr_dtype_to_spark_dtypes:
arr = np.array([1, 2]).astype(t)
self.assertEqual(
expected_spark_dtypes, self.spark.range(1).select(lit(arr).alias("b")).dtypes
)
arr = np.array([1, 2]).astype(np.uint)
with self.assertRaisesRegex(
TypeError, "The type of array scalar '%s' is not supported" % arr.dtype
):
self.spark.range(1).select(lit(arr).alias("b"))

def test_binary_math_function(self):
funcs, expected = zip(*[(atan2, 0.13664), (hypot, 8.07527), (pow, 2.14359), (pmod, 1.1)])
df = self.spark.range(1).select(*(func(1.1, 8) for func in funcs))
Expand Down
49 changes: 48 additions & 1 deletion python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)

from py4j.protocol import register_input_converter
from py4j.java_gateway import GatewayClient, JavaClass, JavaObject
from py4j.java_gateway import GatewayClient, JavaClass, JavaGateway, JavaObject

from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.utils import has_numpy
Expand Down Expand Up @@ -2268,12 +2268,59 @@ def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any:
return obj.item()


class NumpyArrayConverter:
def _from_numpy_type_to_java_type(
self, nt: "np.dtype", gateway: JavaGateway
) -> Optional[JavaClass]:
"""Convert NumPy type to Py4J Java type."""
if nt in [np.dtype("int8"), np.dtype("int16")]:
# Mapping int8 to gateway.jvm.byte causes
# TypeError: 'bytes' object does not support item assignment
return gateway.jvm.short
elif nt == np.dtype("int32"):
return gateway.jvm.int
elif nt == np.dtype("int64"):
return gateway.jvm.long
elif nt == np.dtype("float32"):
return gateway.jvm.float
elif nt == np.dtype("float64"):
return gateway.jvm.double
elif nt == np.dtype("bool"):
return gateway.jvm.boolean

return None

def can_convert(self, obj: Any) -> bool:
return has_numpy and isinstance(obj, np.ndarray) and obj.ndim == 1

def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject:
from pyspark import SparkContext

gateway = SparkContext._gateway
assert gateway is not None
plist = obj.tolist()

if len(obj) > 0 and isinstance(plist[0], str):
jtpe = gateway.jvm.String
else:
jtpe = self._from_numpy_type_to_java_type(obj.dtype, gateway)
if jtpe is None:
raise TypeError("The type of array scalar '%s' is not supported" % (obj.dtype))
jarr = gateway.new_array(jtpe, len(obj))
for i in range(len(plist)):
jarr[i] = plist[i]
return jarr


# datetime is a subclass of date, we should register DatetimeConverter first
register_input_converter(DatetimeNTZConverter())
register_input_converter(DatetimeConverter())
register_input_converter(DateConverter())
register_input_converter(DayTimeIntervalTypeConverter())
register_input_converter(NumpyScalarConverter())
# NumPy array satisfies py4j.java_collections.ListConverter,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> from py4j.java_collections import ListConverter
>>> ndarr = np.array([1, 2])
>>> ListConverter().can_convert(ndarr)
True

# so prepend NumpyArrayConverter
register_input_converter(NumpyArrayConverter(), prepend=True)


def _test() -> None:
Expand Down