11import argparse
22import hashlib
3- import itertools
43import math
54import os
65from pathlib import Path
76from typing import Optional
87
9- import numpy as np
108import torch
11- import torch .distributed as dist
129import torch .nn .functional as F
1310import torch .utils .checkpoint
14- from copy import deepcopy
15- from diffusers import AutoencoderKL , DDPMScheduler , DiffusionPipeline , UNet2DConditionModel
16- from diffusers .optimization import get_scheduler
17- from huggingface_hub import HfFolder , Repository , whoami
18- from packaging import version
19- from PIL import Image
20- from torch .nn .parallel import DistributedDataParallel as DDP
2111from torch .utils .data import Dataset
22- from torchvision import transforms
23- from tqdm .auto import tqdm
24- from transformers import AutoTokenizer , PretrainedConfig
2512
2613import colossalai
2714from colossalai .context .parallel_mode import ParallelMode
3017from colossalai .nn .optimizer .gemini_optimizer import GeminiAdamOptimizer
3118from colossalai .nn .parallel import ZeroDDP
3219from colossalai .nn .parallel .utils import convert_to_torch_module
33- from colossalai .tensor import ColoTensor , ProcessGroup
20+ from colossalai .tensor import ProcessGroup
3421from colossalai .utils import get_current_device
3522from colossalai .utils .model .colo_init_context import ColoInitContext
3623
24+ from diffusers import AutoencoderKL , DDPMScheduler , DiffusionPipeline , UNet2DConditionModel
25+ from diffusers .optimization import get_scheduler
26+ from huggingface_hub import HfFolder , Repository , whoami
27+ from packaging import version
28+ from PIL import Image
29+ from torchvision import transforms
30+ from tqdm .auto import tqdm
31+ from transformers import AutoTokenizer , PretrainedConfig
32+
33+
3734disable_existing_loggers ()
3835logger = get_dist_logger ()
3936
@@ -118,8 +115,10 @@ def parse_args(input_args=None):
118115 "--num_class_images" ,
119116 type = int ,
120117 default = 100 ,
121- help = ("Minimal class images for prior preservation loss. If there are not enough images already present in"
122- " class_data_dir, additional images will be sampled with class_prompt." ),
118+ help = (
119+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
120+ " class_data_dir, additional images will be sampled with class_prompt."
121+ ),
123122 )
124123 parser .add_argument (
125124 "--output_dir" ,
@@ -132,23 +131,26 @@ def parse_args(input_args=None):
132131 "--resolution" ,
133132 type = int ,
134133 default = 512 ,
135- help = ("The resolution for input images, all the images in the train/validation dataset will be resized to this"
136- " resolution" ),
134+ help = (
135+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
136+ " resolution"
137+ ),
137138 )
138139 parser .add_argument (
139140 "--placement" ,
140141 type = str ,
141- default = ' cpu' ,
142+ default = " cpu" ,
142143 help = "Placement Policy for Gemini. Valid when using colossalai as dist plan." ,
143144 )
144- parser .add_argument ("--center_crop" ,
145- action = "store_true" ,
146- help = "Whether to center crop images before resizing to resolution" )
147- parser .add_argument ("--train_batch_size" ,
148- type = int ,
149- default = 4 ,
150- help = "Batch size (per device) for the training dataloader." )
151- parser .add_argument ("--sample_batch_size" , type = int , default = 4 , help = "Batch size (per device) for sampling images." )
145+ parser .add_argument (
146+ "--center_crop" , action = "store_true" , help = "Whether to center crop images before resizing to resolution"
147+ )
148+ parser .add_argument (
149+ "--train_batch_size" , type = int , default = 4 , help = "Batch size (per device) for the training dataloader."
150+ )
151+ parser .add_argument (
152+ "--sample_batch_size" , type = int , default = 4 , help = "Batch size (per device) for sampling images."
153+ )
152154 parser .add_argument ("--num_train_epochs" , type = int , default = 1 )
153155 parser .add_argument (
154156 "--max_train_steps" ,
@@ -184,16 +186,17 @@ def parse_args(input_args=None):
184186 "--lr_scheduler" ,
185187 type = str ,
186188 default = "constant" ,
187- help = ('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
188- ' "constant", "constant_with_warmup"]' ),
189+ help = (
190+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
191+ ' "constant", "constant_with_warmup"]'
192+ ),
193+ )
194+ parser .add_argument (
195+ "--lr_warmup_steps" , type = int , default = 500 , help = "Number of steps for the warmup in the lr scheduler."
196+ )
197+ parser .add_argument (
198+ "--use_8bit_adam" , action = "store_true" , help = "Whether or not to use 8-bit Adam from bitsandbytes."
189199 )
190- parser .add_argument ("--lr_warmup_steps" ,
191- type = int ,
192- default = 500 ,
193- help = "Number of steps for the warmup in the lr scheduler." )
194- parser .add_argument ("--use_8bit_adam" ,
195- action = "store_true" ,
196- help = "Whether or not to use 8-bit Adam from bitsandbytes." )
197200
198201 parser .add_argument ("--max_grad_norm" , default = 1.0 , type = float , help = "Max gradient norm." )
199202 parser .add_argument ("--push_to_hub" , action = "store_true" , help = "Whether or not to push the model to the Hub." )
@@ -208,8 +211,10 @@ def parse_args(input_args=None):
208211 "--logging_dir" ,
209212 type = str ,
210213 default = "logs" ,
211- help = ("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
212- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ),
214+ help = (
215+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
216+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
217+ ),
213218 )
214219 parser .add_argument (
215220 "--mixed_precision" ,
@@ -219,7 +224,8 @@ def parse_args(input_args=None):
219224 help = (
220225 "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
221226 " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
222- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ),
227+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
228+ ),
223229 )
224230 parser .add_argument ("--local_rank" , type = int , default = - 1 , help = "For distributed training: local_rank" )
225231
@@ -285,12 +291,14 @@ def __init__(
285291 else :
286292 self .class_data_root = None
287293
288- self .image_transforms = transforms .Compose ([
289- transforms .Resize (size , interpolation = transforms .InterpolationMode .BILINEAR ),
290- transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size ),
291- transforms .ToTensor (),
292- transforms .Normalize ([0.5 ], [0.5 ]),
293- ])
294+ self .image_transforms = transforms .Compose (
295+ [
296+ transforms .Resize (size , interpolation = transforms .InterpolationMode .BILINEAR ),
297+ transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size ),
298+ transforms .ToTensor (),
299+ transforms .Normalize ([0.5 ], [0.5 ]),
300+ ]
301+ )
294302
295303 def __len__ (self ):
296304 return self ._length
@@ -355,20 +363,24 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
355363 cai_version = colossalai .__version__
356364 if version .parse (cai_version ) > version .parse ("0.1.10" ):
357365 from colossalai .nn .parallel import GeminiDDP
358- model = GeminiDDP (model ,
359- device = get_current_device (),
360- placement_policy = placememt_policy ,
361- pin_memory = True ,
362- search_range_mb = 32 )
363-
364- elif version .parse (cai_version ) <= version .parse ("0.1.10" ) and version .parse (cai_version ) >= version .parse ("0.1.9" ):
366+
367+ model = GeminiDDP (
368+ model , device = get_current_device (), placement_policy = placememt_policy , pin_memory = True , search_range_mb = 32
369+ )
370+
371+ elif version .parse (cai_version ) <= version .parse ("0.1.10" ) and version .parse (cai_version ) >= version .parse (
372+ "0.1.9"
373+ ):
365374 from colossalai .gemini import ChunkManager , GeminiManager
375+
366376 chunk_size = ChunkManager .search_chunk_size (model , 64 * 1024 ** 2 , 32 )
367377 gemini_manager = GeminiManager (placememt_policy , chunk_manager )
368- chunk_manager = ChunkManager (chunk_size ,
369- pg ,
370- enable_distributed_storage = True ,
371- init_device = GeminiManager .get_default_device (placememt_policy ))
378+ chunk_manager = ChunkManager (
379+ chunk_size ,
380+ pg ,
381+ enable_distributed_storage = True ,
382+ init_device = GeminiManager .get_default_device (placememt_policy ),
383+ )
372384 model = ZeroDDP (model , gemini_manager )
373385 else :
374386 raise NotImplemented (f"CAI version { cai_version } is not supported" )
@@ -383,7 +395,7 @@ def main(args):
383395 "gradient_accumulation_steps" : args .gradient_accumulation_steps ,
384396 "clip_grad_norm" : args .max_grad_norm ,
385397 }
386-
398+
387399 colossalai .launch_from_torch (config = config )
388400 pg = ProcessGroup ()
389401
@@ -414,9 +426,11 @@ def main(args):
414426
415427 pipeline .to (get_current_device ())
416428
417- for example in tqdm (sample_dataloader ,
418- desc = "Generating class images" ,
419- disable = not gpc .get_local_rank (ParallelMode .DATA ) == 0 ):
429+ for example in tqdm (
430+ sample_dataloader ,
431+ desc = "Generating class images" ,
432+ disable = not gpc .get_local_rank (ParallelMode .DATA ) == 0 ,
433+ ):
420434 images = pipeline (example ["prompt" ]).images
421435
422436 for i , image in enumerate (images ):
@@ -466,23 +480,24 @@ def main(args):
466480
467481 logger .info (f"Loading text_encoder from { args .pretrained_model_name_or_path } " , ranks = [0 ])
468482
469- text_encoder = text_encoder_cls .from_pretrained (args .pretrained_model_name_or_path ,
470- subfolder = "text_encoder" ,
471- revision = args .revision ,)
483+ text_encoder = text_encoder_cls .from_pretrained (
484+ args .pretrained_model_name_or_path ,
485+ subfolder = "text_encoder" ,
486+ revision = args .revision ,
487+ )
472488
473489 logger .info (f"Loading AutoencoderKL from { args .pretrained_model_name_or_path } " , ranks = [0 ])
474- vae = AutoencoderKL .from_pretrained (args .pretrained_model_name_or_path ,
475- subfolder = "vae" ,
476- revision = args .revision ,)
490+ vae = AutoencoderKL .from_pretrained (
491+ args .pretrained_model_name_or_path ,
492+ subfolder = "vae" ,
493+ revision = args .revision ,
494+ )
477495
478-
479496 logger .info (f"Loading UNet2DConditionModel from { args .pretrained_model_name_or_path } " , ranks = [0 ])
480497 with ColoInitContext ():
481- unet = UNet2DConditionModel .from_pretrained (args .pretrained_model_name_or_path ,
482- subfolder = "unet" ,
483- revision = args .revision ,
484- low_cpu_mem_usage = False )
485-
498+ unet = UNet2DConditionModel .from_pretrained (
499+ args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision , low_cpu_mem_usage = False
500+ )
486501
487502 vae .requires_grad_ (False )
488503 text_encoder .requires_grad_ (False )
@@ -491,7 +506,7 @@ def main(args):
491506 unet .enable_gradient_checkpointing ()
492507
493508 if args .scale_lr :
494- args .learning_rate = ( args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * 2 )
509+ args .learning_rate = args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * 2
495510
496511 unet = gemini_zero_dpp (unet , pg , args .placement )
497512
@@ -527,9 +542,7 @@ def collate_fn(examples):
527542 pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
528543
529544 input_ids = tokenizer .pad (
530- {
531- "input_ids" : input_ids
532- },
545+ {"input_ids" : input_ids },
533546 padding = "max_length" ,
534547 max_length = tokenizer .model_max_length ,
535548 return_tensors = "pt" ,
@@ -541,11 +554,9 @@ def collate_fn(examples):
541554 }
542555 return batch
543556
544- train_dataloader = torch .utils .data .DataLoader (train_dataset ,
545- batch_size = args .train_batch_size ,
546- shuffle = True ,
547- collate_fn = collate_fn ,
548- num_workers = 1 )
557+ train_dataloader = torch .utils .data .DataLoader (
558+ train_dataset , batch_size = args .train_batch_size , shuffle = True , collate_fn = collate_fn , num_workers = 1
559+ )
549560
550561 # Scheduler and math around the number of training steps.
551562 overrode_max_train_steps = False
@@ -662,8 +673,8 @@ def collate_fn(examples):
662673 global_step += 1
663674 logs = {
664675 "loss" : loss .detach ().item (),
665- "lr" : optimizer .param_groups [0 ]['lr' ]
666- } # lr_scheduler.get_last_lr()[0]}
676+ "lr" : optimizer .param_groups [0 ]["lr" ],
677+ } # lr_scheduler.get_last_lr()[0]}
667678 progress_bar .set_postfix (** logs )
668679
669680 if global_step % args .save_steps == 0 :
@@ -681,15 +692,15 @@ def collate_fn(examples):
681692 break
682693
683694 torch .cuda .synchronize ()
684- unet = convert_to_torch_module (unet )
685-
695+ unet = convert_to_torch_module (unet )
696+
686697 if gpc .get_local_rank (ParallelMode .DATA ) == 0 :
687698 pipeline = DiffusionPipeline .from_pretrained (
688699 args .pretrained_model_name_or_path ,
689700 unet = unet ,
690701 revision = args .revision ,
691702 )
692-
703+
693704 pipeline .save_pretrained (args .output_dir )
694705 logger .info (f"Saving model checkpoint to { args .output_dir } " , ranks = [0 ])
695706
0 commit comments