Skip to content

Conversation

@xinrong-meng
Copy link
Member

@xinrong-meng xinrong-meng commented Aug 24, 2022

What changes were proposed in this pull request?

Support NumPy ndarray in built-in functions(pyspark.sql.functions) by introducing Py4J input converter NumpyArrayConverter. The converter converts a ndarray to a Java array.

The mapping between ndarray dtype with Java primitive type is defined as below:

            np.dtype("int64"): gateway.jvm.long,
            np.dtype("int32"): gateway.jvm.int,
            np.dtype("int16"): gateway.jvm.short,
            # Mapping to gateway.jvm.byte causes
            #   TypeError: 'bytes' object does not support item assignment
            np.dtype("int8"): gateway.jvm.short,
            np.dtype("float32"): gateway.jvm.float,
            np.dtype("float64"): gateway.jvm.double,
            np.dtype("bool"): gateway.jvm.boolean,

Why are the changes needed?

As part of SPARK-39405 for NumPy support in SQL.

Does this PR introduce any user-facing change?

Yes. NumPy ndarray is supported in built-in functions.

Take lit for example,

>>> spark.range(1).select(lit(np.array([1, 2], dtype='int16'))).dtypes
[('ARRAY(1S, 2S)', 'array<smallint>')]
>>> spark.range(1).select(lit(np.array([1, 2], dtype='int32'))).dtypes
[('ARRAY(1, 2)', 'array<int>')]
>>> spark.range(1).select(lit(np.array([1, 2], dtype='float32'))).dtypes
[("ARRAY(CAST('1.0' AS FLOAT), CAST('2.0' AS FLOAT))", 'array<float>')]
>>> spark.range(1).select(lit(np.array([]))).dtypes
[('ARRAY()', 'array<double>')]

How was this patch tested?

Unit tests.

@xinrong-meng xinrong-meng changed the title [WIP] Support NumPy arrays in built-in functions [WIP][SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions Aug 24, 2022
@AmplabJenkins
Copy link

Can one of the admins verify this patch?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> from py4j.java_collections import ListConverter
>>> ndarr = np.array([1, 2])
>>> ListConverter().can_convert(ndarr)
True

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Java type of the array is required in order to create a Java array. So tpe_dict is created to map Python types to Java types.

@xinrong-meng xinrong-meng changed the title [WIP][SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions Aug 25, 2022
@xinrong-meng xinrong-meng marked this pull request as ready for review August 25, 2022 19:09
@xinrong-meng
Copy link
Member Author

May I get a review? Thanks! @HyukjinKwon @ueshin @zhengruifeng

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we map this type from NumPy dtype?

Copy link
Member Author

@xinrong-meng xinrong-meng Aug 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since plist = obj.tolist(), plist is a list of Python scalars, see https://numpy.org/doc/stable/reference/generated/numpy.ndarray.tolist.html.

So tpe_dict maps Python types to Java type.

That's consistent with NumpyScalarConverter.convert which calls obj.item(), see https://numpy.org/doc/stable/reference/generated/numpy.ndarray.item.html.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if there is a better approach :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, unlike obj.item in which we have to pass Python primitive type; thus, resulting that JVM side type precision cannot be specified, here we can have more correct size in the JVM array.

I think it's better to have the correct type in the element ... Ideally we should make obj.item respect the numpy dtype too..

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and I believe we already have the type mapping defined in pandas API on Spark somewhere IIRC

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!

One limitation is np.dtype("int8") cann't be mapped to gateway.jvm.byte, create jarr accordingly and then do the per-element assignment.

TypeError: 'bytes' object does not support item assignment is caused in jarr[i] = plist[i].

So both int8 and int16 are mapped to gateway.jvm.short.

@xinrong-meng
Copy link
Member Author

Rebased to resolve conflicts. Only bc90498 is new after the review.

gateway = SparkContext._gateway
assert gateway is not None
plist = obj.tolist()
tpe_np_to_java = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, what about moving this dict outside of convert, so it can be reused

Copy link
Member Author

@xinrong-meng xinrong-meng Sep 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot import SparkContext from the module level. And we may want to do a nullability check for "SparkContext._gateway". So _from_numpy_type_to_java_type is introduced instead for code reuse. Let me know if you have a better idea :)

Copy link
Member

@HyukjinKwon HyukjinKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except @zhengruifeng's comment.

Copy link
Contributor

@itholic itholic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good otherwise

Comment on lines 2306 to 2307
if jtpe is None:
raise TypeError("The type of array scalar is not supported")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a test for this ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, yeah. let's add one negative test.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!
Optimized the TypeError message as well.

return None


def _from_numpy_type_to_java_type(nt: "np.dtype", gateway: JavaGateway) -> Optional[JavaClass]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can actually add this as a NumpyArrayConverter's class attribute

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean an instance method?

@xinrong-meng
Copy link
Member Author

Thank you all! Merged to master.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants