Skip to content

Commit d26953c

Browse files
carmoccaawaelchli
andauthored
Add ModelPruning(prune_on_train_epoch_end) to choose when to apply pruning (#7704)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent b2d77a6 commit d26953c

File tree

3 files changed

+52
-25
lines changed

3 files changed

+52
-25
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818
- Added LightningCLI support for config files on object stores ([#7521](https://github.com/PyTorchLightning/pytorch-lightning/pull/7521))
1919

2020

21+
- Added `ModelPruning(prune_on_train_epoch_end=True|False)` to choose when to apply pruning ([#7704](https://github.com/PyTorchLightning/pytorch-lightning/pull/7704))
22+
23+
2124
- Added support for checkpointing based on a provided time interval during training ([#7515](https://github.com/PyTorchLightning/pytorch-lightning/pull/7515))
2225

2326

pytorch_lightning/callbacks/pruning.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
pruning_dim: Optional[int] = None,
7272
pruning_norm: Optional[int] = None,
7373
verbose: int = 0,
74+
prune_on_train_epoch_end: bool = True,
7475
) -> None:
7576
"""
7677
Model pruning Callback, using PyTorch's prune utilities.
@@ -141,6 +142,9 @@ def __init__(
141142
142143
verbose: Verbosity level. 0 to disable, 1 to log overall sparsity, 2 to log per-layer sparsity
143144
145+
prune_on_train_epoch_end: whether to apply pruning at the end of the training epoch.
146+
If this is ``False``, then the check runs at the end of the validation epoch.
147+
144148
Raises:
145149
MisconfigurationException:
146150
If ``parameter_names`` is neither ``"weight"`` nor ``"bias"``,
@@ -155,6 +159,7 @@ def __init__(
155159
self._parameters_to_prune = parameters_to_prune
156160
self._use_lottery_ticket_hypothesis = use_lottery_ticket_hypothesis
157161
self._resample_parameters = resample_parameters
162+
self._prune_on_train_epoch_end = prune_on_train_epoch_end
158163
self._parameter_names = parameter_names or self.PARAMETER_NAMES
159164
self._global_kwargs: Dict[str, Any] = {}
160165
self._original_layers: Optional[Dict[int, _LayerRef]] = None
@@ -381,8 +386,7 @@ def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module:
381386
self._original_layers.setdefault(id_, _LayerRef(data=deepcopy(module), names=[]))
382387
self._original_layers[id_]["names"].append((i, name))
383388

384-
def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None: # type: ignore
385-
current_epoch = pl_module.current_epoch
389+
def _run_pruning(self, current_epoch: int) -> None:
386390
prune = self._apply_pruning(current_epoch) if callable(self._apply_pruning) else self._apply_pruning
387391
amount = self.amount(current_epoch) if callable(self.amount) else self.amount
388392
if not prune or not amount:
@@ -395,9 +399,19 @@ def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: LightningModule)
395399
):
396400
self.apply_lottery_ticket_hypothesis()
397401

402+
def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None: # type: ignore
403+
if self._prune_on_train_epoch_end:
404+
rank_zero_debug("`ModelPruning.on_train_epoch_end`. Applying pruning")
405+
self._run_pruning(pl_module.current_epoch)
406+
407+
def on_validation_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
408+
if not trainer.sanity_checking and not self._prune_on_train_epoch_end:
409+
rank_zero_debug("`ModelPruning.on_validation_epoch_end`. Applying pruning")
410+
self._run_pruning(pl_module.current_epoch)
411+
398412
def on_train_end(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None:
399413
if self._make_pruning_permanent:
400-
rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint.")
414+
rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint")
401415
self.make_pruning_permanent(pl_module)
402416

403417
def on_save_checkpoint(
@@ -407,7 +421,7 @@ def on_save_checkpoint(
407421
checkpoint: Dict[str, Any],
408422
) -> Dict[str, Any]:
409423
if self._make_pruning_permanent:
410-
rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint.")
424+
rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint")
411425
prev_device = pl_module.device
412426
# prune a copy so training can continue with the same buffers
413427
copy = deepcopy(pl_module.to("cpu"))

tests/callbacks/test_pruning.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import re
1415
from collections import OrderedDict
1516
from logging import INFO
1617
from typing import Union
@@ -21,7 +22,7 @@
2122
from torch import nn
2223
from torch.nn import Sequential
2324

24-
from pytorch_lightning import seed_everything, Trainer
25+
from pytorch_lightning import Trainer
2526
from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning
2627
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2728
from tests.helpers import BoringModel
@@ -224,7 +225,6 @@ def apply_lottery_ticket_hypothesis(self):
224225

225226
@pytest.mark.parametrize("make_pruning_permanent", (False, True))
226227
def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool):
227-
seed_everything(0)
228228
model = TestModel()
229229
pruning_kwargs = {
230230
'parameters_to_prune': [(model.layer.mlp_1, "weight"), (model.layer.mlp_3, "weight")],
@@ -250,17 +250,20 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool
250250

251251
actual = [m.strip() for m in caplog.messages]
252252
actual = [m for m in actual if m.startswith("Applied")]
253-
assert actual == [
254-
"Applied `L1Unstructured`. Pruned: 0/1122 (0.00%) -> 544/1122 (48.48%)",
255-
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 503 (49.12%)", # noqa: E501
256-
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 41 (64.06%)", # noqa: E501
257-
"Applied `RandomUnstructured`. Pruned: 544/1122 (48.48%) -> 680/1122 (60.61%)",
258-
"Applied `RandomUnstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.25. Pruned: 503 (49.12%) -> 629 (61.43%)", # noqa: E501
259-
"Applied `RandomUnstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.25. Pruned: 41 (64.06%) -> 51 (79.69%)", # noqa: E501
260-
"Applied `L1Unstructured`. Pruned: 680/1122 (60.61%) -> 884/1122 (78.79%)",
261-
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 629 (61.43%) -> 827 (80.76%)", # noqa: E501
262-
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 51 (79.69%) -> 57 (89.06%)", # noqa: E501
253+
percentage = r"\(\d+(?:\.\d+)?%\)"
254+
expected = [
255+
rf"Applied `L1Unstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
256+
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ {percentage}", # noqa: E501
257+
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ {percentage}", # noqa: E501
258+
rf"Applied `RandomUnstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
259+
rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.25. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
260+
rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.25. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
261+
rf"Applied `L1Unstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
262+
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
263+
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
263264
]
265+
expected = [re.compile(s) for s in expected]
266+
assert all(regex.match(s) for s, regex in zip(actual, expected))
264267

265268
filepath = str(tmpdir / "foo.ckpt")
266269
trainer.save_checkpoint(filepath)
@@ -270,27 +273,31 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool
270273
assert not has_pruning if make_pruning_permanent else has_pruning
271274

272275

273-
def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog):
276+
@pytest.mark.parametrize("on_train_epoch_end", (False, True))
277+
def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog, on_train_epoch_end):
274278
"""
275279
When a model is saved multiple times and make_permanent=True, we need to
276280
make sure a copy is pruned and not the trained model if we want to continue
277281
with the same pruning buffers.
278282
"""
279-
seed_everything(0)
280283

281284
class TestPruning(ModelPruning):
282285

283286
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
284287
super().on_save_checkpoint(trainer, pl_module, checkpoint)
285-
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]
286-
assert hasattr(pl_module.layer.mlp_3, "weight_orig")
288+
if not on_train_epoch_end:
289+
# these checks only work if pruning on `validation_epoch_end`
290+
# because `on_save_checkpoint` is called before `on_train_epoch_end`
291+
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]
292+
assert hasattr(pl_module.layer.mlp_3, "weight_orig")
287293

288294
model = TestModel()
289295
pruning_callback = TestPruning(
290296
"random_unstructured",
291297
parameters_to_prune=[(model.layer.mlp_3, "weight")],
292298
verbose=1,
293-
make_pruning_permanent=True
299+
make_pruning_permanent=True,
300+
prune_on_train_epoch_end=on_train_epoch_end,
294301
)
295302
ckpt_callback = ModelCheckpoint(monitor="test", save_top_k=2, save_last=True)
296303
trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, progress_bar_refresh_rate=0)
@@ -299,11 +306,14 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
299306

300307
actual = [m.strip() for m in caplog.messages]
301308
actual = [m for m in actual if m.startswith("Applied")]
302-
assert actual == [
303-
"Applied `RandomUnstructured`. Pruned: 0/66 (0.00%) -> 32/66 (48.48%)",
304-
"Applied `RandomUnstructured`. Pruned: 32/66 (48.48%) -> 48/66 (72.73%)",
305-
"Applied `RandomUnstructured`. Pruned: 48/66 (72.73%) -> 56/66 (84.85%)",
309+
percentage = r"\(\d+(?:\.\d+)?%\)"
310+
expected = [
311+
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
312+
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
313+
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
306314
]
315+
expected = [re.compile(s) for s in expected]
316+
assert all(regex.match(s) for s, regex in zip(actual, expected))
307317

308318
# removed on_train_end
309319
assert not hasattr(model.layer.mlp_3, "weight_orig")

0 commit comments

Comments
 (0)