Skip to content

Conversation

@ahendriksen
Copy link

This commit adds the gpu_choice parameter to Trainer. By default,
this parameter is set to 'manual' which causes no observable
difference in behavior.

When gpu_choice is set to "auto" and gpus is an int, then the
trainer will automatically allocate the first available GPU.
This is especially useful when GPUs are configured to be in "exclusive
mode", which means that only one process at a time can use them.

Before submitting

What does this PR do?

Fixes #951

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Comments:

  1. I am not sure if the pick_gpu functions should be placed in distrib_parts.py. I chose this because determine_root_gpu_device` is defined there.
  2. What is the idiomatic way to you specify that a test should only run when CUDA is available?
  3. I am unsure what possible interactions with distributed training there are and if so how I should handle them.
  4. I added a docstring, does that count as adding docs?

Did you have fun?

Make sure you had fun coding 🙃

@pep8speaks
Copy link

pep8speaks commented Apr 9, 2020

Hello @ahendriksen! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2020-04-10 08:00:40 UTC

@mergify mergify bot requested a review from a team April 9, 2020 12:58
@ahendriksen ahendriksen force-pushed the add-auto-gpu-choice branch from 0b98dab to 2d5db37 Compare April 9, 2020 13:00
@williamFalcon
Copy link
Contributor

williamFalcon commented Apr 9, 2020

I don't think the backoff would work because lightning scopes the visible GPUs no? so this needs to work with our auto-gpu scoping.

ie: do backoff first to id the GPUs... THEN set the cuda flags

https://github.com/PyTorchLightning/pytorch-lightning/blob/9754c5da55059dd89cf0a4fd582fe5df9449bbe5/pytorch_lightning/trainer/distrib_data_parallel.py#L267

@ahendriksen
Copy link
Author

ahendriksen commented Apr 9, 2020

If I understand correctly, this PR does backoff first to id the GPUs and THEN sets the cuda flags.

set_nvidia_flags is called after the gpus have been determined.

GPUs are determined here: https://github.com/PyTorchLightning/pytorch-lightning/blob/2d5db37eb7ade5ae1c4a861bfb231c1db08e512e/pytorch_lightning/trainer/trainer.py#L404

Nvidia flags are set a couple lines later:
https://github.com/PyTorchLightning/pytorch-lightning/blob/2d5db37eb7ade5ae1c4a861bfb231c1db08e512e/pytorch_lightning/trainer/trainer.py#L437

Is there something I am missing?

@williamFalcon
Copy link
Contributor

@ahendriksen makes sense!

@williamFalcon
Copy link
Contributor

what about just
auto_select_gpus=T|F
?

@mergify
Copy link
Contributor

mergify bot commented Apr 9, 2020

This pull request is now in conflict... :(

This commit adds the `gpu_choice` parameter to Trainer. By default,
this parameter is set to 'manual' which causes no observable
difference in behavior.

When `gpu_choice` is set to "auto" and `gpus` is an int, then the
trainer will automatically allocate the first available GPU.
This is especially useful when GPUs are configured to be in "exclusive
mode", which means that only one process at a time can use them.
@ahendriksen ahendriksen force-pushed the add-auto-gpu-choice branch from 2d5db37 to 0a18f3d Compare April 10, 2020 07:54
@ahendriksen ahendriksen force-pushed the add-auto-gpu-choice branch from 0a18f3d to db89a2c Compare April 10, 2020 08:00
@codecov
Copy link

codecov bot commented Apr 10, 2020

Codecov Report

Merging #1426 into master will decrease coverage by 0%.
The diff coverage is 60%.

@@          Coverage Diff           @@
##           master   #1426   +/-   ##
======================================
- Coverage      92%     91%   -0%     
======================================
  Files          66      66           
  Lines        3509    3542   +33     
======================================
+ Hits         3213    3232   +19     
- Misses        296     310   +14     

@Borda Borda added the feature Is an improvement or enhancement label Apr 10, 2020
@williamFalcon williamFalcon merged commit 7ac1580 into Lightning-AI:master Apr 10, 2020
@williamFalcon
Copy link
Contributor

really cool! @ahendriksen

@ahendriksen
Copy link
Author

Nice! Thanks for merging. I had fun 🎉

@Borda Borda added this to the 0.7.3 milestone Apr 10, 2020
tullie pushed a commit to tullie/pytorch-lightning that referenced this pull request Jun 7, 2020
* Add automatic GPU choice to trainer

This commit adds the `gpu_choice` parameter to Trainer. By default,
this parameter is set to 'manual' which causes no observable
difference in behavior.

When `gpu_choice` is set to "auto" and `gpus` is an int, then the
trainer will automatically allocate the first available GPU.
This is especially useful when GPUs are configured to be in "exclusive
mode", which means that only one process at a time can use them.

* Rename gpu_choice -> auto_select_gpus
@Borda Borda modified the milestones: 0.7.3, v0.7.x Apr 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature Is an improvement or enhancement

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Automatically pick available GPU

4 participants