Skip to content

train_dreambooth_flax.py 😭  #1088

@camenduru

Description

@camenduru

Describe the bug

I tried CompVis/stable-diffusion-v1-4 flax and bf16 branches with train_dreambooth_flax.py but not working and I can generate images with this code in same vm

    real_seed = random.randint(0, 2147483647)
    prng_seed = jax.random.PRNGKey(real_seed)
    num_samples = jax.device_count()
    prompt_n = num_samples * [prompt]
    prompt_ids = pipe.prepare_inputs(prompt_n)
    prng_seed = jax.random.split(prng_seed, jax.device_count())
    prompt_ids = shard(prompt_ids)
    images = pipe(prompt_ids, params, prng_seed, num_inference_steps=num_inference_steps, height=height, width=width, guidance_scale=guidance_scale, jit=True).images
    images = pipe.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

From error log:

File "train_dreambooth_flax.py", line 370, in main
    prompt_ids = shard(prompt_ids)

    lambda x: x.reshape((local_device_count, -1) + x.shape[1:]), xs)
ValueError: cannot reshape array of size 308 into shape (8,newaxis,77)

prompt_ids = shard(prompt_ids)

Reproduction

pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
git clone https://github.com/huggingface/diffusers
cd diffusers/examples/dreambooth
mkdir instance-images class-images save-model
pip install -U -r requirements_flax.txt
sudo apt-get install git-lfs
git lfs install
git clone -b flax https://user:[email protected]/CompVis/stable-diffusion-v1-4
wget http://1.jpg http://2.jpg http://3.jpg -P instance-images
python3 train_dreambooth_flax.py --pretrained_model_name_or_path="/home/camenduru/diffusers/examples/dreambooth/stable-diffusion-v1-4"  \
--instance_data_dir="instance-images"  \
--class_data_dir="class-images"  \
--output_dir="save-model"  \
--with_prior_preservation  \
--prior_loss_weight=1.0  \
--instance_prompt="parkminyoung"  \
--class_prompt="person"  \
--resolution=512  \
--train_batch_size=1  \
--learning_rate=5e-6  \
--num_class_images=12  \
--max_train_steps=650 

Logs

INFO:__main__:Number of class images to sample: 12.
Generating class images:   0%|                                                                                                                                                               | 0/3 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "train_dreambooth_flax.py", line 665, in <module>
    main()
  File "train_dreambooth_flax.py", line 370, in main
    prompt_ids = shard(prompt_ids)
  File "/home/camenduru/.local/lib/python3.8/site-packages/flax/training/common_utils.py", line 37, in shard
    return jax.tree_util.tree_map(
  File "/home/camenduru/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 200, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/camenduru/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 200, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/camenduru/.local/lib/python3.8/site-packages/flax/training/common_utils.py", line 38, in <lambda>
    lambda x: x.reshape((local_device_count, -1) + x.shape[1:]), xs)
ValueError: cannot reshape array of size 308 into shape (8,newaxis,77)


### System Info

gcloud | v3-8 | tpu-vm-base

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions