Skip to content

Commit 0835fa2

Browse files
committed
Cleanup
1 parent d5c50b1 commit 0835fa2

File tree

1 file changed

+14
-36
lines changed

1 file changed

+14
-36
lines changed

tests/core/test_lightning_optimizer.py

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import torch
1919
from torch.optim import Adam, Optimizer, SGD
20-
from torch.optim.optimizer import _RequiredParameter
2120

2221
from pytorch_lightning import Trainer
2322
from pytorch_lightning.core.optimizer import LightningOptimizer
@@ -308,31 +307,17 @@ def configure_optimizers(self):
308307
assert zero_grad.call_count == max_iter
309308

310309

311-
required = _RequiredParameter()
312-
313-
314310
class OptimizerWithHooks(Optimizer):
315311

316-
def __init__(self, model, lr=required, u0=required):
317-
if lr is not required and lr < 0.0:
318-
raise ValueError("Invalid learning rate: {}".format(lr))
319-
320-
defaults = dict(lr=lr)
321-
self.steps = 0
322-
323-
self.params = []
324-
312+
def __init__(self, model):
325313
self._fwd_handles = []
326314
self._bwd_handles = []
327-
328-
self.model = model
329-
330-
for _, mod in model.named_modules(): # iterates over modules of model
315+
self.params = []
316+
for _, mod in model.named_modules():
331317
mod_class = mod.__class__.__name__
332-
if mod_class not in ['Linear']: # silently skips other layers
318+
if mod_class != 'Linear':
333319
continue
334320

335-
# save the inputs and gradients for the kfac matrix computation
336321
handle = mod.register_forward_pre_hook(self._save_input) # save the inputs
337322
self._fwd_handles.append(handle) # collect forward-save-input hooks in list
338323
handle = mod.register_backward_hook(self._save_grad_output) # save the gradients
@@ -347,21 +332,21 @@ def __init__(self, model, lr=required, u0=required):
347332
d = {'params': params, 'mod': mod, 'layer_type': mod_class}
348333
self.params.append(d)
349334

350-
super(OptimizerWithHooks, self).__init__(self.params, defaults)
335+
super(OptimizerWithHooks, self).__init__(self.params, {"lr": 0.01})
351336

352337
def _save_input(self, mod, i):
353338
"""Saves input of layer"""
354339
if mod.training:
355340
self.state[mod]['x'] = i[0]
356341

357-
def _save_grad_output(self, mod, grad_input, grad_output):
342+
def _save_grad_output(self, mod, _, grad_output):
358343
"""
359344
Saves grad on output of layer to
360345
grad is scaled with batch_size since gradient is spread over samples in mini batch
361346
"""
362-
bs = grad_output[0].shape[0] # batch_size
347+
batch_size = grad_output[0].shape[0]
363348
if mod.training:
364-
self.state[mod]['grad'] = grad_output[0] * bs
349+
self.state[mod]['grad'] = grad_output[0] * batch_size
365350

366351
def step(self, closure=None):
367352
closure()
@@ -371,14 +356,11 @@ def step(self, closure=None):
371356
return True
372357

373358

374-
def test_lightning_optimizer_dont_delete_wrapped_optimizer(tmpdir):
359+
def test_lightning_optimizer_keeps_hooks(tmpdir):
375360

376361
class TestModel(BoringModel):
377-
378-
def __init__(self):
379-
super().__init__()
380-
self.count_on_train_batch_start = 0
381-
self.count_on_train_batch_end = 0
362+
count_on_train_batch_start = 0
363+
count_on_train_batch_end = 0
382364

383365
def configure_optimizers(self):
384366
return OptimizerWithHooks(self)
@@ -390,15 +372,11 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int)
390372

391373
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
392374
self.count_on_train_batch_end += 1
393-
# delete the lightning_optimizers
394-
self.trainer._lightning_optimizers = None
395-
gc.collect()
375+
del self.trainer._lightning_optimizers
376+
gc.collect() # not necessary, just in case
396377

378+
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=4, limit_val_batches=1, max_epochs=1)
397379
model = TestModel()
398-
# Initialize a trainer
399-
trainer = Trainer(limit_train_batches=4, limit_val_batches=1, max_epochs=1)
400-
401-
# Train the model ⚡
402380
trainer.fit(model)
403381
assert model.count_on_train_batch_start == 4
404382
assert model.count_on_train_batch_end == 4

0 commit comments

Comments
 (0)