44import torch
55import torch .nn .functional as F
66
7- from accelerate import Accelerator , DistributedDataParallelKwargs
7+ from accelerate import Accelerator
88from accelerate .logging import get_logger
99from datasets import load_dataset
10- from diffusers import DDIMPipeline , DDIMScheduler , UNetModel
10+ from diffusers import DDPMPipeline , DDPMScheduler , UNetUnconditionalModel
1111from diffusers .hub_utils import init_git_repo , push_to_hub
1212from diffusers .optimization import get_scheduler
1313from diffusers .training_utils import EMAModel
2727
2828
2929def main (args ):
30- ddp_unused_params = DistributedDataParallelKwargs (find_unused_parameters = True )
3130 logging_dir = os .path .join (args .output_dir , args .logging_dir )
3231 accelerator = Accelerator (
3332 mixed_precision = args .mixed_precision ,
3433 log_with = "tensorboard" ,
3534 logging_dir = logging_dir ,
36- kwargs_handlers = [ddp_unused_params ],
3735 )
3836
39- model = UNetModel (
40- attn_resolutions = (16 ,),
41- ch = 128 ,
42- ch_mult = (1 , 2 , 4 , 8 ),
43- dropout = 0.0 ,
37+ model = UNetUnconditionalModel (
38+ image_size = args .resolution ,
39+ in_channels = 3 ,
40+ out_channels = 3 ,
4441 num_res_blocks = 2 ,
45- resamp_with_conv = True ,
46- resolution = args .resolution ,
42+ block_channels = (128 , 128 , 256 , 256 , 512 , 512 ),
43+ down_blocks = (
44+ "UNetResDownBlock2D" ,
45+ "UNetResDownBlock2D" ,
46+ "UNetResDownBlock2D" ,
47+ "UNetResDownBlock2D" ,
48+ "UNetResAttnDownBlock2D" ,
49+ "UNetResDownBlock2D" ,
50+ ),
51+ up_blocks = (
52+ "UNetResUpBlock2D" ,
53+ "UNetResAttnUpBlock2D" ,
54+ "UNetResUpBlock2D" ,
55+ "UNetResUpBlock2D" ,
56+ "UNetResUpBlock2D" ,
57+ "UNetResUpBlock2D" ,
58+ ),
4759 )
48- noise_scheduler = DDIMScheduler ( timesteps = 1000 , tensor_format = "pt" )
60+ noise_scheduler = DDPMScheduler ( num_train_timesteps = 1000 , tensor_format = "pt" )
4961 optimizer = torch .optim .AdamW (
5062 model .parameters (),
5163 lr = args .learning_rate ,
@@ -92,65 +104,44 @@ def transforms(examples):
92104 run = os .path .split (__file__ )[- 1 ].split ("." )[0 ]
93105 accelerator .init_trackers (run )
94106
95- # Train!
96- is_distributed = torch .distributed .is_available () and torch .distributed .is_initialized ()
97- world_size = torch .distributed .get_world_size () if is_distributed else 1
98- total_train_batch_size = args .train_batch_size * args .gradient_accumulation_steps * world_size
99- max_steps = len (train_dataloader ) // args .gradient_accumulation_steps * args .num_epochs
100- logger .info ("***** Running training *****" )
101- logger .info (f" Num examples = { len (train_dataloader .dataset )} " )
102- logger .info (f" Num Epochs = { args .num_epochs } " )
103- logger .info (f" Instantaneous batch size per device = { args .train_batch_size } " )
104- logger .info (f" Total train batch size (w. parallel, distributed & accumulation) = { total_train_batch_size } " )
105- logger .info (f" Gradient Accumulation steps = { args .gradient_accumulation_steps } " )
106- logger .info (f" Total optimization steps = { max_steps } " )
107-
108107 global_step = 0
109108 for epoch in range (args .num_epochs ):
110109 model .train ()
111110 progress_bar = tqdm (total = len (train_dataloader ), disable = not accelerator .is_local_main_process )
112111 progress_bar .set_description (f"Epoch { epoch } " )
113112 for step , batch in enumerate (train_dataloader ):
114113 clean_images = batch ["input" ]
115- noise_samples = torch .randn (clean_images .shape ).to (clean_images .device )
114+ # Sample noise that we'll add to the images
115+ noise = torch .randn (clean_images .shape ).to (clean_images .device )
116116 bsz = clean_images .shape [0 ]
117- timesteps = torch .randint (0 , noise_scheduler .timesteps , (bsz ,), device = clean_images .device ).long ()
117+ # Sample a random timestep for each image
118+ timesteps = torch .randint (
119+ 0 , noise_scheduler .num_train_timesteps , (bsz ,), device = clean_images .device
120+ ).long ()
118121
119- # add noise onto the clean images according to the noise magnitude at each timestep
122+ # Add noise to the clean images according to the noise magnitude at each timestep
120123 # (this is the forward diffusion process)
121- noisy_images = noise_scheduler .add_noise (clean_images , noise_samples , timesteps )
122-
123- if step % args .gradient_accumulation_steps != 0 :
124- with accelerator .no_sync (model ):
125- output = model (noisy_images , timesteps )
126- # predict the noise residual
127- loss = F .mse_loss (output , noise_samples )
128- loss = loss / args .gradient_accumulation_steps
129- accelerator .backward (loss )
130- else :
131- output = model (noisy_images , timesteps )
132- # predict the noise residual
133- loss = F .mse_loss (output , noise_samples )
134- loss = loss / args .gradient_accumulation_steps
124+ noisy_images = noise_scheduler .add_noise (clean_images , noise , timesteps )
125+
126+ with accelerator .accumulate (model ):
127+ # Predict the noise residual
128+ noise_pred = model (noisy_images , timesteps )["sample" ]
129+ loss = F .mse_loss (noise_pred , noise )
135130 accelerator .backward (loss )
136- torch .nn .utils .clip_grad_norm_ (model .parameters (), 1.0 )
131+
132+ accelerator .clip_grad_norm_ (model .parameters (), 1.0 )
137133 optimizer .step ()
138134 lr_scheduler .step ()
139- ema_model .step (model )
135+ if args .use_ema :
136+ ema_model .step (model )
140137 optimizer .zero_grad ()
138+
141139 progress_bar .update (1 )
142- progress_bar .set_postfix (
143- loss = loss .detach ().item (), lr = optimizer .param_groups [0 ]["lr" ], ema_decay = ema_model .decay
144- )
145- accelerator .log (
146- {
147- "train_loss" : loss .detach ().item (),
148- "epoch" : epoch ,
149- "ema_decay" : ema_model .decay ,
150- "step" : global_step ,
151- },
152- step = global_step ,
153- )
140+ logs = {"loss" : loss .detach ().item (), "lr" : lr_scheduler .get_last_lr ()[0 ], "step" : global_step }
141+ if args .use_ema :
142+ logs ["ema_decay" ] = ema_model .decay
143+ progress_bar .set_postfix (** logs )
144+ accelerator .log (logs , step = global_step )
154145 global_step += 1
155146 progress_bar .close ()
156147
@@ -159,26 +150,27 @@ def transforms(examples):
159150 # Generate a sample image for visual inspection
160151 if accelerator .is_main_process :
161152 with torch .no_grad ():
162- pipeline = DDIMPipeline (
163- unet = accelerator .unwrap_model (ema_model .averaged_model ),
164- noise_scheduler = noise_scheduler ,
153+ pipeline = DDPMPipeline (
154+ unet = accelerator .unwrap_model (ema_model .averaged_model if args . use_ema else model ),
155+ scheduler = noise_scheduler ,
165156 )
166157
167158 generator = torch .manual_seed (0 )
168159 # run pipeline in inference (sample random noise and denoise)
169- images = pipeline (generator = generator , batch_size = args .eval_batch_size , num_inference_steps = 50 )
160+ images = pipeline (generator = generator , batch_size = args .eval_batch_size )
170161
171162 # denormalize the images and save to tensorboard
172163 images_processed = (images .cpu () + 1.0 ) * 127.5
173164 images_processed = images_processed .clamp (0 , 255 ).type (torch .uint8 ).numpy ()
174165
175166 accelerator .trackers [0 ].writer .add_images ("test_samples" , images_processed , epoch )
176167
177- # save the model
178- if args .push_to_hub :
179- push_to_hub (args , pipeline , repo , commit_message = f"Epoch { epoch } " , blocking = False )
180- else :
181- pipeline .save_pretrained (args .output_dir )
168+ if epoch % args .save_model_epochs == 0 or epoch == args .num_epochs - 1 :
169+ # save the model
170+ if args .push_to_hub :
171+ push_to_hub (args , pipeline , repo , commit_message = f"Epoch { epoch } " , blocking = False )
172+ else :
173+ pipeline .save_pretrained (args .output_dir )
182174 accelerator .wait_for_everyone ()
183175
184176 accelerator .end_training ()
@@ -188,12 +180,13 @@ def transforms(examples):
188180 parser = argparse .ArgumentParser (description = "Simple example of a training script." )
189181 parser .add_argument ("--local_rank" , type = int , default = - 1 )
190182 parser .add_argument ("--dataset" , type = str , default = "huggan/flowers-102-categories" )
191- parser .add_argument ("--output_dir" , type = str , default = "ddpm-model " )
183+ parser .add_argument ("--output_dir" , type = str , default = "ddpm-flowers-64 " )
192184 parser .add_argument ("--overwrite_output_dir" , action = "store_true" )
193185 parser .add_argument ("--resolution" , type = int , default = 64 )
194186 parser .add_argument ("--train_batch_size" , type = int , default = 16 )
195187 parser .add_argument ("--eval_batch_size" , type = int , default = 16 )
196188 parser .add_argument ("--num_epochs" , type = int , default = 100 )
189+ parser .add_argument ("--save_model_epochs" , type = int , default = 5 )
197190 parser .add_argument ("--gradient_accumulation_steps" , type = int , default = 1 )
198191 parser .add_argument ("--learning_rate" , type = float , default = 1e-4 )
199192 parser .add_argument ("--lr_scheduler" , type = str , default = "cosine" )
@@ -202,6 +195,7 @@ def transforms(examples):
202195 parser .add_argument ("--adam_beta2" , type = float , default = 0.999 )
203196 parser .add_argument ("--adam_weight_decay" , type = float , default = 1e-6 )
204197 parser .add_argument ("--adam_epsilon" , type = float , default = 1e-3 )
198+ parser .add_argument ("--use_ema" , action = "store_true" , default = True )
205199 parser .add_argument ("--ema_inv_gamma" , type = float , default = 1.0 )
206200 parser .add_argument ("--ema_power" , type = float , default = 3 / 4 )
207201 parser .add_argument ("--ema_max_decay" , type = float , default = 0.9999 )
0 commit comments