From 12cae2b279201d084a85f588367aaa13495c245f Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Tue, 16 Mar 2021 20:20:07 +0100 Subject: [PATCH] Update stochastic_weight_avg.py --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index bece2ffe9f1b2..4292be0f862ff 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -17,6 +17,7 @@ """ from copy import deepcopy from typing import Callable, Optional, Union +from weakref import proxy import torch from torch import nn @@ -136,8 +137,15 @@ def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'): return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()) def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): + + # this bypasses pickling errors with advanced profiler + pl_module.trainer = None # copy the model before moving it to accelerator device. self._average_model = deepcopy(pl_module) + # reset trainer ref + trainer_ref = proxy(trainer) + self._average_model.trainer = trainer_ref + pl_module.trainer = trainer_ref def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): optimizers = trainer.optimizers