@@ -442,20 +442,25 @@ def collate_fn(examples):
442442 input_ids = [example ["instance_prompt_ids" ] for example in examples ]
443443 pixel_values = [example ["instance_images" ] for example in examples ]
444444
445- # concat class and instance examples for prior preservation
446- if args .with_prior_preservation :
447- input_ids += [example ["class_prompt_ids" ] for example in examples ]
448- pixel_values += [example ["class_images" ] for example in examples ]
449-
450- pixel_values = torch .stack (pixel_values )
451- pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
452-
445+ pixel_values = torch .stack (pixel_values ).to (memory_format = torch .contiguous_format ).float ()
453446 input_ids = tokenizer .pad ({"input_ids" : input_ids }, padding = True , return_tensors = "pt" ).input_ids
454447
455448 batch = {
456449 "input_ids" : input_ids ,
457450 "pixel_values" : pixel_values ,
458451 }
452+
453+ if args .with_prior_preservation :
454+ class_input_ids = [example ["class_prompt_ids" ] for example in examples ]
455+ class_pixel_values = [example ["class_images" ] for example in examples ]
456+
457+ class_pixel_values = torch .stack (class_pixel_values ).to (memory_format = torch .contiguous_format ).float ()
458+ class_input_ids = tokenizer .pad (
459+ {"input_ids" : class_input_ids }, padding = True , return_tensors = "pt"
460+ ).input_ids
461+ batch ["class_input_ids" ] = class_input_ids
462+ batch ["class_pixel_values" ] = class_pixel_values
463+
459464 return batch
460465
461466 train_dataloader = torch .utils .data .DataLoader (
@@ -516,33 +521,41 @@ def collate_fn(examples):
516521 unet .train ()
517522 for step , batch in enumerate (train_dataloader ):
518523 with accelerator .accumulate (unet ):
519- # Convert images to latent space
520- with torch .no_grad ():
521- latents = vae .encode (batch ["pixel_values" ]).latent_dist .sample ()
522- latents = latents * 0.18215
523-
524- # Sample noise that we'll add to the latents
525- noise = torch .randn (latents .shape ).to (latents .device )
526- bsz = latents .shape [0 ]
527- # Sample a random timestep for each image
528- timesteps = torch .randint (
529- 0 , noise_scheduler .config .num_train_timesteps , (bsz ,), device = latents .device
530- ).long ()
531-
532- # Add noise to the latents according to the noise magnitude at each timestep
533- # (this is the forward diffusion process)
534- noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
535-
536- # Get the text embedding for conditioning
537- with torch .no_grad ():
538- encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]
539-
540- # Predict the noise residual
541- noise_pred = unet (noisy_latents , timesteps , encoder_hidden_states ).sample
542-
543- loss = F .mse_loss (noise_pred , noise , reduction = "none" ).mean ([1 , 2 , 3 ]).mean ()
544- accelerator .backward (loss )
545524
525+ def _forward (input_ids , pixel_values ):
526+ # Convert images to latent space
527+ with torch .no_grad ():
528+ latents = vae .encode (pixel_values ).latent_dist .sample ()
529+ latents = latents * 0.18215
530+
531+ # Sample noise that we'll add to the latents
532+ noise = torch .randn (latents .shape ).to (latents .device )
533+ bsz = latents .shape [0 ]
534+ # Sample a random timestep for each image
535+ timesteps = torch .randint (
536+ 0 , noise_scheduler .config .num_train_timesteps , (bsz ,), device = latents .device
537+ ).long ()
538+
539+ # Add noise to the latents according to the noise magnitude at each timestep
540+ # (this is the forward diffusion process)
541+ noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
542+
543+ # Get the text embedding for conditioning
544+ with torch .no_grad ():
545+ encoder_hidden_states = text_encoder (input_ids )[0 ]
546+
547+ # Predict the noise residual
548+ noise_pred = unet (noisy_latents , timesteps , encoder_hidden_states ).sample
549+ loss = F .mse_loss (noise_pred , noise , reduction = "none" ).mean ([1 , 2 , 3 ]).mean ()
550+ return loss
551+
552+ loss = _forward (batch ["input_ids" ], batch ["pixel_values" ])
553+
554+ if args .with_prior_preservation :
555+ prior_loss = _forward (batch ["class_input_ids" ], batch ["class_pixel_values" ])
556+ loss = loss + prior_loss
557+
558+ accelerator .backward (loss )
546559 optimizer .step ()
547560 lr_scheduler .step ()
548561 optimizer .zero_grad ()
0 commit comments