-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[Spark-11968][ML][MLLIB]Optimize MLLIB ALS recommendForAll #17742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test build #76100 has finished for PR 17742 at commit
|
|
Interesting - I was working on something very similar - a rough draft of it is in a branch. |
|
Test build #76111 has finished for PR 17742 at commit
|
|
Could you post updated performance numbers? I think we can do the same optimization in |
| k += 1 | ||
| } | ||
| output.toSeq | ||
| case (users, items) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put case statement on previous line: flatMap { case (... =>
| k += 1 | ||
| } | ||
| output.toSeq | ||
| case (users, items) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer case (srcIter, dstIter) rather than users / items (as they can be swapped depending on which recommendation method is being called).
| val n = math.min(items.size, num) | ||
| val output = new Array[(Int, (Int, Double))](m * n) | ||
| var j = 0 | ||
| users.foreach (user => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
srcIter.foreach { case (srcId, srcFactor) =>
| def order(a: (Int, Double)) = a._2 | ||
| val pq: BoundedPriorityQueue[(Int, Double)] = | ||
| new BoundedPriorityQueue[(Int, Double)](n)(Ordering.by(order)) | ||
| items.foreach (item => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly here: dstIter.foreach { case (dstId, dstFactor) =>
| var rate: Double = 0 | ||
| var k = 0 | ||
| while(k < rank) { | ||
| rate += user._2(k) * item._2(k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then we can have rate += srcFactor(k) * dstFactor(k)
Also, can we call it score or prediction rather than rate?
| val n = math.min(items.size, num) | ||
| val output = new Array[(Int, (Int, Double))](m * n) | ||
| var j = 0 | ||
| users.foreach (user => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will there be performance benefit to using while loop here vs foreach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will test while here, thanks.
| */ | ||
| var rate: Double = 0 | ||
| var k = 0 | ||
| while(k < rank) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
space here: while (
| }) | ||
| val pqIter = pq.iterator | ||
| var i = 0 | ||
| while(i < n) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Space here: while (
| rate += user._2(k) * item._2(k) | ||
| k += 1 | ||
| } | ||
| pq += ((item._1, rate)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we can then use dstFactor instead
| val pqIter = pq.iterator | ||
| var i = 0 | ||
| while(i < n) { | ||
| output(j + i) = (user._1, pqIter.next()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here srcFactor instead
|
Test build #76176 has finished for PR 17742 at commit
|
|
Hi @MLnick , The new test results are: |
|
Another case |
|
Thanks very much @MLnick . |
|
Test build #76180 has finished for PR 17742 at commit
|
|
@mpjlu yeah we can do the ML version in a follow up PR that is ok (I can help if needed). |
| users.foreach (user => { | ||
| srcIter.foreach { case (srcId, srcFactor) => | ||
| def order(a: (Int, Double)) = a._2 | ||
| val pq: BoundedPriorityQueue[(Int, Double)] = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could remove the type sig from the val definition here to make it fit on one line
| srcIter.foreach { case (srcId, srcFactor) => | ||
| def order(a: (Int, Double)) = a._2 | ||
| val pq: BoundedPriorityQueue[(Int, Double)] = | ||
| new BoundedPriorityQueue[(Int, Double)](n)(Ordering.by(order)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe you can just do Ordering.by(_._2) without needing to define def order(... above
| * blas.ddot (F2jBLAS) is the same performance with the following code. | ||
| * the performace of blas.ddot with NativeBLAS is very bad. | ||
| * blas.ddot (F2jBLAS) is about 10% improvement comparing with linalg.dot. | ||
| * val rate = blas.ddot(rank, user._2, 1, item._2, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can perhaps say here instead "The below code is equivalent to val score = blas.ddot(rank, srcFactor, 1, dstFactor, 1)"
|
|
||
| /** | ||
| * Blockifies features to use Level-3 BLAS. | ||
| */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should adjust the comment here as we're not using Level-3 BLAS any more.
| k += 1 | ||
| } | ||
| output.toSeq | ||
| val ratings = srcBlocks.cartesian(dstBlocks).flatMap { case (srcIter, dstIter) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to more detail to the doc string comment for this method to explain the approach used for efficiency.
|
I did some tests with the PR. With a larger dataset (3.29million users and 0.21 million products), the recommendProductsForUsers time reduces from 48h to 39min, 73x faster than the original method. |
|
Test build #76254 has finished for PR 17742 at commit
|
| val output = new Array[(Int, (Int, Double))](m * n) | ||
| var j = 0 | ||
| srcIter.foreach { case (srcId, srcFactor) => | ||
| val pq = new BoundedPriorityQueue[(Int, Double)](n)(Ordering.by(_._2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: there are several 4-space indents here that should be 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @srowen
|
Test build #76258 has finished for PR 17742 at commit
|
|
Test build #76260 has started for PR 17742 at commit |
|
I changed the commit message to drop the |
This PR is a `DataFrame` version of #17742 for [SPARK-11968](https://issues.apache.org/jira/browse/SPARK-11968), for improving the performance of `recommendAll` methods. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath <[email protected]> Closes #17845 from MLnick/ml-als-perf. (cherry picked from commit 10b00ab) Signed-off-by: Nick Pentreath <[email protected]>
This PR is a `DataFrame` version of apache#17742 for [SPARK-11968](https://issues.apache.org/jira/browse/SPARK-11968), for improving the performance of `recommendAll` methods. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath <[email protected]> Closes apache#17845 from MLnick/ml-als-perf.
|
I think the problem is not BLAS-3 ops, nor the 256MB total memory. The @mpjlu Could you test the following?
Iterator.range(0, m).flatMap { i =>
Iterator.range(0, n).map { j =>
(srcIds(i), (dstIds(j), ratings(i, j)))
}
}The second option is just a quick test, scarifying some performance. The temp objects created this way have very short life, and GC should be able to handle it. Then very likely we don't need to do top-k inside ALS, because the |
|
Thanks @mengxr |
|
I don't think we should use BLAS 3 here, because no matter use output or not here, you need a big buffer to save the BLAS result. That still cause GC problem. |
|
A single buffer doesn't lead to long GC pause. If it request lot of memory, it might trigger GC to collect other objects. But itself is a single object, which can be easily GC'ed. The problem here is having many small long-living objects as in |
|
Thanks, I will do some test based on BLAS 3. |
|
** The most optimized version would be doing a quickselect on each row and select top k. |
|
BLAS3 with still keeping the output size as I had a version using BLAS 3 followed by a sort per row (see https://issues.apache.org/jira/browse/SPARK-11968 for branch link and test details). For MLLIB it was slower than this approach by a factor of 1.5x. I just re-tested for ML and it it is 56s vs 16s for this approach, so really significantly slower. Comparatively, both approaches created the intermediate Even without that this approach is a lot faster than |
|
It's true I think my native BLAS is not working will have to check - but yeah 1.5-2x matches what I've seen in my comparisons |
|
@mpjlu Could you try not linking with native BLAS or system BLAS in your test? Just let it fallback to f2j BLAS. I can do some tests on my end too. |
|
F2Jblas is faster than MKL blas. The following test is based on F2jBLAS. NA means not test that case. 3 workers: each worker 40 cores, each worker 120G memory, each worker 1 executor. |
|
@mpjlu do you have link to the code for Method 1? |
|
val srcBlocks = blockify(rank, srcFeatures) Code is like this. Thanks. |
|
I not validate whether this code is right. just test performance. |
Small clean ups from #17742 and #17845. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath <[email protected]> Closes #17919 from MLnick/SPARK-20677-als-perf-followup. (cherry picked from commit 25b4f41) Signed-off-by: Nick Pentreath <[email protected]>
Small clean ups from apache#17742 and apache#17845. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath <[email protected]> Closes apache#17919 from MLnick/SPARK-20677-als-perf-followup.
Small clean ups from apache#17742 and apache#17845. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath <[email protected]> Closes apache#17919 from MLnick/SPARK-20677-als-perf-followup.
The recommendForAll of MLLIB ALS is very slow. GC is a key problem of the current method. The task use the following code to keep temp result: val output = new Array[(Int, (Int, Double))](m*n) m = n = 4096 (default value, no method to set) so output is about 4k * 4k * (4 + 4 + 8) = 256M. This is a large memory and cause serious GC problem, and it is frequently OOM. Actually, we don't need to save all the temp result. Support we recommend topK (topK is about 10, or 20) product for each user, we only need 4k * topK * (4 + 4 + 8) memory to save the temp result. The Test Environment: 3 workers: each work 10 core, each work 30G memory, each work 1 executor. The Data: User 480,000, and Item 17,000 BlockSize: 1024 2048 4096 8192 Old method: 245s 332s 488s OOM This solution: 121s 118s 117s 120s The existing UT. Author: Peng <[email protected]> Author: Peng Meng <[email protected]> Closes apache#17742 from mpjlu/OptimizeAls.
This PR is a `DataFrame` version of apache#17742 for [SPARK-11968](https://issues.apache.org/jira/browse/SPARK-11968), for improving the performance of `recommendAll` methods. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath <[email protected]> Closes apache#17845 from MLnick/ml-als-perf.
Small clean ups from apache#17742 and apache#17845. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath <[email protected]> Closes apache#17919 from MLnick/SPARK-20677-als-perf-followup.
|
I find why F2j BLAS is much faster than Native BLAS for xiangrui's method (use GEMM) here. |
|
I have rewritten recommendForAll with BLAS GEMM, and get about 50% performance improvement. |
|
I have submitted PR for ALS optimization with GEMM. and it is ready for review. |
What changes were proposed in this pull request?
The recommendForAll of MLLIB ALS is very slow.
GC is a key problem of the current method.
The task use the following code to keep temp result:
val output = new Array(Int, (Int, Double))
m = n = 4096 (default value, no method to set)
so output is about 4k * 4k * (4 + 4 + 8) = 256M. This is a large memory and cause serious GC problem, and it is frequently OOM.
Actually, we don't need to save all the temp result. Support we recommend topK (topK is about 10, or 20) product for each user, we only need 4k * topK * (4 + 4 + 8) memory to save the temp result.
The Test Environment:
3 workers: each work 10 core, each work 30G memory, each work 1 executor.
The Data: User 480,000, and Item 17,000
BlockSize: 1024 2048 4096 8192
Old method: 245s 332s 488s OOM
This solution: 121s 118s 117s 120s
How was this patch tested?
The existing UT.