-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-28736][SPARK-28735][PYTHON][ML][TESTS] Fix PySpark ML tests to pass in JDK 11 #25475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
cc @WeichenXu123, @srowen, @dongjoon-hyun, this fixes PySpark tests on JDK 11. |
|
Wow. Thank you, @HyukjinKwon ! |
| True | ||
| >>> model.predict([-0.1,-0.05]) | ||
| 0 | ||
| >>> softPredicted = model.predictSoft([-0.1,-0.05]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For instance, weights within Gaussian mixture model:
JDK 8
weights: WrappedArray(0.49520257460263445, 0.33813075873069875, 0.16666666666666685)
JDK 11
weights: WrappedArray(0.5000000000000001, 0.33333333333333326, 0.16666666666666666)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also probably OK for the same reason. The test was too specific.
|
Test build #109210 has finished for PR 25475 at commit
|
| self.assertTrue(result.prediction, expected_prediction) | ||
| self.assertTrue(np.allclose(result.probability, expected_probability, atol=1E-4)) | ||
| self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, atol=1E-4)) | ||
| self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, atol=1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is 1 the minimum difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup ..
JDK 8:
[-11.19194106875243,-7.677866573997363,21.280214474039443]
JDK 11:
[-11.608192299802019,-8.158279986906651,22.177570449962918]
Seems multiple floats affects the results while they are roughly correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure where the difference comes from, but it could be subtle differences in randomization or something across the JDKs. If these two tests are the only ones that vary, I think we're OK. I agree with loosening the bound here as these are log-odds, and I suspect the test values were picked just because it's what some previous run spit out (that is, it's too specific)
|
+1. This PR looks reasonable and good to me. |
|
Im going to just merge it. This is test-only PR and should always be fixed later. I roughly checked with @WeichenXu123 too offline as well. |
|
Merged to master. |
What changes were proposed in this pull request?
This PR proposes to fix both tests below:
to pass in JDK 11.
The root cause seems to be different float values being understood via Py4J. This issue also was found in #25132 before.
When floats are transferred from Python to JVM, the values are sent as are. Python floats are not "precise" due to its own limitation - https://docs.python.org/3/tutorial/floatingpoint.html.
For some reasons, the floats from Python on JDK 8 and JDK 11 are different, which is already explicitly not guaranteed.
This seems why only some tests in PySpark with floats are being failed.
So, this PR fixes it by increasing tolerance in identified test cases in PySpark.
Why are the changes needed?
To fully support JDK 11. See, for instance, #25443 and #25423 for ongoing efforts.
Does this PR introduce any user-facing change?
No.
How was this patch tested?
Manually tested as described in JIRAs: