-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
I tried CompVis/stable-diffusion-v1-4 flax and bf16 branches with train_dreambooth_flax.py but not working and I can generate images with this code in same vm
real_seed = random.randint(0, 2147483647)
prng_seed = jax.random.PRNGKey(real_seed)
num_samples = jax.device_count()
prompt_n = num_samples * [prompt]
prompt_ids = pipe.prepare_inputs(prompt_n)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)
images = pipe(prompt_ids, params, prng_seed, num_inference_steps=num_inference_steps, height=height, width=width, guidance_scale=guidance_scale, jit=True).images
images = pipe.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
From error log:
File "train_dreambooth_flax.py", line 370, in main
prompt_ids = shard(prompt_ids)
lambda x: x.reshape((local_device_count, -1) + x.shape[1:]), xs)
ValueError: cannot reshape array of size 308 into shape (8,newaxis,77)
| prompt_ids = shard(prompt_ids) |
Reproduction
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
git clone https://github.com/huggingface/diffusers
cd diffusers/examples/dreambooth
mkdir instance-images class-images save-model
pip install -U -r requirements_flax.txt
sudo apt-get install git-lfs
git lfs install
git clone -b flax https://user:[email protected]/CompVis/stable-diffusion-v1-4
wget http://1.jpg http://2.jpg http://3.jpg -P instance-images
python3 train_dreambooth_flax.py --pretrained_model_name_or_path="/home/camenduru/diffusers/examples/dreambooth/stable-diffusion-v1-4" \
--instance_data_dir="instance-images" \
--class_data_dir="class-images" \
--output_dir="save-model" \
--with_prior_preservation \
--prior_loss_weight=1.0 \
--instance_prompt="parkminyoung" \
--class_prompt="person" \
--resolution=512 \
--train_batch_size=1 \
--learning_rate=5e-6 \
--num_class_images=12 \
--max_train_steps=650
Logs
INFO:__main__:Number of class images to sample: 12.
Generating class images: 0%| | 0/3 [00:00<?, ?it/s]
Traceback (most recent call last):
File "train_dreambooth_flax.py", line 665, in <module>
main()
File "train_dreambooth_flax.py", line 370, in main
prompt_ids = shard(prompt_ids)
File "/home/camenduru/.local/lib/python3.8/site-packages/flax/training/common_utils.py", line 37, in shard
return jax.tree_util.tree_map(
File "/home/camenduru/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 200, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/camenduru/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 200, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/camenduru/.local/lib/python3.8/site-packages/flax/training/common_utils.py", line 38, in <lambda>
lambda x: x.reshape((local_device_count, -1) + x.shape[1:]), xs)
ValueError: cannot reshape array of size 308 into shape (8,newaxis,77)
### System Info
gcloud | v3-8 | tpu-vm-base
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working