Skip to content

Example unconditional training script does not work #499

@c-goldschmidt

Description

@c-goldschmidt

Describe the bug

When running the example unconditional training script
on GPU using --mixed_precision=fp16, the autocasting of accelerate will break:

Traceback (most recent call last):
  File "<path>\train_unconditional.py", line 251, in <module>
    main(args)
  File "<path>\train_unconditional.py", line 145, in main
    loss = F.mse_loss(noise_pred, noise)
  File "<lib-path>\lib\site-packages\torch\nn\functional.py", line 3269, in mse_loss
    if not (target.size() == input.size()):
AttributeError: 'dict' object has no attribute 'size'

This happens due to the fact that in accelerate.utils.operations.recursively_apply the "Mapping" class is handled using a dict initialization (=> https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/operations.py#L87). This breaks the UNet2DOutput class, because it will wrap the output in a dict (because UNet2DOutput does inherit from OrderedDict).
i.e. UNet2DOutput(sample=sample) will become equivalent to UNet2DOutput(sample={'sample': sample}).

now this could be changed in the train_unconditional.py. unfortunately this just moves the problem, as it will fail in a later stage in library code that isn't accessible to the user.

Reproduction

Run the example script with mixed-precision set to fp16:

train_unconditional.py --dataset_name="huggan/pokemon" --output_dir="ddpm-test-64" --mixed_precision=fp16

Logs

No response

System Info

  • diffusers version: 0.3.0
  • Platform: Windows-10-10.0.22000-SP0
  • Python version: 3.9.13
  • PyTorch version (GPU?): 1.12.1+cu116 (True)
  • Huggingface_hub version: 0.9.1
  • Transformers version: 4.21.3
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

some comments:

  • i'm running on windows 11, wrongly reported as 10 using diffusers-cli env
  • accelerate version: 0.12.0

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions