diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index d345ebb391e3..71e9dcf8e09a 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -41,7 +41,7 @@ from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm -from transformers import CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast +from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast, PretrainedConfig, T5EncoderModel, T5TokenizerFast import diffusers from diffusers import ( @@ -69,6 +69,18 @@ logger = get_logger(__name__) +def safe_info(msg, *args, **kwargs): + if isinstance(msg, str) and (args or kwargs): + msg = msg.format(*args, **kwargs) + else: + msg = str(msg) + tqdm.write(msg) + + +# 替换 logger.info +logger.info = safe_info + + def save_model_card( repo_id: str, images=None, @@ -92,7 +104,8 @@ def save_model_card( for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) widget_dict.append( - {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + {"text": validation_prompt if validation_prompt else " ", + "output": {"url": f"image_{i}.png"}} ) model_description = f""" @@ -172,22 +185,26 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device) + # 将整条验证用 pipeline 统一到 float32,避免模块间/张量间 dtype 不一致 + pipeline = pipeline.to(device=accelerator.device, dtype=torch.float32) pipeline.set_progress_bar_config(disable=True) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None + generator = torch.Generator(device=accelerator.device).manual_seed( + args.seed) if args.seed is not None else None # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() with autocast_ctx: - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + images = [pipeline(**pipeline_args, generator=generator).images[0] + for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + tracker.writer.add_images( + phase_name, np_images, epoch, dataformats="NHWC") if tracker.name == "wandb": tracker.log( { @@ -223,7 +240,8 @@ def import_model_class_from_model_name_or_path( def parse_args(input_args=None): - parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser = argparse.ArgumentParser( + description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -289,7 +307,8 @@ def parse_args(input_args=None): help="The column of the dataset containing the instance prompt for each image", ) - parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + parser.add_argument("--repeats", type=int, default=1, + help="How many times to repeat the training data.") parser.add_argument( "--class_data_dir", @@ -344,7 +363,8 @@ def parse_args(input_args=None): action="store_true", help="Flag to add prior preservation loss.", ) - parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument("--prior_loss_weight", type=float, + default=1.0, help="The weight of prior preservation loss.") parser.add_argument( "--num_class_images", type=int, @@ -360,7 +380,8 @@ def parse_args(input_args=None): default="sd3-dreambooth", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument("--seed", type=int, default=None, + help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, @@ -475,7 +496,8 @@ def parse_args(input_args=None): default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) - parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument("--lr_power", type=float, default=1.0, + help="Power factor of the polynomial scheduler.") parser.add_argument( "--dataloader_num_workers", type=int, @@ -513,7 +535,8 @@ def parse_args(input_args=None): "--optimizer", type=str, default="AdamW", - help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + help=( + 'The optimizer type to use. Choose between ["AdamW", "prodigy"]'), ) parser.add_argument( @@ -535,8 +558,10 @@ def parse_args(input_args=None): help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " "uses the value of square root of beta2. Ignored if optimizer is adamW", ) - parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") - parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument("--prodigy_decouple", type=bool, default=True, + help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, + default=1e-04, help="Weight decay to use for unet params") parser.add_argument( "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) @@ -561,9 +586,12 @@ def parse_args(input_args=None): help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " "Ignored if optimizer is adamW", ) - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument("--max_grad_norm", default=1.0, + type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", + help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, + help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, @@ -617,7 +645,8 @@ def parse_args(input_args=None): " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." ), ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--local_rank", type=int, default=-1, + help="For distributed training: local_rank") if input_args is not None: args = parser.parse_args(input_args) @@ -625,10 +654,12 @@ def parse_args(input_args=None): args = parser.parse_args() if args.dataset_name is None and args.instance_data_dir is None: - raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + raise ValueError( + "Specify either `--dataset_name` or `--instance_data_dir`") if args.dataset_name is not None and args.instance_data_dir is not None: - raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + raise ValueError( + "Specify only one of `--dataset_name` or `--instance_data_dir`") env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -636,15 +667,18 @@ def parse_args(input_args=None): if args.with_prior_preservation: if args.class_data_dir is None: - raise ValueError("You must specify a data directory for class images.") + raise ValueError( + "You must specify a data directory for class images.") if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") else: # logger is not available yet if args.class_data_dir is not None: - warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + warnings.warn( + "You need not use --class_data_dir without --with_prior_preservation.") if args.class_prompt is not None: - warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + warnings.warn( + "You need not use --class_prompt without --with_prior_preservation.") return args @@ -723,13 +757,15 @@ def __init__( # create final list of captions according to --repeats self.custom_instance_prompts = [] for caption in custom_instance_prompts: - self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + self.custom_instance_prompts.extend( + itertools.repeat(caption, repeats)) else: self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") - instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + instance_images = [Image.open(path) for path in list( + Path(instance_data_root).iterdir())] self.custom_instance_prompts = None self.instance_images = [] @@ -737,8 +773,10 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.pixel_values = [] - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_resize = transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop( + size) if center_crop else transforms.RandomCrop(size) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( [ @@ -759,7 +797,8 @@ def __init__( x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + y1, x1, h, w = train_crop.get_params( + image, (args.resolution, args.resolution)) image = crop(image, y1, x1, h, w) image = train_transforms(image) self.pixel_values.append(image) @@ -772,7 +811,8 @@ def __init__( self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_images_path = list(self.class_data_root.iterdir()) if class_num is not None: - self.num_class_images = min(len(self.class_images_path), class_num) + self.num_class_images = min( + len(self.class_images_path), class_num) else: self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) @@ -781,8 +821,10 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop( + size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -797,7 +839,8 @@ def __getitem__(self, index): example["instance_images"] = instance_image if self.custom_instance_prompts: - caption = self.custom_instance_prompts[index % self.num_instance_images] + caption = self.custom_instance_prompts[index % + self.num_instance_images] if caption: example["instance_prompt"] = caption else: @@ -807,7 +850,8 @@ def __getitem__(self, index): example["instance_prompt"] = self.instance_prompt if self.class_data_root: - class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = Image.open( + self.class_images_path[index % self.num_class_images]) class_image = exif_transpose(class_image) if not class_image.mode == "RGB": @@ -829,7 +873,8 @@ def collate_fn(examples, with_prior_preservation=False): prompts += [example["class_prompt"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + pixel_values = pixel_values.to( + memory_format=torch.contiguous_format).float() batch = {"pixel_values": pixel_values, "prompts": prompts} return batch @@ -871,19 +916,25 @@ def _encode_prompt_with_t5( prompt=None, num_images_per_prompt=1, device=None, + text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) + batch_size = len(prompt) if prompt is not None else ( + text_input_ids.shape[0] if text_input_ids is not None else 1) + + if tokenizer is not None and text_input_ids is None: + text_inputs = tokenizer(prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + elif text_input_ids is None: + raise ValueError( + "text_input_ids must be provided when the tokenizer is not specified") - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(device))[0] dtype = text_encoder.dtype @@ -893,7 +944,8 @@ def _encode_prompt_with_t5( # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds @@ -907,7 +959,8 @@ def _encode_prompt_with_clip( num_images_per_prompt: int = 1, ): prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) + batch_size = len(prompt) if prompt is not None else ( + text_input_ids.shape[0] if text_input_ids is not None else 1) if tokenizer is not None: text_inputs = tokenizer( @@ -921,9 +974,11 @@ def _encode_prompt_with_clip( text_input_ids = text_inputs.input_ids else: if text_input_ids is None: - raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + raise ValueError( + "text_input_ids must be provided when the tokenizer is not specified") - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + prompt_embeds = text_encoder( + text_input_ids.to(device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] @@ -932,7 +987,8 @@ def _encode_prompt_with_clip( _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds, pooled_prompt_embeds @@ -947,8 +1003,13 @@ def encode_prompt( text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt + if tokenizers is None: + if text_input_ids_list is None or not isinstance(text_input_ids_list, (list, tuple)) or len(text_input_ids_list) != 3: + raise ValueError( + "When tokenizers is None, text_input_ids_list must be a list/tuple of length 3: [clip1_ids, clip2_ids, t5_ids].") - clip_tokenizers = tokenizers[:2] + clip_tokenizers = tokenizers[:2] if tokenizers is not None else [ + None, None] clip_text_encoders = text_encoders[:2] clip_prompt_embeds_list = [] @@ -970,15 +1031,17 @@ def encode_prompt( t5_prompt_embed = _encode_prompt_with_t5( text_encoders[-1], - tokenizers[-1], + tokenizers[-1] if tokenizers is not None else None, max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[-1].device, + text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None, ) clip_prompt_embeds = torch.nn.functional.pad( - clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + clip_prompt_embeds, (0, + t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) ) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) @@ -1000,7 +1063,8 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator_project_config = ProjectConfiguration( + project_dir=args.output_dir, logging_dir=logging_dir) kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -1016,7 +1080,8 @@ def main(args): if args.report_to == "wandb": if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + raise ImportError( + "Make sure to install wandb if you want to use it for logging during training.") # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -1044,7 +1109,8 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + has_supported_fp16_accelerator = torch.cuda.is_available( + ) or torch.backends.mps.is_available() torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 @@ -1064,7 +1130,8 @@ def main(args): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + sample_dataloader = torch.utils.data.DataLoader( + sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) @@ -1075,8 +1142,10 @@ def main(args): images = pipeline(example["prompt"]).images for i, image in enumerate(images): - hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() - image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + hash_image = insecure_hashlib.sha1( + image.tobytes()).hexdigest() + image_filename = class_images_dir / \ + f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) del pipeline @@ -1094,12 +1163,12 @@ def main(args): ).repo_id # Load the tokenizers - tokenizer_one = CLIPTokenizer.from_pretrained( + tokenizer_one = CLIPTokenizerFast.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, ) - tokenizer_two = CLIPTokenizer.from_pretrained( + tokenizer_two = CLIPTokenizerFast.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, @@ -1187,16 +1256,20 @@ def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: for i, model in enumerate(models): if isinstance(unwrap_model(model), SD3Transformer2DModel): - unwrap_model(model).save_pretrained(os.path.join(output_dir, "transformer")) + unwrap_model(model).save_pretrained( + os.path.join(output_dir, "transformer")) elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): if isinstance(unwrap_model(model), CLIPTextModelWithProjection): hidden_size = unwrap_model(model).config.hidden_size if hidden_size == 768: - unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder")) + unwrap_model(model).save_pretrained( + os.path.join(output_dir, "text_encoder")) elif hidden_size == 1280: - unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_2")) + unwrap_model(model).save_pretrained( + os.path.join(output_dir, "text_encoder_2")) else: - unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_3")) + unwrap_model(model).save_pretrained( + os.path.join(output_dir, "text_encoder_3")) else: raise ValueError(f"Wrong model supplied: {type(model)=}.") @@ -1210,27 +1283,32 @@ def load_model_hook(models, input_dir): # load diffusers style into model if isinstance(unwrap_model(model), SD3Transformer2DModel): - load_model = SD3Transformer2DModel.from_pretrained(input_dir, subfolder="transformer") + load_model = SD3Transformer2DModel.from_pretrained( + input_dir, subfolder="transformer") model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): try: - load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder") + load_model = CLIPTextModelWithProjection.from_pretrained( + input_dir, subfolder="text_encoder") model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: try: - load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder_2") + load_model = CLIPTextModelWithProjection.from_pretrained( + input_dir, subfolder="text_encoder_2") model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: try: - load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_3") + load_model = T5EncoderModel.from_pretrained( + input_dir, subfolder="text_encoder_3") model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: - raise ValueError(f"Couldn't load the model of type: ({type(model)}).") + raise ValueError( + f"Couldn't load the model of type: ({type(model)}).") else: raise ValueError(f"Unsupported model found: {type(model)=}") @@ -1246,11 +1324,13 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes ) # Optimization parameters - transformer_parameters_with_lr = {"params": transformer.parameters(), "lr": args.learning_rate} + transformer_parameters_with_lr = { + "params": transformer.parameters(), "lr": args.learning_rate} if args.train_text_encoder: # different learning rate for text encoder and unet text_parameters_one_with_lr = { @@ -1315,7 +1395,8 @@ def load_model_hook(models, input_dir): try: import prodigyopt except ImportError: - raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + raise ImportError( + "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") optimizer_class = prodigyopt.Prodigy @@ -1362,13 +1443,15 @@ def load_model_hook(models, input_dir): train_dataset, batch_size=args.train_batch_size, shuffle=True, - collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + collate_fn=lambda examples: collate_fn( + examples, args.with_prior_preservation), num_workers=args.dataloader_num_workers, ) if not args.train_text_encoder: tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three] - text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three] + text_encoders = [text_encoder_one, + text_encoder_two, text_encoder_three] def compute_text_embeddings(prompt, text_encoders, tokenizers): with torch.no_grad(): @@ -1376,7 +1459,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoders, tokenizers, prompt, args.max_sequence_length ) prompt_embeds = prompt_embeds.to(accelerator.device) - pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to( + accelerator.device) return prompt_embeds, pooled_prompt_embeds # If no type of tuning is done on the text_encoder and custom instance prompts are NOT @@ -1410,25 +1494,33 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): prompt_embeds = instance_prompt_hidden_states pooled_prompt_embeds = instance_pooled_prompt_embeds if args.with_prior_preservation: - prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) - pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) + prompt_embeds = torch.cat( + [prompt_embeds, class_prompt_hidden_states], dim=0) + pooled_prompt_embeds = torch.cat( + [pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # batch prompts on all training steps else: tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) - tokens_three = tokenize_prompt(tokenizer_three, args.instance_prompt) + tokens_three = tokenize_prompt( + tokenizer_three, args.instance_prompt) if args.with_prior_preservation: - class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt) - class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt) - class_tokens_three = tokenize_prompt(tokenizer_three, args.class_prompt) + class_tokens_one = tokenize_prompt( + tokenizer_one, args.class_prompt) + class_tokens_two = tokenize_prompt( + tokenizer_two, args.class_prompt) + class_tokens_three = tokenize_prompt( + tokenizer_three, args.class_prompt) tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) - tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0) + tokens_three = torch.cat( + [tokens_three, class_tokens_three], dim=0) # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -1467,11 +1559,13 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + args.num_train_epochs = math.ceil( + args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. @@ -1480,15 +1574,19 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): accelerator.init_trackers(tracker_name, config=vars(args)) # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + total_batch_size = args.train_batch_size * \ + accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info( + f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info( + f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 @@ -1530,10 +1628,13 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): - sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) - schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + sigmas = noise_scheduler_copy.sigmas.to( + device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to( + accelerator.device) timesteps = timesteps.to(accelerator.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [(schedule_timesteps == t).nonzero().item() + for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < n_dim: @@ -1550,7 +1651,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] if args.train_text_encoder: - models_to_accumulate.extend([text_encoder_one, text_encoder_two, text_encoder_three]) + models_to_accumulate.extend( + [text_encoder_one, text_encoder_two, text_encoder_three]) with accelerator.accumulate(models_to_accumulate): pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] @@ -1564,11 +1666,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: tokens_one = tokenize_prompt(tokenizer_one, prompts) tokens_two = tokenize_prompt(tokenizer_two, prompts) - tokens_three = tokenize_prompt(tokenizer_three, prompts) + tokens_three = tokenize_prompt( + tokenizer_three, prompts) # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor + model_input = (model_input - vae.config.shift_factor) * \ + vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) # Sample noise that we'll add to the latents @@ -1584,13 +1688,17 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): logit_std=args.logit_std, mode_scale=args.mode_scale, ) - indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + indices = ( + u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to( + device=model_input.device) # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 - sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + sigmas = get_sigmas( + timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * \ + model_input + sigmas * noise # Predict the noise residual if not args.train_text_encoder: @@ -1603,10 +1711,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): )[0] else: prompt_embeds, pooled_prompt_embeds = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], + text_encoders=[text_encoder_one, + text_encoder_two, text_encoder_three], tokenizers=None, prompt=None, - text_input_ids_list=[tokens_one, tokens_two, tokens_three], + max_sequence_length=args.max_sequence_length, + text_input_ids_list=[ + tokens_one, tokens_two, tokens_three], ) model_pred = transformer( hidden_states=noisy_model_input, @@ -1623,7 +1734,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + weighting = compute_loss_weighting_for_sd3( + weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss if args.precondition_outputs: @@ -1633,7 +1745,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + model_pred, model_pred_prior = torch.chunk( + model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) # Compute prior loss @@ -1647,7 +1760,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Compute regular loss. loss = torch.mean( - (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + (weighting.float() * (model_pred.float() - target.float()) + ** 2).reshape(target.shape[0], -1), 1, ) loss = loss.mean() @@ -1668,7 +1782,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder else transformer.parameters() ) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + accelerator.clip_grad_norm_( + params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -1684,28 +1799,35 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + checkpoints = [ + d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + num_to_remove = len( + checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") - logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + logs = {"loss": loss.detach().item( + ), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) @@ -1727,7 +1849,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): vae=vae, text_encoder=accelerator.unwrap_model(text_encoder_one), text_encoder_2=accelerator.unwrap_model(text_encoder_two), - text_encoder_3=accelerator.unwrap_model(text_encoder_three), + text_encoder_3=accelerator.unwrap_model( + text_encoder_three), transformer=accelerator.unwrap_model(transformer), revision=args.revision, variant=args.variant,