-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Improve Lite Examples #10195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Lite Examples #10195
Conversation
| def run(self, args): | ||
| train_kwargs = {"batch_size": args.batch_size} | ||
| test_kwargs = {"batch_size": args.test_batch_size} | ||
| def run(self, hparams): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am having a hard time to understand why we need to subclass LightningLite here.
Since we're calling run() directly from the main program, we're not introducing anything by making it a member of this class. As a matter of ffact if i take this run function and make it a top level function passing an instance of the base class (LightningLite) as first argument, I'll achieve exactly the same thing without changing any code. What am I missing here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is made so that we can spawn processes automatically for users when they run the run method. We could have done it the way you suggested, but this would mean that if the user implements their code on a single GPU for example and then wish to go to multiple GPUs or TPU, they would need to change their code again. Instead, with the requirement for a run method, we can be fully agnostic to the accelerator and plugins. The user only has to change the arguments to Lite.
It is very likely that users will ask the same question as you! We can make this clearer in the doc.
aribornstein
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks much better than yesterday. Main comments
- At the beginning of each example script make an ordered list that explains all the changes that were made from the previous script and their line numbers
- Before each change put a comment above it with a short explanation of what the change does and the number from the ordered list at the top of the script
- Use comments to deliniate the data processing code from the model training code so that it will be more clear what happens in the 5 script
| def run(self, hparams): | ||
| self.hparams = hparams | ||
| seed_everything(hparams.seed) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we put a comment between 33 and 44 to delinate the data code
| loss = F.nll_loss(output, target) | ||
|
|
||
| #################### | ||
| self.backward(loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets put a list at the top of all the changes an their line numbers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and a comment here that says change loss.backward() to self.backward(loss)
| if hparams.dry_run: | ||
| break | ||
|
|
||
| test_loss = self.all_gather(test_loss).sum() / len(test_loader.dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put a comment explaing this changes
| break | ||
|
|
||
| if args.save_model and self.is_global_zero: | ||
| if hparams.save_model and self.is_global_zero: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and another comment explaining this
| hparams = parser.parse_args() | ||
|
|
||
| Lite(**lite_kwargs).run(args) | ||
| Lite(accelerator="gpu" if torch.cuda.is_available() else "cpu", devices=torch.cuda.device_count()).run(hparams) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put a comment explaining this
What does this PR do?
Set devices to 1 when it's just
Trainer(accelerator='auto')Does your PR introduce any breaking changes? If yes, please list them.
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃