@@ -149,6 +149,9 @@ def parse_args():
149149 parser .add_argument (
150150 "--lr_warmup_steps" , type = int , default = 500 , help = "Number of steps for the warmup in the lr scheduler."
151151 )
152+ parser .add_argument (
153+ "--use_8bit_adam" , action = "store_true" , help = "Whether or not to use 8-bit Adam from bitsandbytes."
154+ )
152155 parser .add_argument ("--adam_beta1" , type = float , default = 0.9 , help = "The beta1 parameter for the Adam optimizer." )
153156 parser .add_argument ("--adam_beta2" , type = float , default = 0.999 , help = "The beta2 parameter for the Adam optimizer." )
154157 parser .add_argument ("--adam_weight_decay" , type = float , default = 1e-2 , help = "Weight decay to use." )
@@ -401,7 +404,19 @@ def main():
401404 args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
402405 )
403406
404- optimizer = torch .optim .AdamW (
407+ if args .use_8bit_adam :
408+ try :
409+ import bitsandbytes as bnb
410+ except ImportError :
411+ raise ImportError (
412+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
413+ )
414+
415+ optimizer_class = bnb .optim .AdamW8bit
416+ else :
417+ optimizer_class = torch .optim .AdamW
418+
419+ optimizer = optimizer_class (
405420 unet .parameters (), # only optimize unet
406421 lr = args .learning_rate ,
407422 betas = (args .adam_beta1 , args .adam_beta2 ),
0 commit comments