1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import re
1415from collections import OrderedDict
1516from logging import INFO
1617from typing import Union
2122from torch import nn
2223from torch .nn import Sequential
2324
24- from pytorch_lightning import seed_everything , Trainer
25+ from pytorch_lightning import Trainer
2526from pytorch_lightning .callbacks import ModelCheckpoint , ModelPruning
2627from pytorch_lightning .utilities .exceptions import MisconfigurationException
2728from tests .helpers import BoringModel
@@ -224,7 +225,6 @@ def apply_lottery_ticket_hypothesis(self):
224225
225226@pytest .mark .parametrize ("make_pruning_permanent" , (False , True ))
226227def test_multiple_pruning_callbacks (tmpdir , caplog , make_pruning_permanent : bool ):
227- seed_everything (0 )
228228 model = TestModel ()
229229 pruning_kwargs = {
230230 'parameters_to_prune' : [(model .layer .mlp_1 , "weight" ), (model .layer .mlp_3 , "weight" )],
@@ -250,17 +250,20 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool
250250
251251 actual = [m .strip () for m in caplog .messages ]
252252 actual = [m for m in actual if m .startswith ("Applied" )]
253- assert actual == [
254- "Applied `L1Unstructured`. Pruned: 0/1122 (0.00%) -> 544/1122 (48.48%)" ,
255- "Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 503 (49.12%)" , # noqa: E501
256- "Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 41 (64.06%)" , # noqa: E501
257- "Applied `RandomUnstructured`. Pruned: 544/1122 (48.48%) -> 680/1122 (60.61%)" ,
258- "Applied `RandomUnstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.25. Pruned: 503 (49.12%) -> 629 (61.43%)" , # noqa: E501
259- "Applied `RandomUnstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.25. Pruned: 41 (64.06%) -> 51 (79.69%)" , # noqa: E501
260- "Applied `L1Unstructured`. Pruned: 680/1122 (60.61%) -> 884/1122 (78.79%)" ,
261- "Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 629 (61.43%) -> 827 (80.76%)" , # noqa: E501
262- "Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 51 (79.69%) -> 57 (89.06%)" , # noqa: E501
253+ percentage = r"\(\d+(?:\.\d+)?%\)"
254+ expected = [
255+ rf"Applied `L1Unstructured`. Pruned: \d+\/1122 { percentage } -> \d+\/1122 { percentage } " ,
256+ rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ { percentage } " , # noqa: E501
257+ rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ { percentage } " , # noqa: E501
258+ rf"Applied `RandomUnstructured`. Pruned: \d+\/1122 { percentage } -> \d+\/1122 { percentage } " ,
259+ rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.25. Pruned: \d+ { percentage } -> \d+ { percentage } " , # noqa: E501
260+ rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.25. Pruned: \d+ { percentage } -> \d+ { percentage } " , # noqa: E501
261+ rf"Applied `L1Unstructured`. Pruned: \d+\/1122 { percentage } -> \d+\/1122 { percentage } " ,
262+ rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: \d+ { percentage } -> \d+ { percentage } " , # noqa: E501
263+ rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: \d+ { percentage } -> \d+ { percentage } " , # noqa: E501
263264 ]
265+ expected = [re .compile (s ) for s in expected ]
266+ assert all (regex .match (s ) for s , regex in zip (actual , expected ))
264267
265268 filepath = str (tmpdir / "foo.ckpt" )
266269 trainer .save_checkpoint (filepath )
@@ -270,27 +273,31 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool
270273 assert not has_pruning if make_pruning_permanent else has_pruning
271274
272275
273- def test_permanent_when_model_is_saved_multiple_times (tmpdir , caplog ):
276+ @pytest .mark .parametrize ("on_train_epoch_end" , (False , True ))
277+ def test_permanent_when_model_is_saved_multiple_times (tmpdir , caplog , on_train_epoch_end ):
274278 """
275279 When a model is saved multiple times and make_permanent=True, we need to
276280 make sure a copy is pruned and not the trained model if we want to continue
277281 with the same pruning buffers.
278282 """
279- seed_everything (0 )
280283
281284 class TestPruning (ModelPruning ):
282285
283286 def on_save_checkpoint (self , trainer , pl_module , checkpoint ):
284287 super ().on_save_checkpoint (trainer , pl_module , checkpoint )
285- assert "layer.mlp_3.weight_orig" not in checkpoint ["state_dict" ]
286- assert hasattr (pl_module .layer .mlp_3 , "weight_orig" )
288+ if not on_train_epoch_end :
289+ # these checks only work if pruning on `validation_epoch_end`
290+ # because `on_save_checkpoint` is called before `on_train_epoch_end`
291+ assert "layer.mlp_3.weight_orig" not in checkpoint ["state_dict" ]
292+ assert hasattr (pl_module .layer .mlp_3 , "weight_orig" )
287293
288294 model = TestModel ()
289295 pruning_callback = TestPruning (
290296 "random_unstructured" ,
291297 parameters_to_prune = [(model .layer .mlp_3 , "weight" )],
292298 verbose = 1 ,
293- make_pruning_permanent = True
299+ make_pruning_permanent = True ,
300+ prune_on_train_epoch_end = on_train_epoch_end ,
294301 )
295302 ckpt_callback = ModelCheckpoint (monitor = "test" , save_top_k = 2 , save_last = True )
296303 trainer = Trainer (callbacks = [pruning_callback , ckpt_callback ], max_epochs = 3 , progress_bar_refresh_rate = 0 )
@@ -299,11 +306,14 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
299306
300307 actual = [m .strip () for m in caplog .messages ]
301308 actual = [m for m in actual if m .startswith ("Applied" )]
302- assert actual == [
303- "Applied `RandomUnstructured`. Pruned: 0/66 (0.00%) -> 32/66 (48.48%)" ,
304- "Applied `RandomUnstructured`. Pruned: 32/66 (48.48%) -> 48/66 (72.73%)" ,
305- "Applied `RandomUnstructured`. Pruned: 48/66 (72.73%) -> 56/66 (84.85%)" ,
309+ percentage = r"\(\d+(?:\.\d+)?%\)"
310+ expected = [
311+ rf"Applied `RandomUnstructured`. Pruned: \d+\/66 { percentage } -> \d+\/66 { percentage } " ,
312+ rf"Applied `RandomUnstructured`. Pruned: \d+\/66 { percentage } -> \d+\/66 { percentage } " ,
313+ rf"Applied `RandomUnstructured`. Pruned: \d+\/66 { percentage } -> \d+\/66 { percentage } " ,
306314 ]
315+ expected = [re .compile (s ) for s in expected ]
316+ assert all (regex .match (s ) for s , regex in zip (actual , expected ))
307317
308318 # removed on_train_end
309319 assert not hasattr (model .layer .mlp_3 , "weight_orig" )
0 commit comments