|
14 | 14 | # See the License for the specific language governing permissions and |
15 | 15 |
|
16 | 16 | import argparse |
17 | | -import copy |
18 | 17 | import logging |
19 | 18 | import math |
20 | 19 | import os |
21 | 20 | import random |
22 | 21 | from pathlib import Path |
23 | | -from typing import Iterable, Optional |
| 22 | +from typing import Optional |
24 | 23 |
|
25 | 24 | import numpy as np |
26 | 25 | import torch |
|
36 | 35 | from datasets import load_dataset |
37 | 36 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel |
38 | 37 | from diffusers.optimization import get_scheduler |
| 38 | +from diffusers.training_utils import EMAModel |
39 | 39 | from diffusers.utils import check_min_version |
40 | 40 | from diffusers.utils.import_utils import is_xformers_available |
41 | 41 | from huggingface_hub import HfFolder, Repository, whoami |
@@ -305,115 +305,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: |
305 | 305 | } |
306 | 306 |
|
307 | 307 |
|
308 | | -# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 |
309 | | -class EMAModel: |
310 | | - """ |
311 | | - Exponential Moving Average of models weights |
312 | | - """ |
313 | | - |
314 | | - def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): |
315 | | - parameters = list(parameters) |
316 | | - self.shadow_params = [p.clone().detach() for p in parameters] |
317 | | - |
318 | | - self.collected_params = None |
319 | | - |
320 | | - self.decay = decay |
321 | | - self.optimization_step = 0 |
322 | | - |
323 | | - @torch.no_grad() |
324 | | - def step(self, parameters): |
325 | | - parameters = list(parameters) |
326 | | - |
327 | | - self.optimization_step += 1 |
328 | | - |
329 | | - # Compute the decay factor for the exponential moving average. |
330 | | - value = (1 + self.optimization_step) / (10 + self.optimization_step) |
331 | | - one_minus_decay = 1 - min(self.decay, value) |
332 | | - |
333 | | - for s_param, param in zip(self.shadow_params, parameters): |
334 | | - if param.requires_grad: |
335 | | - s_param.sub_(one_minus_decay * (s_param - param)) |
336 | | - else: |
337 | | - s_param.copy_(param) |
338 | | - |
339 | | - torch.cuda.empty_cache() |
340 | | - |
341 | | - def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: |
342 | | - """ |
343 | | - Copy current averaged parameters into given collection of parameters. |
344 | | -
|
345 | | - Args: |
346 | | - parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
347 | | - updated with the stored moving averages. If `None`, the |
348 | | - parameters with which this `ExponentialMovingAverage` was |
349 | | - initialized will be used. |
350 | | - """ |
351 | | - parameters = list(parameters) |
352 | | - for s_param, param in zip(self.shadow_params, parameters): |
353 | | - param.data.copy_(s_param.data) |
354 | | - |
355 | | - def to(self, device=None, dtype=None) -> None: |
356 | | - r"""Move internal buffers of the ExponentialMovingAverage to `device`. |
357 | | -
|
358 | | - Args: |
359 | | - device: like `device` argument to `torch.Tensor.to` |
360 | | - """ |
361 | | - # .to() on the tensors handles None correctly |
362 | | - self.shadow_params = [ |
363 | | - p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) |
364 | | - for p in self.shadow_params |
365 | | - ] |
366 | | - |
367 | | - def state_dict(self) -> dict: |
368 | | - r""" |
369 | | - Returns the state of the ExponentialMovingAverage as a dict. |
370 | | - This method is used by accelerate during checkpointing to save the ema state dict. |
371 | | - """ |
372 | | - # Following PyTorch conventions, references to tensors are returned: |
373 | | - # "returns a reference to the state and not its copy!" - |
374 | | - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict |
375 | | - return { |
376 | | - "decay": self.decay, |
377 | | - "optimization_step": self.optimization_step, |
378 | | - "shadow_params": self.shadow_params, |
379 | | - "collected_params": self.collected_params, |
380 | | - } |
381 | | - |
382 | | - def load_state_dict(self, state_dict: dict) -> None: |
383 | | - r""" |
384 | | - Loads the ExponentialMovingAverage state. |
385 | | - This method is used by accelerate during checkpointing to save the ema state dict. |
386 | | - Args: |
387 | | - state_dict (dict): EMA state. Should be an object returned |
388 | | - from a call to :meth:`state_dict`. |
389 | | - """ |
390 | | - # deepcopy, to be consistent with module API |
391 | | - state_dict = copy.deepcopy(state_dict) |
392 | | - |
393 | | - self.decay = state_dict["decay"] |
394 | | - if self.decay < 0.0 or self.decay > 1.0: |
395 | | - raise ValueError("Decay must be between 0 and 1") |
396 | | - |
397 | | - self.optimization_step = state_dict["optimization_step"] |
398 | | - if not isinstance(self.optimization_step, int): |
399 | | - raise ValueError("Invalid optimization_step") |
400 | | - |
401 | | - self.shadow_params = state_dict["shadow_params"] |
402 | | - if not isinstance(self.shadow_params, list): |
403 | | - raise ValueError("shadow_params must be a list") |
404 | | - if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): |
405 | | - raise ValueError("shadow_params must all be Tensors") |
406 | | - |
407 | | - self.collected_params = state_dict["collected_params"] |
408 | | - if self.collected_params is not None: |
409 | | - if not isinstance(self.collected_params, list): |
410 | | - raise ValueError("collected_params must be a list") |
411 | | - if not all(isinstance(p, torch.Tensor) for p in self.collected_params): |
412 | | - raise ValueError("collected_params must all be Tensors") |
413 | | - if len(self.collected_params) != len(self.shadow_params): |
414 | | - raise ValueError("collected_params and shadow_params must have the same length") |
415 | | - |
416 | | - |
417 | 308 | def main(): |
418 | 309 | args = parse_args() |
419 | 310 | logging_dir = os.path.join(args.output_dir, args.logging_dir) |
|
0 commit comments