Skip to content

Commit d87cc15

Browse files
zetyquicklypcuenca
andauthored
expose polynomial:power and cosine_with_restarts:num_cycles params (#1737)
* expose polynomial:power and cosine_with_restarts:num_cycles using get_scheduler func, add it to train_dreambooth.py * fix formatting * fix style * Update src/diffusers/optimization.py Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent e29dc97 commit d87cc15

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,13 @@ def parse_args(input_args=None):
204204
parser.add_argument(
205205
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
206206
)
207+
parser.add_argument(
208+
"--lr_num_cycles",
209+
type=int,
210+
default=1,
211+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
212+
)
213+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
207214
parser.add_argument(
208215
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
209216
)
@@ -588,6 +595,8 @@ def main(args):
588595
optimizer=optimizer,
589596
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
590597
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
598+
num_cycles=args.lr_num_cycles,
599+
power=args.lr_power,
591600
)
592601

593602
if args.train_text_encoder:

src/diffusers/optimization.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ def get_cosine_schedule_with_warmup(
121121
The number of steps for the warmup phase.
122122
num_training_steps (`int`):
123123
The total number of training steps.
124-
num_cycles (`float`, *optional*, defaults to 0.5):
125-
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
126-
following a half-cosine).
124+
num_periods (`float`, *optional*, defaults to 0.5):
125+
The number of periods of the cosine function in a schedule (the default is to just decrease from the max
126+
value to 0 following a half-cosine).
127127
last_epoch (`int`, *optional*, defaults to -1):
128128
The index of the last epoch when resuming training.
129129
@@ -240,6 +240,8 @@ def get_scheduler(
240240
optimizer: Optimizer,
241241
num_warmup_steps: Optional[int] = None,
242242
num_training_steps: Optional[int] = None,
243+
num_cycles: int = 1,
244+
power: float = 1.0,
243245
):
244246
"""
245247
Unified API to get any scheduler from its name.
@@ -255,6 +257,12 @@ def get_scheduler(
255257
num_training_steps (`int``, *optional*):
256258
The number of training steps to do. This is not required by all schedulers (hence the argument being
257259
optional), the function will raise an error if it's unset and the scheduler type requires it.
260+
num_cycles (`int`, *optional*):
261+
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
262+
power (`float`, *optional*, defaults to 1.0):
263+
Power factor. See `POLYNOMIAL` scheduler
264+
last_epoch (`int`, *optional*, defaults to -1):
265+
The index of the last epoch when resuming training.
258266
"""
259267
name = SchedulerType(name)
260268
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
@@ -272,4 +280,14 @@ def get_scheduler(
272280
if num_training_steps is None:
273281
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
274282

283+
if name == SchedulerType.COSINE_WITH_RESTARTS:
284+
return schedule_func(
285+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
286+
)
287+
288+
if name == SchedulerType.POLYNOMIAL:
289+
return schedule_func(
290+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
291+
)
292+
275293
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)

0 commit comments

Comments
 (0)