Skip to content

Commit 7c82a16

Browse files
anton-lpatil-surajpcuencapatrickvonplaten
authored
Fix EMA for multi-gpu training in the unconditional example (#1930)
* improve EMA * style * one EMA model * quality * fix tests * fix test * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * re organise the unconditional script * backwards compatibility * default to init values for some args * fix ort script * issubclass => isinstance * update state_dict * docstr * doc * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * use .to if device is passed * deprecate device * make flake happy * fix typo Co-authored-by: patil-suraj <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent f354dd9 commit 7c82a16

File tree

6 files changed

+314
-216
lines changed

6 files changed

+314
-216
lines changed

examples/text_to_image/train_text_to_image.py

Lines changed: 2 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@
1414
# See the License for the specific language governing permissions and
1515

1616
import argparse
17-
import copy
1817
import logging
1918
import math
2019
import os
2120
import random
2221
from pathlib import Path
23-
from typing import Iterable, Optional
22+
from typing import Optional
2423

2524
import numpy as np
2625
import torch
@@ -36,6 +35,7 @@
3635
from datasets import load_dataset
3736
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
3837
from diffusers.optimization import get_scheduler
38+
from diffusers.training_utils import EMAModel
3939
from diffusers.utils import check_min_version
4040
from diffusers.utils.import_utils import is_xformers_available
4141
from huggingface_hub import HfFolder, Repository, whoami
@@ -305,115 +305,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
305305
}
306306

307307

308-
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
309-
class EMAModel:
310-
"""
311-
Exponential Moving Average of models weights
312-
"""
313-
314-
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
315-
parameters = list(parameters)
316-
self.shadow_params = [p.clone().detach() for p in parameters]
317-
318-
self.collected_params = None
319-
320-
self.decay = decay
321-
self.optimization_step = 0
322-
323-
@torch.no_grad()
324-
def step(self, parameters):
325-
parameters = list(parameters)
326-
327-
self.optimization_step += 1
328-
329-
# Compute the decay factor for the exponential moving average.
330-
value = (1 + self.optimization_step) / (10 + self.optimization_step)
331-
one_minus_decay = 1 - min(self.decay, value)
332-
333-
for s_param, param in zip(self.shadow_params, parameters):
334-
if param.requires_grad:
335-
s_param.sub_(one_minus_decay * (s_param - param))
336-
else:
337-
s_param.copy_(param)
338-
339-
torch.cuda.empty_cache()
340-
341-
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
342-
"""
343-
Copy current averaged parameters into given collection of parameters.
344-
345-
Args:
346-
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
347-
updated with the stored moving averages. If `None`, the
348-
parameters with which this `ExponentialMovingAverage` was
349-
initialized will be used.
350-
"""
351-
parameters = list(parameters)
352-
for s_param, param in zip(self.shadow_params, parameters):
353-
param.data.copy_(s_param.data)
354-
355-
def to(self, device=None, dtype=None) -> None:
356-
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
357-
358-
Args:
359-
device: like `device` argument to `torch.Tensor.to`
360-
"""
361-
# .to() on the tensors handles None correctly
362-
self.shadow_params = [
363-
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
364-
for p in self.shadow_params
365-
]
366-
367-
def state_dict(self) -> dict:
368-
r"""
369-
Returns the state of the ExponentialMovingAverage as a dict.
370-
This method is used by accelerate during checkpointing to save the ema state dict.
371-
"""
372-
# Following PyTorch conventions, references to tensors are returned:
373-
# "returns a reference to the state and not its copy!" -
374-
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
375-
return {
376-
"decay": self.decay,
377-
"optimization_step": self.optimization_step,
378-
"shadow_params": self.shadow_params,
379-
"collected_params": self.collected_params,
380-
}
381-
382-
def load_state_dict(self, state_dict: dict) -> None:
383-
r"""
384-
Loads the ExponentialMovingAverage state.
385-
This method is used by accelerate during checkpointing to save the ema state dict.
386-
Args:
387-
state_dict (dict): EMA state. Should be an object returned
388-
from a call to :meth:`state_dict`.
389-
"""
390-
# deepcopy, to be consistent with module API
391-
state_dict = copy.deepcopy(state_dict)
392-
393-
self.decay = state_dict["decay"]
394-
if self.decay < 0.0 or self.decay > 1.0:
395-
raise ValueError("Decay must be between 0 and 1")
396-
397-
self.optimization_step = state_dict["optimization_step"]
398-
if not isinstance(self.optimization_step, int):
399-
raise ValueError("Invalid optimization_step")
400-
401-
self.shadow_params = state_dict["shadow_params"]
402-
if not isinstance(self.shadow_params, list):
403-
raise ValueError("shadow_params must be a list")
404-
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
405-
raise ValueError("shadow_params must all be Tensors")
406-
407-
self.collected_params = state_dict["collected_params"]
408-
if self.collected_params is not None:
409-
if not isinstance(self.collected_params, list):
410-
raise ValueError("collected_params must be a list")
411-
if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
412-
raise ValueError("collected_params must all be Tensors")
413-
if len(self.collected_params) != len(self.shadow_params):
414-
raise ValueError("collected_params and shadow_params must have the same length")
415-
416-
417308
def main():
418309
args = parse_args()
419310
logging_dir = os.path.join(args.output_dir, args.logging_dir)

examples/unconditional_image_generation/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ accelerate launch train_unconditional.py \
3939
--train_batch_size=16 \
4040
--num_epochs=100 \
4141
--gradient_accumulation_steps=1 \
42+
--use_ema \
4243
--learning_rate=1e-4 \
4344
--lr_warmup_steps=500 \
4445
--mixed_precision=no \
@@ -63,6 +64,7 @@ accelerate launch train_unconditional.py \
6364
--train_batch_size=16 \
6465
--num_epochs=100 \
6566
--gradient_accumulation_steps=1 \
67+
--use_ema \
6668
--learning_rate=1e-4 \
6769
--lr_warmup_steps=500 \
6870
--mixed_precision=no \
@@ -150,6 +152,7 @@ accelerate launch train_unconditional_ort.py \
150152
--dataset_name="huggan/flowers-102-categories" \
151153
--resolution=64 \
152154
--output_dir="ddpm-ema-flowers-64" \
155+
--use_ema \
153156
--train_batch_size=16 \
154157
--num_epochs=1 \
155158
--gradient_accumulation_steps=1 \

0 commit comments

Comments
 (0)