Skip to content

Commit faae25c

Browse files
authored
Merge e663641 into 352e8f0
2 parents 352e8f0 + e663641 commit faae25c

File tree

3 files changed

+81
-22
lines changed

3 files changed

+81
-22
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6868
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))
6969

7070

71+
- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073))
72+
73+
7174
- Fixed incorrect usage of `detach()`, `cpu()`, `to()` ([#6216](https://github.com/PyTorchLightning/pytorch-lightning/pull/6216))
7275

7376

pytorch_lightning/callbacks/pruning.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import inspect
1919
from copy import deepcopy
2020
from functools import partial
21-
from typing import Any, Callable, List, Optional, Tuple, Union
21+
from typing import Any, Callable, List, Optional, Tuple, Union, Dict
2222

2323
import torch
2424
import torch.nn.utils.prune as pytorch_prune
@@ -27,7 +27,7 @@
2727
from pytorch_lightning import _logger as log
2828
from pytorch_lightning.callbacks.base import Callback
2929
from pytorch_lightning.core.lightning import LightningModule
30-
from pytorch_lightning.utilities import rank_zero_only
30+
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_debug
3131
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3232

3333
_PYTORCH_PRUNING_FUNCTIONS = {
@@ -246,14 +246,18 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytor
246246
def _wrap_pruning_fn(pruning_fn, **kwargs):
247247
return partial(pruning_fn, **kwargs)
248248

249-
def make_pruning_permanent(self):
250-
""" Makes ``parameters_to_prune`` current pruning permanent. """
251-
for module, param_name in self._parameters_to_prune:
252-
try:
253-
pytorch_prune.remove(module, param_name)
254-
except ValueError:
255-
# pruning already made permanent
256-
pass
249+
def make_pruning_permanent(self, pl_module: LightningModule):
250+
"""
251+
Removes pruning buffers from any pruned modules
252+
253+
Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180
254+
"""
255+
for _, module in pl_module.named_modules():
256+
for k in list(module._forward_pre_hooks):
257+
hook = module._forward_pre_hooks[k]
258+
if isinstance(hook, pytorch_prune.BasePruningMethod):
259+
hook.remove(module)
260+
del module._forward_pre_hooks[k]
257261

258262
def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str):
259263
trained = getattr(module, tensor_name)
@@ -351,7 +355,7 @@ def _log_sparsity_stats(
351355
f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})"
352356
)
353357

354-
def on_before_accelerator_backend_setup(self, trainer, pl_module):
358+
def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule):
355359
parameters_to_prune = self.sanitize_parameters_to_prune(
356360
pl_module, self._parameters_to_prune, parameter_names=self._parameter_names
357361
)
@@ -367,7 +371,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module):
367371
self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []})
368372
self._original_layers[id_]["names"].append((i, name))
369373

370-
def on_train_epoch_end(self, trainer, pl_module, *args):
374+
def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs):
371375
current_epoch = trainer.current_epoch
372376
prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning
373377
amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount
@@ -381,13 +385,20 @@ def on_train_epoch_end(self, trainer, pl_module, *args):
381385
):
382386
self.apply_lottery_ticket_hypothesis()
383387

384-
def on_train_end(self, *args):
388+
def on_train_end(self, trainer, pl_module: LightningModule):
385389
if self._make_pruning_permanent:
386-
self.make_pruning_permanent()
390+
rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint.")
391+
self.make_pruning_permanent(pl_module)
387392

388-
def on_save_checkpoint(self, *args):
393+
def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]):
389394
if self._make_pruning_permanent:
390-
self.make_pruning_permanent()
395+
rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint.")
396+
prev_device = pl_module.device
397+
# prune a copy so training can continue with the same buffers
398+
copy = deepcopy(pl_module.to("cpu"))
399+
self.make_pruning_permanent(copy)
400+
checkpoint["state_dict"] = copy.state_dict()
401+
pl_module.to(prev_device)
391402

392403
@staticmethod
393404
def sanitize_parameters_to_prune(

tests/callbacks/test_pruning.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import platform
1616
from collections import OrderedDict
1717
from logging import INFO
18-
from unittest import mock
1918

2019
import pytest
2120
import torch
@@ -24,7 +23,7 @@
2423
from torch.nn import Sequential
2524

2625
from pytorch_lightning import seed_everything, Trainer
27-
from pytorch_lightning.callbacks import ModelPruning
26+
from pytorch_lightning.callbacks import ModelPruning, ModelCheckpoint
2827
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2928
from tests.helpers import BoringModel
3029

@@ -42,6 +41,10 @@ def __init__(self):
4241
])
4342
)
4443

44+
def training_step(self, batch, batch_idx):
45+
self.log("test", -batch_idx)
46+
return super().training_step(batch, batch_idx)
47+
4548

4649
class TestPruningMethod(pytorch_prune.BasePruningMethod):
4750
PRUNING_TYPE = "unstructured"
@@ -219,7 +222,6 @@ def apply_lottery_ticket_hypothesis(self):
219222

220223

221224
@pytest.mark.parametrize("make_pruning_permanent", (False, True))
222-
@mock.patch.dict(os.environ, {}, clear=True)
223225
def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
224226
seed_everything(0)
225227
model = TestModel()
@@ -244,8 +246,9 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
244246
with caplog.at_level(INFO):
245247
trainer.fit(model)
246248

247-
actual = [m.strip() for m in caplog.messages[-9:]]
248-
expected = [
249+
actual = [m.strip() for m in caplog.messages]
250+
actual = [m for m in actual if m.startswith("Applied")]
251+
assert actual == [
249252
"Applied `L1Unstructured`. Pruned: 0/1122 (0.00%) -> 544/1122 (48.48%)",
250253
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 506 (49.41%)", # noqa: E501
251254
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 38 (59.38%)", # noqa: E501
@@ -256,11 +259,53 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
256259
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 633 (61.82%) -> 828 (80.86%)", # noqa: E501
257260
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 47 (73.44%) -> 56 (87.50%)", # noqa: E501
258261
]
259-
assert actual == expected
260262

261263
filepath = str(tmpdir / "foo.ckpt")
262264
trainer.save_checkpoint(filepath)
263265

264266
model.load_from_checkpoint(filepath, strict=False)
265267
has_pruning = hasattr(model.layer.mlp_1, "weight_orig")
266268
assert not has_pruning if make_pruning_permanent else has_pruning
269+
270+
271+
def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog):
272+
"""
273+
When a model is saved multiple times and make_permanent=True, we need to
274+
make sure a copy is pruned and not the trained model if we want to continue
275+
with the same pruning buffers.
276+
"""
277+
seed_everything(0)
278+
279+
class TestPruning(ModelPruning):
280+
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
281+
super().on_save_checkpoint(trainer, pl_module, checkpoint)
282+
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]
283+
assert hasattr(pl_module.layer.mlp_3, "weight_orig")
284+
285+
model = TestModel()
286+
pruning_callback = TestPruning(
287+
"random_unstructured",
288+
parameters_to_prune=[(model.layer.mlp_3, "weight")],
289+
verbose=1,
290+
make_pruning_permanent=True
291+
)
292+
ckpt_callback = ModelCheckpoint(monitor="test", save_top_k=2, save_last=True)
293+
trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, progress_bar_refresh_rate=0)
294+
with caplog.at_level(INFO):
295+
trainer.fit(model)
296+
297+
actual = [m.strip() for m in caplog.messages]
298+
actual = [m for m in actual if m.startswith("Applied")]
299+
assert actual == [
300+
"Applied `RandomUnstructured`. Pruned: 0/66 (0.00%) -> 32/66 (48.48%)",
301+
"Applied `RandomUnstructured`. Pruned: 32/66 (48.48%) -> 48/66 (72.73%)",
302+
"Applied `RandomUnstructured`. Pruned: 48/66 (72.73%) -> 56/66 (84.85%)",
303+
]
304+
305+
# removed on_train_end
306+
assert not hasattr(model.layer.mlp_3, "weight_orig")
307+
308+
model.load_from_checkpoint(trainer.checkpoint_callback.kth_best_model_path)
309+
assert not hasattr(model.layer.mlp_3, "weight_orig")
310+
model.load_from_checkpoint(trainer.checkpoint_callback.last_model_path)
311+
assert not hasattr(model.layer.mlp_3, "weight_orig")

0 commit comments

Comments
 (0)