|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +import gc |
| 15 | +from typing import Any |
14 | 16 | from unittest.mock import DEFAULT, patch |
15 | 17 |
|
16 | 18 | import torch |
@@ -303,3 +305,78 @@ def configure_optimizers(self): |
303 | 305 | lbfgs = model.optimizers() |
304 | 306 | max_iter = lbfgs.param_groups[0]["max_iter"] |
305 | 307 | assert zero_grad.call_count == max_iter |
| 308 | + |
| 309 | + |
| 310 | +class OptimizerWithHooks(Optimizer): |
| 311 | + |
| 312 | + def __init__(self, model): |
| 313 | + self._fwd_handles = [] |
| 314 | + self._bwd_handles = [] |
| 315 | + self.params = [] |
| 316 | + for _, mod in model.named_modules(): |
| 317 | + mod_class = mod.__class__.__name__ |
| 318 | + if mod_class != 'Linear': |
| 319 | + continue |
| 320 | + |
| 321 | + handle = mod.register_forward_pre_hook(self._save_input) # save the inputs |
| 322 | + self._fwd_handles.append(handle) # collect forward-save-input hooks in list |
| 323 | + handle = mod.register_backward_hook(self._save_grad_output) # save the gradients |
| 324 | + self._bwd_handles.append(handle) # collect backward-save-grad hook in list |
| 325 | + |
| 326 | + # save the parameters |
| 327 | + params = [mod.weight] |
| 328 | + if mod.bias is not None: |
| 329 | + params.append(mod.bias) |
| 330 | + |
| 331 | + # save a param_group for each module |
| 332 | + d = {'params': params, 'mod': mod, 'layer_type': mod_class} |
| 333 | + self.params.append(d) |
| 334 | + |
| 335 | + super(OptimizerWithHooks, self).__init__(self.params, {"lr": 0.01}) |
| 336 | + |
| 337 | + def _save_input(self, mod, i): |
| 338 | + """Saves input of layer""" |
| 339 | + if mod.training: |
| 340 | + self.state[mod]['x'] = i[0] |
| 341 | + |
| 342 | + def _save_grad_output(self, mod, _, grad_output): |
| 343 | + """ |
| 344 | + Saves grad on output of layer to |
| 345 | + grad is scaled with batch_size since gradient is spread over samples in mini batch |
| 346 | + """ |
| 347 | + batch_size = grad_output[0].shape[0] |
| 348 | + if mod.training: |
| 349 | + self.state[mod]['grad'] = grad_output[0] * batch_size |
| 350 | + |
| 351 | + def step(self, closure=None): |
| 352 | + closure() |
| 353 | + for group in self.param_groups: |
| 354 | + _ = self.state[group['mod']]['x'] |
| 355 | + _ = self.state[group['mod']]['grad'] |
| 356 | + return True |
| 357 | + |
| 358 | + |
| 359 | +def test_lightning_optimizer_keeps_hooks(tmpdir): |
| 360 | + |
| 361 | + class TestModel(BoringModel): |
| 362 | + count_on_train_batch_start = 0 |
| 363 | + count_on_train_batch_end = 0 |
| 364 | + |
| 365 | + def configure_optimizers(self): |
| 366 | + return OptimizerWithHooks(self) |
| 367 | + |
| 368 | + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: |
| 369 | + self.count_on_train_batch_start += 1 |
| 370 | + optimizer = self.optimizers(use_pl_optimizer=False) |
| 371 | + assert len(optimizer._fwd_handles) == 1 |
| 372 | + |
| 373 | + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: |
| 374 | + self.count_on_train_batch_end += 1 |
| 375 | + del self.trainer._lightning_optimizers |
| 376 | + gc.collect() # not necessary, just in case |
| 377 | + |
| 378 | + trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=4, limit_val_batches=1, max_epochs=1) |
| 379 | + model = TestModel() |
| 380 | + trainer.fit(model) |
| 381 | + assert model.count_on_train_batch_start == 4 |
| 382 | + assert model.count_on_train_batch_end == 4 |
0 commit comments