Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""
from copy import deepcopy
from typing import Callable, Optional, Union
from weakref import proxy

import torch
from torch import nn
Expand Down Expand Up @@ -136,8 +137,15 @@ def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'):
return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())

def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):

# this bypasses pickling errors with advanced profiler
pl_module.trainer = None
# copy the model before moving it to accelerator device.
self._average_model = deepcopy(pl_module)
# reset trainer ref
trainer_ref = proxy(trainer)
self._average_model.trainer = trainer_ref
pl_module.trainer = trainer_ref

def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
optimizers = trainer.optimizers
Expand Down