From 83eec73261ae4eab26b27c7ce7dbd8a71f58d994 Mon Sep 17 00:00:00 2001 From: duongna21 Date: Mon, 17 Oct 2022 20:44:00 +0700 Subject: [PATCH 01/10] add textual inversion flax --- .../textual_inversion_flax.py | 622 ++++++++++++++++++ 1 file changed, 622 insertions(+) create mode 100644 examples/textual_inversion/textual_inversion_flax.py diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py new file mode 100644 index 000000000000..23b57821e86c --- /dev/null +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -0,0 +1,622 @@ +import argparse +import logging +import math +import os +import random +from pathlib import Path +from typing import Optional + +import numpy as np +import jax +import jax.numpy as jnp +from flax import jax_utils +from flax.training import train_state +from flax.training.common_utils import shard + +import optax + +import torch +import torch.utils.checkpoint +from torch.utils.data import Dataset + +import PIL +from diffusers import ( + FlaxAutoencoderKL, + FlaxDDPMScheduler, + FlaxPNDMScheduler, + FlaxStableDiffusionPipeline, + FlaxUNet2DConditionModel, +) +from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +import transformers +from transformers import CLIPFeatureExtractor, FlaxCLIPTextModel, CLIPTokenizer, set_seed + +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data." + ) + parser.add_argument( + "--placeholder_token", + type=str, + default=None, + required=True, + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word." + ) + parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'") + parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.") + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=5000, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=True, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--use_auth_token", + action="store_true", + help=( + "Will use the token generated when running `huggingface-cli login` (necessary to use this script with" + " private models)." + ), + ) + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.train_data_dir is None: + raise ValueError("You must specify a train data directory.") + + return args + + +imagenet_templates_small = [ + "a photo of a {}", + "a rendering of a {}", + "a cropped photo of the {}", + "the photo of a {}", + "a photo of a clean {}", + "a photo of a dirty {}", + "a dark photo of the {}", + "a photo of my {}", + "a photo of the cool {}", + "a close-up photo of a {}", + "a bright photo of the {}", + "a cropped photo of a {}", + "a photo of the {}", + "a good photo of the {}", + "a photo of one {}", + "a close-up photo of the {}", + "a rendition of the {}", + "a photo of the clean {}", + "a rendition of a {}", + "a photo of a nice {}", + "a good photo of a {}", + "a photo of the nice {}", + "a photo of the small {}", + "a photo of the weird {}", + "a photo of the large {}", + "a photo of a cool {}", + "a photo of a small {}", +] + +imagenet_style_templates_small = [ + "a painting in the style of {}", + "a rendering in the style of {}", + "a cropped painting in the style of {}", + "the painting in the style of {}", + "a clean painting in the style of {}", + "a dirty painting in the style of {}", + "a dark painting in the style of {}", + "a picture in the style of {}", + "a cool painting in the style of {}", + "a close-up painting in the style of {}", + "a bright painting in the style of {}", + "a cropped painting in the style of {}", + "a good painting in the style of {}", + "a close-up painting in the style of {}", + "a rendition in the style of {}", + "a nice painting in the style of {}", + "a small painting in the style of {}", + "a weird painting in the style of {}", + "a large painting in the style of {}", +] + + +class TextualInversionDataset(Dataset): + def __init__( + self, + data_root, + tokenizer, + learnable_property="object", # [object, style] + size=512, + repeats=100, + interpolation="bicubic", + flip_p=0.5, + set="train", + placeholder_token="*", + center_crop=False, + ): + self.data_root = data_root + self.tokenizer = tokenizer + self.learnable_property = learnable_property + self.size = size + self.placeholder_token = placeholder_token + self.center_crop = center_crop + self.flip_p = flip_p + + self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + + self.num_images = len(self.image_paths) + self._length = self.num_images + + if set == "train": + self._length = self.num_images * repeats + + self.interpolation = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + + self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small + self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + placeholder_string = self.placeholder_token + text = random.choice(self.templates).format(placeholder_string) + + example["input_ids"] = self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids[0] + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = ( + img.shape[0], + img.shape[1], + ) + img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] + + image = Image.fromarray(img) + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip_transform(image) + image = np.array(image).astype(np.uint8) + image = (image / 127.5 - 1.0).astype(np.float32) + + example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng): + if model.config.vocab_size == new_num_tokens or new_num_tokens is None: + return + model.config.vocab_size = new_num_tokens + + params = model.params + old_embeddings = params["text_model"]["embeddings"]["token_embedding"]["embedding"] + old_num_tokens, emb_dim = old_embeddings.shape + + initializer = jax.nn.initializers.normal() + + new_embeddings = initializer(rng, (new_num_tokens, emb_dim)) + new_embeddings = new_embeddings.at[:old_num_tokens].set(old_embeddings) + new_embeddings = new_embeddings.at[placeholder_token_id].set(new_embeddings[initializer_token_id]) + params["text_model"]["embeddings"]["token_embedding"]["embedding"] = new_embeddings + + model.params = params + return model + + +def get_params_to_save(params): + return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params)) + + +def main(): + args = parse_args() + + if args.seed is not None: + set_seed(args.seed) + + if jax.process_index() == 0: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # Setup logging, we only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + transformers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + + # Load the tokenizer and add the placeholder token as a additional special token + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + + # Add the placeholder token in tokenizer + num_added_tokens = tokenizer.add_tokens(args.placeholder_token) + if num_added_tokens == 0: + raise ValueError( + f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + + # Convert the initializer_token, placeholder_token to ids + token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + # Check if initializer_token is a single token or a sequence of tokens + if len(token_ids) > 1: + raise ValueError("The initializer token must be a single token.") + + initializer_token_id = token_ids[0] + placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + + # Load models and create wrapper for stable diffusion + text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") + vae, state_vae = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + unet, state_unet = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + + # Create sampling rng + rng = jax.random.PRNGKey(args.seed) + rng, _ = jax.random.split(rng) + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder = resize_token_embeddings( + text_encoder, len(tokenizer), initializer_token_id, placeholder_token_id, rng + ) + original_token_embeds = text_encoder.params["text_model"]["embeddings"]["token_embedding"]["embedding"] + + train_dataset = TextualInversionDataset( + data_root=args.train_data_dir, + tokenizer=tokenizer, + size=args.resolution, + placeholder_token=args.placeholder_token, + repeats=args.repeats, + learnable_property=args.learnable_property, + center_crop=args.center_crop, + set="train", + ) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + input_ids = torch.stack([example["input_ids"] for example in examples]) + + batch = {"pixel_values": pixel_values, "input_ids": input_ids} + batch = {k: v.numpy() for k, v in batch.items()} + + return batch + + total_train_batch_size = args.train_batch_size * jax.local_device_count() + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=total_train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn + ) + + # Optimization + if args.scale_lr: + args.learning_rate = args.learning_rate * total_train_batch_size + + constant_scheduler = optax.constant_schedule(args.learning_rate) + + optimizer = optax.adamw( + learning_rate=constant_scheduler, + b1=args.adam_beta1, + b2=args.adam_beta2, + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + + def create_mask(params, label_fn): + def _map(params, mask, label_fn): + for k in params: + if label_fn(k): + mask[k] = "token_embedding" + else: + if isinstance(params[k], dict): + mask[k] = {} + _map(params[k], mask[k], label_fn) + else: + mask[k] = "zero" + + mask = {} + _map(params, mask, label_fn) + return mask + + def zero_grads(): + # from https://github.com/deepmind/optax/issues/159#issuecomment-896459491 + def init_fn(_): + return () + + def update_fn(updates, state, params=None): + return jax.tree_util.tree_map(jnp.zeros_like, updates), () + + return optax.GradientTransformation(init_fn, update_fn) + + # Zero out gradients of layers other than the token embedding layer + tx = optax.multi_transform( + {"token_embedding": optimizer, "zero": zero_grads()}, + create_mask(text_encoder.params, lambda s: s == "token_embedding"), + ) + + state = train_state.TrainState.create(apply_fn=text_encoder.__call__, params=text_encoder.params, tx=tx) + + noise_scheduler = FlaxDDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 + ) + + # Initialize our training + train_rngs = jax.random.split(rng, jax.local_device_count()) + + # Define gradient train step fn + def train_step(state, batch, train_rng): + dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3) + + def compute_loss(params): + vae_outputs = vae.apply( + {"params": state_vae}, batch["pixel_values"], deterministic=True, method=vae.encode + ) + latents = vae_outputs.latent_dist.sample(sample_rng) + # (NHWC) -> (NCHW) + latents = jnp.transpose(latents, (0, 3, 1, 2)) + latents = latents * 0.18215 + + noise_rng, timestep_rng = jax.random.split(sample_rng) + noise = jax.random.normal(noise_rng, latents.shape) + bsz = latents.shape[0] + timesteps = jax.random.randint( + timestep_rng, + (bsz,), + 0, + noise_scheduler.config.num_train_timesteps, + ) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + encoder_hidden_states = state.apply_fn( + batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True + )[0] + unet_outputs = unet.apply( + {"params": state_unet}, noisy_latents, timesteps, encoder_hidden_states, train=False + ) + noise_pred = unet_outputs.sample + loss = (noise - noise_pred) ** 2 + loss = loss.mean() + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + loss, grad = grad_fn(state.params) + grad = jax.lax.pmean(grad, "batch") + new_state = state.apply_gradients(grads=grad) + + # Keep the token embeddings fixed except the newly added embeddings for the concept, + # as we only want to optimize the concept embeddings + token_embeds = original_token_embeds.at[placeholder_token_id].set( + new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"][placeholder_token_id] + ) + new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"] = token_embeds + + metrics = {"loss": loss} + metrics = jax.lax.pmean(metrics, axis_name="batch") + return new_state, metrics, new_train_rng + + # Create parallel version of the train and eval step + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) + + # Replicate the train state on each device + state = jax_utils.replicate(state) + + # Train! + num_update_steps_per_epoch = math.ceil(len(train_dataloader)) + + # Scheduler and math around the number of training steps. + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + epochs = tqdm(range(args.num_train_epochs), desc=f"Epoch ... (1/{args.num_train_epochs})", position=0) + for epoch in epochs: + # ======================== Training ================================ + + train_metrics = [] + + steps_per_epoch = len(train_dataset) // total_train_batch_size + train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False) + # train + for batch in train_dataloader: + batch = shard(batch) + state, train_metric, train_rngs = p_train_step(state, batch, train_rngs) + train_metrics.append(train_metric) + + train_step_progress_bar.update(1) + + train_metric = jax_utils.unreplicate(train_metric) + + train_step_progress_bar.close() + epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") + + # Create the pipeline using using the trained modules and save it. + if jax.process_index() == 0: + scheduler = FlaxPNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ) + safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", from_pt=True + ) + pipeline = FlaxStableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ) + + pipeline.save_pretrained( + args.output_dir, + params={ + "text_encoder": get_params_to_save(state.params), + "vae": state_vae, + "unet": state_unet, + "safety_checker": safety_checker.params, + }, + ) + + # Also save the newly trained embeddings + learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"][ + "embedding" + ][placeholder_token_id] + learned_embeds_dict = {args.placeholder_token: learned_embeds} + jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + +if __name__ == "__main__": + main() \ No newline at end of file From e319ea0fb85cc8bb63d671dee41ec68e4bb62bb1 Mon Sep 17 00:00:00 2001 From: duongna21 Date: Mon, 17 Oct 2022 22:12:13 +0700 Subject: [PATCH 02/10] make style --- examples/textual_inversion/textual_inversion_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 23b57821e86c..15186ba28e63 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -619,4 +619,4 @@ def compute_loss(params): if __name__ == "__main__": - main() \ No newline at end of file + main() From 04721a3a0efe78959765b5efe2b6934a1d5e9058 Mon Sep 17 00:00:00 2001 From: duongna21 Date: Mon, 17 Oct 2022 22:14:47 +0700 Subject: [PATCH 03/10] make style --- .../textual_inversion_flax.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 15186ba28e63..37a5b6ebe906 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -7,19 +7,15 @@ from typing import Optional import numpy as np -import jax -import jax.numpy as jnp -from flax import jax_utils -from flax.training import train_state -from flax.training.common_utils import shard - -import optax - import torch import torch.utils.checkpoint from torch.utils.data import Dataset +import jax +import jax.numpy as jnp +import optax import PIL +import transformers from diffusers import ( FlaxAutoencoderKL, FlaxDDPMScheduler, @@ -28,12 +24,15 @@ FlaxUNet2DConditionModel, ) from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker +from flax import jax_utils +from flax.training import train_state +from flax.training.common_utils import shard from huggingface_hub import HfFolder, Repository, whoami from PIL import Image from torchvision import transforms from tqdm.auto import tqdm -import transformers -from transformers import CLIPFeatureExtractor, FlaxCLIPTextModel, CLIPTokenizer, set_seed +from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed + logger = logging.getLogger(__name__) From 7b129b15382baee106ef13590d39b247c80e752b Mon Sep 17 00:00:00 2001 From: duongna21 Date: Tue, 25 Oct 2022 15:14:58 +0700 Subject: [PATCH 04/10] replicate vae and unet params --- .../textual_inversion_flax.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 37a5b6ebe906..db22f2527e50 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -391,9 +391,10 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Load models and create wrapper for stable diffusion - text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") - vae, state_vae = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") - unet, state_unet = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + # text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") + text_encoder = FlaxCLIPTextModel.from_pretrained("duongna/text_encoder_flax") + vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") # Create sampling rng rng = jax.random.PRNGKey(args.seed) @@ -485,12 +486,12 @@ def update_fn(updates, state, params=None): train_rngs = jax.random.split(rng, jax.local_device_count()) # Define gradient train step fn - def train_step(state, batch, train_rng): + def train_step(state, vae_params, unet_params, batch, train_rng): dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3) def compute_loss(params): vae_outputs = vae.apply( - {"params": state_vae}, batch["pixel_values"], deterministic=True, method=vae.encode + {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode ) latents = vae_outputs.latent_dist.sample(sample_rng) # (NHWC) -> (NCHW) @@ -511,7 +512,7 @@ def compute_loss(params): batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True )[0] unet_outputs = unet.apply( - {"params": state_unet}, noisy_latents, timesteps, encoder_hidden_states, train=False + {"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False ) noise_pred = unet_outputs.sample loss = (noise - noise_pred) ** 2 @@ -540,6 +541,8 @@ def compute_loss(params): # Replicate the train state on each device state = jax_utils.replicate(state) + vae_params = jax_utils.replicate(vae_params) + unet_params = jax_utils.replicate(unet_params) # Train! num_update_steps_per_epoch = math.ceil(len(train_dataloader)) @@ -568,7 +571,7 @@ def compute_loss(params): # train for batch in train_dataloader: batch = shard(batch) - state, train_metric, train_rngs = p_train_step(state, batch, train_rngs) + state, train_metric, train_rngs = p_train_step(state, vae_params, unet_params, batch, train_rngs) train_metrics.append(train_metric) train_step_progress_bar.update(1) @@ -600,8 +603,8 @@ def compute_loss(params): args.output_dir, params={ "text_encoder": get_params_to_save(state.params), - "vae": state_vae, - "unet": state_unet, + "vae": get_params_to_save(vae_params), + "unet": get_params_to_save(unet_params), "safety_checker": safety_checker.params, }, ) @@ -618,4 +621,4 @@ def compute_loss(params): if __name__ == "__main__": - main() + main() \ No newline at end of file From 6f4012f2c230404ca853d481c07ce059e51ad919 Mon Sep 17 00:00:00 2001 From: duongna21 Date: Tue, 25 Oct 2022 15:17:51 +0700 Subject: [PATCH 05/10] make style --- examples/textual_inversion/textual_inversion_flax.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index db22f2527e50..7096c96c7e22 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -560,6 +560,8 @@ def compute_loss(params): logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}") logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + epochs = tqdm(range(args.num_train_epochs), desc=f"Epoch ... (1/{args.num_train_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ @@ -575,6 +577,10 @@ def compute_loss(params): train_metrics.append(train_metric) train_step_progress_bar.update(1) + global_step += 1 + + if global_step >= args.max_train_steps: + break train_metric = jax_utils.unreplicate(train_metric) @@ -621,4 +627,4 @@ def compute_loss(params): if __name__ == "__main__": - main() \ No newline at end of file + main() From dbba3804183aa11d12ac5503560f912aa893f93d Mon Sep 17 00:00:00 2001 From: duongna21 Date: Tue, 25 Oct 2022 15:18:28 +0700 Subject: [PATCH 06/10] minor --- examples/textual_inversion/textual_inversion_flax.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 7096c96c7e22..2a4ad768aefb 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -391,8 +391,7 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Load models and create wrapper for stable diffusion - # text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") - text_encoder = FlaxCLIPTextModel.from_pretrained("duongna/text_encoder_flax") + text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") From 99b2d72d6983b1a81e8d6d85d6f22d48bb734ca5 Mon Sep 17 00:00:00 2001 From: duongna21 Date: Wed, 26 Oct 2022 11:01:29 +0700 Subject: [PATCH 07/10] save after end of training --- .../textual_inversion_flax.py | 68 +++++++++---------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 2a4ad768aefb..e68940928d30 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -586,43 +586,43 @@ def compute_loss(params): train_step_progress_bar.close() epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") - # Create the pipeline using using the trained modules and save it. - if jax.process_index() == 0: - scheduler = FlaxPNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ) - safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker", from_pt=True - ) - pipeline = FlaxStableDiffusionPipeline( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), - ) + # Create the pipeline using using the trained modules and save it. + if jax.process_index() == 0: + scheduler = FlaxPNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ) + safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", from_pt=True + ) + pipeline = FlaxStableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ) - pipeline.save_pretrained( - args.output_dir, - params={ - "text_encoder": get_params_to_save(state.params), - "vae": get_params_to_save(vae_params), - "unet": get_params_to_save(unet_params), - "safety_checker": safety_checker.params, - }, - ) + pipeline.save_pretrained( + args.output_dir, + params={ + "text_encoder": get_params_to_save(state.params), + "vae": get_params_to_save(vae_params), + "unet": get_params_to_save(unet_params), + "safety_checker": safety_checker.params, + }, + ) - # Also save the newly trained embeddings - learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"][ - "embedding" - ][placeholder_token_id] - learned_embeds_dict = {args.placeholder_token: learned_embeds} - jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict) + # Also save the newly trained embeddings + learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"][ + "embedding" + ][placeholder_token_id] + learned_embeds_dict = {args.placeholder_token: learned_embeds} + jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict) - if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) if __name__ == "__main__": From fd22de82f153d7ce02a6e46953238eede9963640 Mon Sep 17 00:00:00 2001 From: duongna21 Date: Wed, 26 Oct 2022 11:03:00 +0700 Subject: [PATCH 08/10] style --- examples/textual_inversion/textual_inversion_flax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index e68940928d30..be2b7ffb5490 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -615,9 +615,9 @@ def compute_loss(params): ) # Also save the newly trained embeddings - learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"][ - "embedding" - ][placeholder_token_id] + learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][ + placeholder_token_id + ] learned_embeds_dict = {args.placeholder_token: learned_embeds} jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict) From 98a21a42956783693d222ebfe1d7c3630462dcf9 Mon Sep 17 00:00:00 2001 From: "Duong A. Nguyen" <38061659+duongna21@users.noreply.github.com> Date: Thu, 27 Oct 2022 00:28:34 +0700 Subject: [PATCH 09/10] Temporary fix Co-authored-by: Suraj Patil --- examples/textual_inversion/textual_inversion_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index be2b7ffb5490..84ff97c39a96 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -391,7 +391,7 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Load models and create wrapper for stable diffusion - text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") + text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") From f9aba406d0c8ece53427c581fbac7eb8019dc410 Mon Sep 17 00:00:00 2001 From: "Duong A. Nguyen" <38061659+duongna21@users.noreply.github.com> Date: Thu, 27 Oct 2022 01:07:38 +0700 Subject: [PATCH 10/10] Add Flax instruction --- examples/textual_inversion/README.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/examples/textual_inversion/README.md b/examples/textual_inversion/README.md index 05d8ffb8c9f2..4f4b01e270d6 100644 --- a/examples/textual_inversion/README.md +++ b/examples/textual_inversion/README.md @@ -68,6 +68,24 @@ accelerate launch textual_inversion.py \ A full training run takes ~1 hour on one V100 GPU. +If you want to speed it up even more, Flax implementation is available: + +```bash +export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" +export DATA_DIR="path-to-dir-containing-images" + +python textual_inversion_flax.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATA_DIR \ + --learnable_property="object" \ + --placeholder_token="" --initializer_token="toy" \ + --resolution=512 \ + --train_batch_size=1 \ + --max_train_steps=3000 \ + --learning_rate=5.0e-04 --scale_lr \ + --output_dir="textual_inversion_cat" +``` +It should be at least 70% faster than the PyTorch script with the same configuration. ### Inference