-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Initial integration with FairScale Pipe Module for model partitioning/gradient checkpointing #4443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hello @SeanNaren! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2020-11-18 15:09:15 UTC |
3834ec3 to
6cf6a7d
Compare
|
So the code works, allowing the user to split the model across GPUs in a SPSD fashion (I've just made that hardcoded, not sure we want to worry about SPMD support). Buuut imo it involves too much of the users input, particulary in their train/val/test step logic: def training_step(self, batch, batch_idx):
output = self(batch)
if self.trainer.accelerator_backend.final_stage:
loss = self.loss(batch, output)
self.log('loss', loss)
print(loss)
return {"loss": loss}
else:
self.trainer.accelerator_backend.sync_gradients(output)
def validation_step(self, batch, batch_idx):
output = self(batch)
if self.trainer.accelerator_backend.final_stage:
loss = self.loss(batch, output)
self.log('x', loss)
def test_step(self, batch, batch_idx):
output = self.layers(batch)
if self.trainer.accelerator_backend.final_stage:
loss = self.loss(batch, output)
self.log('y', loss)This check is required because some GPUs are just intermediates (they only contain the portion of the model that the activations are passed downstream to other processes, no loss calculated). This happens to include GPU 0 so right now logging is a mess. Theres' a neat refactor from @froody which will allow control from just one process, which should make things much cleaner and potentially remove the manual final stage check within the step functions (facebookresearch/fairscale#173 (comment)) |
tchaton
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome addition ! I wonder if we could block function call using self.trainer.accelerator_backend.final_stage in self.log. However, I guess sync_gradients can't be moved out as it needs output.
| style=PipelineStyle.MultiProcess, | ||
| input_device=torch.cuda.current_device(), | ||
| worker_map=get_worker_map(), | ||
| checkpoint='never', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is checkpoint never ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The checkpoint argument refers to activation checkpointing, i.e. not saving grads on the initial forward pass, and then re-running the forward pass with grads enabled during the backward pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will be exposed eventually of course :)
| backend=rpc.BackendType.TENSORPIPE, | ||
| rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=init_method), | ||
| ) | ||
| mpu.initialize_model_parallel(model_parallel_size_=1, pipeline_length=len(self.pipe_module.layer_partitions)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May the user want to modify those parameters ?
|
Thanks for your comments @tchaton :) I had a question around the API, since I think this is eventually going to tie to a few other important components (like standalone gradient checkpointing, and potentially even the shardedDDP stuff). In the current API I've wrapped the sequential module in a lightning module, and defined a custom accelerator in the code: train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
val_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
# model
model = BoringModel()
# model.layers is a sequential module that needs to be manually wrapped
model.layers = LightningPipeModule(
model.layers,
layer_partitions=[1, 1], # Puts 1 layer on each GPU
microbatches=8 # Bubble partitioning under the hood for device utilization
)
accelerator = PipeAccelerator(model.layers, cluster_environment=TorchElasticEnvironment())
trainer = Trainer(
default_root_dir=os.getcwd(),
max_epochs=1,
gpus=2,
accelerator=accelerator
)
trainer.fit(model, train_data, val_data)This is meh and I'd prefer to so something like: train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
val_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
# model
model = BoringModel()
# model.layers is a sequential module that needs to be manually wrapped
model.layers = LightningPipeModule(
model.layers,
layer_partitions=[1, 1], # Puts 1 layer on each GPU
microbatches=8 # Bubble partitioning under the hood for device utilization
)
trainer = Trainer(
default_root_dir=os.getcwd(),
max_epochs=1,
gpus=2,
accelerator='ddp_pipe' # Skip initializing the accelerator beforehand!
)
trainer.fit(model, train_data, val_data)Is there a way I could 'register' the pipe module for the accelerator to pick up automatically? |
What does this PR do?
Closes #4322.
Relates to #4178 but I want to keep separate the FairScale ZeRO + ShardedDDP integration and the Pipe + Checkpointing integration since they're decoupled (for now) in fairscale.
Pipe allows a sequential model to be split onto separate GPUs. It comes with its own hyper-parameters and because it's tied to a
torch.nn.sequentialit needs closer integration from the users perspective.Still unsure on the API, but want to throw something out for us to discuss.
All the feedback welcomed :)
TODO
backward_helperrather than calling the losscc @ananthsub @williamFalcon
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃