Skip to content

Conversation

@jkbradley
Copy link
Member

@jkbradley jkbradley commented Jul 26, 2016

What changes were proposed in this pull request?

RandomForest currently sends the entire forest to each worker on each iteration. This is because (a) the node queue is FIFO and (b) the closure references the entire array of trees (topNodes). (a) causes RFs to handle splits in many trees, especially early on in learning. (b) sends all trees explicitly.

This PR:
(a) Change the RF node queue to be FILO (a stack), so that RFs tend to focus on 1 or a few trees before focusing on others.
(b) Change topNodes to pass only the trees required on that iteration.

How was this patch tested?

Unit tests:

  • Existing tests for correctness of tree learning
  • Manually modifying code and running tests to verify that a small number of trees are communicated on each iteration
    • This last item is hard to test via unit tests given the current APIs.

@jkbradley
Copy link
Member Author

@hhbyyh This is an improvement I had implemented a while back, just a little too late for the 2.0 code freeze. Could you please help review it or find others? Thank you!

q += ((treeIndex, node))
}

/** Remove and return last inserted element. Linear time (unclear in Scala docs though) */
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is not ideal, but it should be insignificant compared to the cost of communicating the trees. I am not aware of an existing solution for a stack with constant-time push/pop in Scala.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"a FILO queue" is an unconventional saying. I like it though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some reason not to use scala.collection.mutable.Stack[(Int, LearningNode)] here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How critical is this for performance? I know that append is O(1) for scala mutable lists but dropRight(1) is actually implemented in a parent class [1] of MutableList and therefores does not take advantage of the list's reference to its last element. All in all, from it's implementation it seems that dropRight is O(n)

[1] https://github.com/scala/scala/blob/v2.11.8/src/library/scala/collection/LinearSeqOptimized.scala#L194-L204

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, I missed that List.tail is O(1) time. I switched to a List. Thanks all!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also somehow didn't find Stack via Google... Does anyone know of documentation for the performance of Stack? I don't see it in the Scala docs.

@hhbyyh
Copy link
Contributor

hhbyyh commented Jul 26, 2016

Ack. I'll review it and run tests tonight.

@SparkQA
Copy link

SparkQA commented Jul 26, 2016

Test build #62858 has finished for PR 14359 at commit 6fcfb4b.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@hhbyyh
Copy link
Contributor

hhbyyh commented Jul 26, 2016

If it is not urgent, I'd like to try some large scale training to understand more about the improvements.

@jkbradley
Copy link
Member Author

Not urgent, but I'd like it to be in 2.1

@jodersky
Copy link
Member

@jkbradley , you can find the scaladoc on stacks here http://www.scala-lang.org/api/current/index.html#scala.collection.mutable.Stack

Also this document http://docs.scala-lang.org/overviews/collections/performance-characteristics gives a nice overview of the different collection types in scala and their performances

@SparkQA
Copy link

SparkQA commented Jul 26, 2016

Test build #62900 has finished for PR 14359 at commit 3c00d03.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member Author

Thanks @jodersky I saw those, but the first does not document computational cost & the latter does not really clarify what I need for stacks (push and pop).

@jodersky
Copy link
Member

jodersky commented Jul 26, 2016

Agree, it's not very obvious. In the latter document I think a push is akin to prepend and pop to head

@jkbradley
Copy link
Member Author

Sorry for the long delay; I've been swamped by other things for a while. Re-emerging...

I switched to Stack and then realized Stack has been deprecated in Scala 2.11, so I reverted to the original NodeQueue. But I renamed NodeQueue to NodeStack to be a bit clearer.

@hhbyyh Any luck testing this at scale?

@jkbradley
Copy link
Member Author

Btw, to give back-of-the-envelope estimates, we can look at 2 numbers:
(1) How many nodes will be split on each iteration?
(2) How big is the forest which is serialized and sent to workers on each iteration?

For (1), here's an example:

  • 1000 features, each with 50 bins -> 50000 possible splits
  • set maxMemoryInMB = 256 (default)
  • regression => 3 Double values per possible split
  • 256 * 10^6 / (3 * 50000 * 8) = 213 nodes/iteration

This implies that for trees of depth > 8 or so, many iterations will only split nodes from 1 or 2 trees. I.e., we should avoid communicating most trees.

For (2), the forest can be pretty expensive to send.

  • Each node:
    • leaf node: 5 Doubles
    • internal node: ~8 Doubles/references + Split
      • Split: O(# categories) or 2 values for continuous, say 3 Doubles on average
    • => say 8 Doubles/node on average
  • 100 trees of depth 8 => 25600 nodes => 1.6MB
  • 100 trees of depth 14 => 105MB
  • I've heard of many cases of users wanting to fit 500-1000 trees and use trees of depth 18-20.

@SparkQA
Copy link

SparkQA commented Aug 16, 2016

Test build #63822 has finished for PR 14359 at commit f79f77c.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jodersky
Copy link
Member

I switched to Stack and then realized Stack has been deprecated in Scala 2.11...

I think you probably read the immutable stack docs; the mutable stack is not deprecated AFAIK. I can imagine that having a custom stack implementation may allow for additional operations in the future, however we should also consider that using standard collections reduces the load for anyone who will maintain the code then.

Btw, I highly recommend to use the milestone scaladocs over the current ones. Although 2.12 is not officially out yet, the changes to the library are minimal and the UI is much more pleasant to use.

@jkbradley
Copy link
Member Author

Ahh, you're right; I was looking at immutable. I'll update to use the mutable stack. Thanks!

@jkbradley
Copy link
Member Author

Done!

@SparkQA
Copy link

SparkQA commented Aug 17, 2016

Test build #63886 has finished for PR 14359 at commit 41f4297.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
/*
FILO queue of nodes to train: (treeIndex, node)
We make this FILO by always inserting nodes by appending (+=) and removing with dropRight.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the methods described here still refer to the original queue data structure

@jodersky
Copy link
Member

Some comments still refer to the use of queue and should be updated. Other than that, the data structure part now looks good to me.

@jkbradley
Copy link
Member Author

Thanks @jodersky ! Updated.

@SparkQA
Copy link

SparkQA commented Aug 18, 2016

Test build #64020 has finished for PR 14359 at commit 133fdbf.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@hhbyyh
Copy link
Contributor

hhbyyh commented Sep 9, 2016

Hi Joseph, Sorry for the late response. I was occupied by a customer Spark project for the past month.

The idea looks reasonable and I tested with MNist dataset and the overall run time decrease from 245 seconds to 225 seconds on average. LGTM.

val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
/*
Stack of nodes to train: (treeIndex, node)
The reason this is FILO is that we train many trees at once, but we want to focus on
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"The reason this is FILO" -> "The reason we use a stack"

@sethah
Copy link
Contributor

sethah commented Sep 20, 2016

This is a really nice improvement. The communication overhead is reduced, based on some simple local tests. I wonder how we can add a test to verify that the algorithm focuses on completing whole trees at once. Potentially, we can add a test of selectNodesToSplit to verify that it chooses nodes from fewer number of trees, but I'm not sure it's necessary. Thoughts?

Also, it might not be too hard to take this a step further. We could group the nodes to be trained by tree, and keep track of the amount of memory they require. Then to select nodes to split, we can simply pick off the trees that require the most memory until we exceed the threshold. This way we truly minimize the number of trees while still occupying the memory size. We could leave it for another JIRA.

@jkbradley
Copy link
Member Author

Thanks @hhbyyh and @sethah !

I agree that a later PR could be more careful about which trees are completed in which order and test this more thoroughly. But I hope this takes us 80% of the way there. If it's Ok with you, I'd like to go ahead and merge it as is once tests pass.

@sethah
Copy link
Contributor

sethah commented Sep 22, 2016

LGTM pending tests.

@SparkQA
Copy link

SparkQA commented Sep 23, 2016

Test build #65798 has finished for PR 14359 at commit d16c2da.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member Author

Merging with master

@asfgit asfgit closed this in 947b8c6 Sep 23, 2016
@jkbradley jkbradley deleted the rfs-fewer-trees branch September 23, 2016 05:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants