Skip to content

[Community] Exception in device=TPU:0: Got unsupported ScalarType BFloat16 #898

@ssusie

Description

@ssusie

Describe the bug

Trying to train Stable Diffusion on TPU with configuring accelerator to use TPU and changing fp16 to bf16. Having the following issue:

Exception in device=TPU:0: Got unsupported ScalarType BFloat16
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/ssusie/.local/lib/python3.8/site-packages/accelerate/utils/launch.py", line 89, in __call__
    self.launcher(*args)
  File "/home/ssusie/diffusers/examples/text_to_image/train_text_to_image.py", line 620, in main
    pipeline.save_pretrained(args.output_dir)
  File "/home/ssusie/.local/lib/python3.8/site-packages/diffusers/pipeline_utils.py", line 180, in save_pretrained
    save_method(os.path.join(save_directory, pipeline_component_name))
  File "/home/ssusie/.local/lib/python3.8/site-packages/diffusers/modeling_utils.py", line 209, in save_pretrained
    save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
  File "/usr/local/lib/python3.8/dist-packages/torch/serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/usr/local/lib/python3.8/dist-packages/torch/serialization.py", line 589, in _save
    pickler.dump(obj)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 177, in __reduce_ex__
    return self._reduce_ex_internal(proto)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 223, in _reduce_ex_internal
    return (torch._utils._rebuild_device_tensor_from_numpy, (self.cpu().numpy(),
TypeError: Got unsupported ScalarType BFloat16
^MSteps: 100%|██████████| 15000/15000 [9:13:57<00:00,  2.22s/it, lr=1e-5, step_loss=0.0033]
Traceback (most recent call last):
  File "/home/ssusie/.local/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/ssusie/.local/lib/python3.8/site-packages/accelerate/commands/accelerate_cli.py", line 43, in main
    args.func(args)
  File "/home/ssusie/.local/lib/python3.8/site-packages/accelerate/commands/launch.py", line 906, in launch_command
    tpu_launcher(args)
  File "/home/ssusie/.local/lib/python3.8/site-packages/accelerate/commands/launch.py", line 652, in tpu_launcher
    xmp.spawn(PrepareForLaunch(main_function), args=(), nprocs=args.num_processes)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 389, in spawn
    return torch.multiprocessing.start_processes(
  File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 149, in join
    raise ProcessExitedException(
torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with exit code 17

Reproduction

I ran accelerate config and chose TPU with 4 cores. I removed tensor_format from here

beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"

export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"

accelerate launch train_text_to_image.py   \
--pretrained_model_name_or_path=$MODEL_NAME   \
--dataset_name=$dataset_name   \
--use_ema   \
--resolution=512 --center_crop --random_flip   \
--train_batch_size=1   \
--gradient_accumulation_steps=1  \
--gradient_checkpointing   \
--mixed_precision="bf16"   \
--max_train_steps=15000   \
--learning_rate=1e-05  \
 --max_grad_norm=1   \
--lr_scheduler="constant" --lr_warmup_steps=0   \
--output_dir="sd-pokemon-model"

Logs

No response

System Info

  • diffusers version: 0.6.0.dev0
  • Platform: Linux-5.13.0-1023-gcp-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • PyTorch version (GPU?): 1.12.0+cu102 (False)
  • Huggingface_hub version: 0.10.1
  • Transformers version: 4.23.1
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions