-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions #37635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Can one of the admins verify this patch? |
python/pyspark/sql/types.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>>> from py4j.java_collections import ListConverter
>>> ndarr = np.array([1, 2])
>>> ListConverter().can_convert(ndarr)
True
python/pyspark/sql/types.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Java type of the array is required in order to create a Java array. So tpe_dict is created to map Python types to Java types.
b780b46 to
00370b5
Compare
|
May I get a review? Thanks! @HyukjinKwon @ueshin @zhengruifeng |
python/pyspark/sql/types.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we map this type from NumPy dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since plist = obj.tolist(), plist is a list of Python scalars, see https://numpy.org/doc/stable/reference/generated/numpy.ndarray.tolist.html.
So tpe_dict maps Python types to Java type.
That's consistent with NumpyScalarConverter.convert which calls obj.item(), see https://numpy.org/doc/stable/reference/generated/numpy.ndarray.item.html.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know if there is a better approach :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, unlike obj.item in which we have to pass Python primitive type; thus, resulting that JVM side type precision cannot be specified, here we can have more correct size in the JVM array.
I think it's better to have the correct type in the element ... Ideally we should make obj.item respect the numpy dtype too..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and I believe we already have the type mapping defined in pandas API on Spark somewhere IIRC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense!
One limitation is np.dtype("int8") cann't be mapped to gateway.jvm.byte, create jarr accordingly and then do the per-element assignment.
TypeError: 'bytes' object does not support item assignment is caused in jarr[i] = plist[i].
So both int8 and int16 are mapped to gateway.jvm.short.
b829383 to
bc90498
Compare
|
Rebased to resolve conflicts. Only bc90498 is new after the review. |
python/pyspark/sql/types.py
Outdated
| gateway = SparkContext._gateway | ||
| assert gateway is not None | ||
| plist = obj.tolist() | ||
| tpe_np_to_java = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, what about moving this dict outside of convert, so it can be reused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We cannot import SparkContext from the module level. And we may want to do a nullability check for "SparkContext._gateway". So _from_numpy_type_to_java_type is introduced instead for code reuse. Let me know if you have a better idea :)
HyukjinKwon
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM except @zhengruifeng's comment.
itholic
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good otherwise
python/pyspark/sql/types.py
Outdated
| if jtpe is None: | ||
| raise TypeError("The type of array scalar is not supported") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a test for this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, yeah. let's add one negative test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
Optimized the TypeError message as well.
python/pyspark/sql/types.py
Outdated
| return None | ||
|
|
||
|
|
||
| def _from_numpy_type_to_java_type(nt: "np.dtype", gateway: JavaGateway) -> Optional[JavaClass]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can actually add this as a NumpyArrayConverter's class attribute
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean an instance method?
|
Thank you all! Merged to master. |
What changes were proposed in this pull request?
Support NumPy ndarray in built-in functions(
pyspark.sql.functions) by introducing Py4J input converterNumpyArrayConverter. The converter converts a ndarray to a Java array.The mapping between ndarray dtype with Java primitive type is defined as below:
Why are the changes needed?
As part of SPARK-39405 for NumPy support in SQL.
Does this PR introduce any user-facing change?
Yes. NumPy ndarray is supported in built-in functions.
Take
litfor example,How was this patch tested?
Unit tests.