Skip to content

Commit 838841b

Browse files
committed
reformatted
1 parent f9a1031 commit 838841b

File tree

1 file changed

+97
-86
lines changed

1 file changed

+97
-86
lines changed

examples/dreambooth/train_dreambooth_colossalai.py

Lines changed: 97 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,14 @@
11
import argparse
22
import hashlib
3-
import itertools
43
import math
54
import os
65
from pathlib import Path
76
from typing import Optional
87

9-
import numpy as np
108
import torch
11-
import torch.distributed as dist
129
import torch.nn.functional as F
1310
import torch.utils.checkpoint
14-
from copy import deepcopy
15-
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
16-
from diffusers.optimization import get_scheduler
17-
from huggingface_hub import HfFolder, Repository, whoami
18-
from packaging import version
19-
from PIL import Image
20-
from torch.nn.parallel import DistributedDataParallel as DDP
2111
from torch.utils.data import Dataset
22-
from torchvision import transforms
23-
from tqdm.auto import tqdm
24-
from transformers import AutoTokenizer, PretrainedConfig
2512

2613
import colossalai
2714
from colossalai.context.parallel_mode import ParallelMode
@@ -30,10 +17,20 @@
3017
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
3118
from colossalai.nn.parallel import ZeroDDP
3219
from colossalai.nn.parallel.utils import convert_to_torch_module
33-
from colossalai.tensor import ColoTensor, ProcessGroup
20+
from colossalai.tensor import ProcessGroup
3421
from colossalai.utils import get_current_device
3522
from colossalai.utils.model.colo_init_context import ColoInitContext
3623

24+
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
25+
from diffusers.optimization import get_scheduler
26+
from huggingface_hub import HfFolder, Repository, whoami
27+
from packaging import version
28+
from PIL import Image
29+
from torchvision import transforms
30+
from tqdm.auto import tqdm
31+
from transformers import AutoTokenizer, PretrainedConfig
32+
33+
3734
disable_existing_loggers()
3835
logger = get_dist_logger()
3936

@@ -118,8 +115,10 @@ def parse_args(input_args=None):
118115
"--num_class_images",
119116
type=int,
120117
default=100,
121-
help=("Minimal class images for prior preservation loss. If there are not enough images already present in"
122-
" class_data_dir, additional images will be sampled with class_prompt."),
118+
help=(
119+
"Minimal class images for prior preservation loss. If there are not enough images already present in"
120+
" class_data_dir, additional images will be sampled with class_prompt."
121+
),
123122
)
124123
parser.add_argument(
125124
"--output_dir",
@@ -132,23 +131,26 @@ def parse_args(input_args=None):
132131
"--resolution",
133132
type=int,
134133
default=512,
135-
help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
136-
" resolution"),
134+
help=(
135+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
136+
" resolution"
137+
),
137138
)
138139
parser.add_argument(
139140
"--placement",
140141
type=str,
141-
default='cpu',
142+
default="cpu",
142143
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
143144
)
144-
parser.add_argument("--center_crop",
145-
action="store_true",
146-
help="Whether to center crop images before resizing to resolution")
147-
parser.add_argument("--train_batch_size",
148-
type=int,
149-
default=4,
150-
help="Batch size (per device) for the training dataloader.")
151-
parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
145+
parser.add_argument(
146+
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
147+
)
148+
parser.add_argument(
149+
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
150+
)
151+
parser.add_argument(
152+
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
153+
)
152154
parser.add_argument("--num_train_epochs", type=int, default=1)
153155
parser.add_argument(
154156
"--max_train_steps",
@@ -184,16 +186,17 @@ def parse_args(input_args=None):
184186
"--lr_scheduler",
185187
type=str,
186188
default="constant",
187-
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
188-
' "constant", "constant_with_warmup"]'),
189+
help=(
190+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
191+
' "constant", "constant_with_warmup"]'
192+
),
193+
)
194+
parser.add_argument(
195+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
196+
)
197+
parser.add_argument(
198+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
189199
)
190-
parser.add_argument("--lr_warmup_steps",
191-
type=int,
192-
default=500,
193-
help="Number of steps for the warmup in the lr scheduler.")
194-
parser.add_argument("--use_8bit_adam",
195-
action="store_true",
196-
help="Whether or not to use 8-bit Adam from bitsandbytes.")
197200

198201
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
199202
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
@@ -208,8 +211,10 @@ def parse_args(input_args=None):
208211
"--logging_dir",
209212
type=str,
210213
default="logs",
211-
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
212-
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
214+
help=(
215+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
216+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
217+
),
213218
)
214219
parser.add_argument(
215220
"--mixed_precision",
@@ -219,7 +224,8 @@ def parse_args(input_args=None):
219224
help=(
220225
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
221226
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
222-
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
227+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
228+
),
223229
)
224230
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
225231

@@ -285,12 +291,14 @@ def __init__(
285291
else:
286292
self.class_data_root = None
287293

288-
self.image_transforms = transforms.Compose([
289-
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
290-
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
291-
transforms.ToTensor(),
292-
transforms.Normalize([0.5], [0.5]),
293-
])
294+
self.image_transforms = transforms.Compose(
295+
[
296+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
297+
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
298+
transforms.ToTensor(),
299+
transforms.Normalize([0.5], [0.5]),
300+
]
301+
)
294302

295303
def __len__(self):
296304
return self._length
@@ -355,20 +363,24 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
355363
cai_version = colossalai.__version__
356364
if version.parse(cai_version) > version.parse("0.1.10"):
357365
from colossalai.nn.parallel import GeminiDDP
358-
model = GeminiDDP(model,
359-
device=get_current_device(),
360-
placement_policy=placememt_policy,
361-
pin_memory=True,
362-
search_range_mb=32)
363-
364-
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
366+
367+
model = GeminiDDP(
368+
model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=32
369+
)
370+
371+
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse(
372+
"0.1.9"
373+
):
365374
from colossalai.gemini import ChunkManager, GeminiManager
375+
366376
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
367377
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
368-
chunk_manager = ChunkManager(chunk_size,
369-
pg,
370-
enable_distributed_storage=True,
371-
init_device=GeminiManager.get_default_device(placememt_policy))
378+
chunk_manager = ChunkManager(
379+
chunk_size,
380+
pg,
381+
enable_distributed_storage=True,
382+
init_device=GeminiManager.get_default_device(placememt_policy),
383+
)
372384
model = ZeroDDP(model, gemini_manager)
373385
else:
374386
raise NotImplemented(f"CAI version {cai_version} is not supported")
@@ -383,7 +395,7 @@ def main(args):
383395
"gradient_accumulation_steps": args.gradient_accumulation_steps,
384396
"clip_grad_norm": args.max_grad_norm,
385397
}
386-
398+
387399
colossalai.launch_from_torch(config=config)
388400
pg = ProcessGroup()
389401

@@ -414,9 +426,11 @@ def main(args):
414426

415427
pipeline.to(get_current_device())
416428

417-
for example in tqdm(sample_dataloader,
418-
desc="Generating class images",
419-
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0):
429+
for example in tqdm(
430+
sample_dataloader,
431+
desc="Generating class images",
432+
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0,
433+
):
420434
images = pipeline(example["prompt"]).images
421435

422436
for i, image in enumerate(images):
@@ -466,23 +480,24 @@ def main(args):
466480

467481
logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0])
468482

469-
text_encoder = text_encoder_cls.from_pretrained(args.pretrained_model_name_or_path,
470-
subfolder="text_encoder",
471-
revision=args.revision,)
483+
text_encoder = text_encoder_cls.from_pretrained(
484+
args.pretrained_model_name_or_path,
485+
subfolder="text_encoder",
486+
revision=args.revision,
487+
)
472488

473489
logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])
474-
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path,
475-
subfolder="vae",
476-
revision=args.revision,)
490+
vae = AutoencoderKL.from_pretrained(
491+
args.pretrained_model_name_or_path,
492+
subfolder="vae",
493+
revision=args.revision,
494+
)
477495

478-
479496
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
480497
with ColoInitContext():
481-
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
482-
subfolder="unet",
483-
revision=args.revision,
484-
low_cpu_mem_usage=False)
485-
498+
unet = UNet2DConditionModel.from_pretrained(
499+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
500+
)
486501

487502
vae.requires_grad_(False)
488503
text_encoder.requires_grad_(False)
@@ -491,7 +506,7 @@ def main(args):
491506
unet.enable_gradient_checkpointing()
492507

493508
if args.scale_lr:
494-
args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2)
509+
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2
495510

496511
unet = gemini_zero_dpp(unet, pg, args.placement)
497512

@@ -527,9 +542,7 @@ def collate_fn(examples):
527542
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
528543

529544
input_ids = tokenizer.pad(
530-
{
531-
"input_ids": input_ids
532-
},
545+
{"input_ids": input_ids},
533546
padding="max_length",
534547
max_length=tokenizer.model_max_length,
535548
return_tensors="pt",
@@ -541,11 +554,9 @@ def collate_fn(examples):
541554
}
542555
return batch
543556

544-
train_dataloader = torch.utils.data.DataLoader(train_dataset,
545-
batch_size=args.train_batch_size,
546-
shuffle=True,
547-
collate_fn=collate_fn,
548-
num_workers=1)
557+
train_dataloader = torch.utils.data.DataLoader(
558+
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
559+
)
549560

550561
# Scheduler and math around the number of training steps.
551562
overrode_max_train_steps = False
@@ -662,8 +673,8 @@ def collate_fn(examples):
662673
global_step += 1
663674
logs = {
664675
"loss": loss.detach().item(),
665-
"lr": optimizer.param_groups[0]['lr']
666-
} #lr_scheduler.get_last_lr()[0]}
676+
"lr": optimizer.param_groups[0]["lr"],
677+
} # lr_scheduler.get_last_lr()[0]}
667678
progress_bar.set_postfix(**logs)
668679

669680
if global_step % args.save_steps == 0:
@@ -681,15 +692,15 @@ def collate_fn(examples):
681692
break
682693

683694
torch.cuda.synchronize()
684-
unet=convert_to_torch_module(unet)
685-
695+
unet = convert_to_torch_module(unet)
696+
686697
if gpc.get_local_rank(ParallelMode.DATA) == 0:
687698
pipeline = DiffusionPipeline.from_pretrained(
688699
args.pretrained_model_name_or_path,
689700
unet=unet,
690701
revision=args.revision,
691702
)
692-
703+
693704
pipeline.save_pretrained(args.output_dir)
694705
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
695706

0 commit comments

Comments
 (0)