-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Deprecatetruncated_bptt_steps flag on Trainer in favor of same setting on the LightningModule
#7323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0dcbb8c
b428df3
887028c
2152f22
03b062e
3d0eb17
fc44fc1
e424b2a
1ba38ea
074818d
020adeb
a6179ae
418cb7e
597c643
d239f99
148ab33
d1669fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1005,6 +1005,63 @@ Get the model file size (in megabytes) using ``self.model_size`` inside Lightnin | |
|
|
||
| -------------- | ||
|
|
||
| truncated_bptt_steps | ||
| ^^^^^^^^^^^^^^^^^^^^ | ||
|
|
||
| Truncated back prop breaks performs backprop every k steps of | ||
| a much longer sequence. | ||
|
|
||
| If this is enabled, your batches will automatically get truncated | ||
| and the trainer will apply Truncated Backprop to it. | ||
|
|
||
| (`Williams et al. "An efficient gradient-based algorithm for on-line training of | ||
| recurrent network trajectories." | ||
| <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_) | ||
ananthsub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| `Tutorial <https://d2l.ai/chapter_recurrent-neural-networks/bptt.html>`_ | ||
|
|
||
| .. testcode:: python | ||
|
|
||
| from pytorch_lightning import LightningModule | ||
|
|
||
| class MyModel(LightningModule): | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| # Important: This property activates truncated backpropagation through time | ||
| # Setting this value to 2 splits the batch into sequences of size 2 | ||
| self.truncated_bptt_steps = 2 | ||
|
|
||
| # Truncated back-propagation through time | ||
| def training_step(self, batch, batch_idx, hiddens): | ||
| # the training step must be updated to accept a ``hiddens`` argument | ||
| # hiddens are the hiddens from the previous truncated backprop step | ||
| out, hiddens = self.lstm(data, hiddens) | ||
| return { | ||
| "loss": ..., | ||
| "hiddens": hiddens | ||
| } | ||
|
Comment on lines
+1023
to
+1043
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why exactly the same example as in sequences.rst?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no good reason other than copy/paste. any suggestions on how they should be different?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I may drop the example here and keep only the earlier one...
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would make this short here and link to the "sequences" document which provides all example and tutorials. |
||
|
|
||
| Lightning takes care to split your batch along the time-dimension. | ||
|
|
||
| .. code-block:: python | ||
| # we use the second as the time dimension | ||
| # (batch, time, ...) | ||
| sub_batch = batch[0, 0:t, ...] | ||
| To modify how the batch is split, | ||
| override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`: | ||
|
|
||
| .. testcode:: python | ||
|
|
||
| class LitMNIST(LightningModule): | ||
| def tbptt_split_batch(self, batch, split_size): | ||
| # do your own splitting on the batch | ||
| return splits | ||
|
|
||
| -------------- | ||
|
|
||
| Hooks | ||
| ^^^^^ | ||
| This is the pseudocode to describe how all the hooks are called during a call to ``.fit()``. | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,7 +14,7 @@ | |||||||
|
|
||||||||
| from contextlib import contextmanager, suppress | ||||||||
| from copy import copy, deepcopy | ||||||||
| from typing import Dict, List, Optional, Union | ||||||||
| from typing import Any, Dict, List, Optional, Union | ||||||||
|
|
||||||||
| import numpy as np | ||||||||
| import torch | ||||||||
|
|
@@ -432,12 +432,13 @@ def _track_gradient_norm(self): | |||||||
| grad_norm_dict = grad_norm(model, self.trainer.track_grad_norm) | ||||||||
| return grad_norm_dict | ||||||||
|
|
||||||||
| def tbptt_split_batch(self, batch): | ||||||||
| def _tbptt_split_batch(self, batch: Any) -> List[Any]: | ||||||||
| splits = [batch] | ||||||||
| if self.trainer.truncated_bptt_steps is not None: | ||||||||
| truncated_bptt_enabled = self._truncated_bptt_enabled() | ||||||||
| if truncated_bptt_enabled: | ||||||||
ananthsub marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| model_ref = self.trainer.lightning_module | ||||||||
| with self.trainer.profiler.profile("tbptt_split_batch"): | ||||||||
| splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps) | ||||||||
| splits = model_ref.tbptt_split_batch(batch, self._truncated_bptt_steps()) | ||||||||
| return splits | ||||||||
|
|
||||||||
| def run_training_epoch(self): | ||||||||
|
|
@@ -612,7 +613,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): | |||||||
| return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) | ||||||||
|
|
||||||||
| # lightning module hook | ||||||||
| splits = self.tbptt_split_batch(batch) | ||||||||
| splits = self._tbptt_split_batch(batch) | ||||||||
|
|
||||||||
| for split_idx, split_batch in enumerate(splits): | ||||||||
|
|
||||||||
|
|
@@ -876,11 +877,22 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): | |||||||
| ) | ||||||||
|
|
||||||||
| # pass hiddens if using tbptt | ||||||||
| if self.trainer.truncated_bptt_steps is not None: | ||||||||
| if self._truncated_bptt_enabled(): | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall this also be a property?
Suggested change
|
||||||||
| args.append(hiddens) | ||||||||
|
|
||||||||
| return args | ||||||||
|
|
||||||||
| def _truncated_bptt_enabled(self) -> bool: | ||||||||
ananthsub marked this conversation as resolved.
Show resolved
Hide resolved
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| """ Temporary tbptt utilities until this flag is fully migrated to the lightning module. """ | ||||||||
| return self._truncated_bptt_steps() > 0 | ||||||||
|
|
||||||||
| def _truncated_bptt_steps(self) -> int: | ||||||||
ananthsub marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| lightning_module = self.trainer.lightning_module | ||||||||
| # Give precedence to the LightningModule as the Trainer flag will be removed in v1.5 | ||||||||
| if lightning_module.truncated_bptt_steps > 0: | ||||||||
| return lightning_module.truncated_bptt_steps | ||||||||
| return self.trainer.truncated_bptt_steps or 0 | ||||||||
|
|
||||||||
| def save_loggers_on_train_batch_end(self): | ||||||||
| # when loggers should save to disk | ||||||||
| should_flush_logs = self.trainer.logger_connector.should_flush_logs | ||||||||
|
|
||||||||
Uh oh!
There was an error while loading. Please reload this page.