@@ -121,9 +121,9 @@ def get_cosine_schedule_with_warmup(
121121 The number of steps for the warmup phase.
122122 num_training_steps (`int`):
123123 The total number of training steps.
124- num_cycles (`float`, *optional*, defaults to 0.5):
125- The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
126- following a half-cosine).
124+ num_periods (`float`, *optional*, defaults to 0.5):
125+ The number of periods of the cosine function in a schedule (the default is to just decrease from the max
126+ value to 0 following a half-cosine).
127127 last_epoch (`int`, *optional*, defaults to -1):
128128 The index of the last epoch when resuming training.
129129
@@ -240,6 +240,8 @@ def get_scheduler(
240240 optimizer : Optimizer ,
241241 num_warmup_steps : Optional [int ] = None ,
242242 num_training_steps : Optional [int ] = None ,
243+ num_cycles : int = 1 ,
244+ power : float = 1.0 ,
243245):
244246 """
245247 Unified API to get any scheduler from its name.
@@ -255,6 +257,12 @@ def get_scheduler(
255257 num_training_steps (`int``, *optional*):
256258 The number of training steps to do. This is not required by all schedulers (hence the argument being
257259 optional), the function will raise an error if it's unset and the scheduler type requires it.
260+ num_cycles (`int`, *optional*):
261+ The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
262+ power (`float`, *optional*, defaults to 1.0):
263+ Power factor. See `POLYNOMIAL` scheduler
264+ last_epoch (`int`, *optional*, defaults to -1):
265+ The index of the last epoch when resuming training.
258266 """
259267 name = SchedulerType (name )
260268 schedule_func = TYPE_TO_SCHEDULER_FUNCTION [name ]
@@ -272,4 +280,14 @@ def get_scheduler(
272280 if num_training_steps is None :
273281 raise ValueError (f"{ name } requires `num_training_steps`, please provide that argument." )
274282
283+ if name == SchedulerType .COSINE_WITH_RESTARTS :
284+ return schedule_func (
285+ optimizer , num_warmup_steps = num_warmup_steps , num_training_steps = num_training_steps , num_cycles = num_cycles
286+ )
287+
288+ if name == SchedulerType .POLYNOMIAL :
289+ return schedule_func (
290+ optimizer , num_warmup_steps = num_warmup_steps , num_training_steps = num_training_steps , power = power
291+ )
292+
275293 return schedule_func (optimizer , num_warmup_steps = num_warmup_steps , num_training_steps = num_training_steps )
0 commit comments