-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
This is a really awesome feature we're looking to add. Super hard problem also if any ninjas want to try to tackle it :) (you'll be legendary haha).
Problem:
Some models are too big to fit in memory. Thus can't do any distributed training currently available (even in PyTorch).
But... we can break up the model and put parts on each GPU. The trick though is to do it automatically, because manually doing this is a PITA (trust me, i spend weeks dealing with this haha).
Proposed solution:
User hook in LightningModule where user returns the modules they want balanced.
class MyModule(LightningModule):
def __init__(...):
self.model_a = SomeModel()
self.layer_1 = Linear(...)
self.layer2 = Linear(...)
def forward(x):
# in each of these module calls, auto place the input x on the gpu of the module
x = self.model_a(x)
# in each of these module calls, auto place the input x on the gpu of the module
x = self.layer_1(x)
# in each of these module calls, auto place the input x on the gpu of the module
x = self.layer_2(x)
return x
def self_balance():
return [self.model_a, self.layer_1, self.layer_2]So the above does two cool things:
- user says how they want to break up the model.
- In the forward, we auto put the input on that module's GPU.
That's the easy part lol... the hard part is deciding how to balance... optimizing for speed so you minimize data transfer across GPUs while not blowing up the RAM and using the RAM efficiently.
Anyone want to give this a shot?