diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 253063e7936d..c57dca670d68 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -172,6 +172,11 @@ def parse_args(): ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -473,9 +478,19 @@ def main(): text_encoder, optimizer, train_dataloader, lr_scheduler ) - # Move vae and unet to device - vae.to(accelerator.device) - unet.to(accelerator.device) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae and unet to device. + vae.encoder.to(device=accelerator.device, dtype=weight_dtype) + vae.quant_conv.to(accelerator.device, dtype=weight_dtype) + unet.to(accelerator.device, dtype=weight_dtype) # Keep vae and unet in eval model as we don't train these vae.eval() @@ -513,11 +528,12 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() - latents = latents * 0.18215 + with torch.no_grad(): + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) + noise = torch.randn(latents.shape).to(latents.device, dtype=weight_dtype) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint( @@ -526,15 +542,16 @@ def main(): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps).to(dtype=weight_dtype) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype) # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + # Calculate loss in fp32 + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") accelerator.backward(loss) # Zero out the gradients for all token embeddings except the newly added diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index a17b1d2a5333..354f51186348 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -548,7 +548,7 @@ def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: + if self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -631,7 +631,7 @@ def forward(self, hidden_states, temb=None): output_states = () for resnet in self.resnets: - if self.training and self.gradient_checkpointing: + if self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -1134,7 +1134,7 @@ def forward( res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -1212,7 +1212,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs):