1616
1717Hacked together by Ross Wightman
1818"""
19+ import logging
1920import numbers
2021from typing import List , Mapping , Optional , Sequence , Tuple , Union
2122
2526from .adamw import adamw
2627from .nadamw import nadamw
2728
29+ _logger = logging .getLogger (__name__ )
30+
2831# Constants from Keller Jordan's Muon
2932MUON_EPS = 1e-7
3033DEFAULT_NS_STEPS = 5
@@ -95,7 +98,7 @@ def zeropower_via_newtonschulz(
9598 steps: Number of Newton-Schulz iterations
9699 coefficients: Coefficients (a, b, c) for the iteration
97100 eps: Numerical stability epsilon for norm
98- safety_factor: Multiplicative safety factor for norm (1.01 is common safety value)
101+ safety_factor: Multiplicative safety factor for norm (1.01 is common safety value in 'polar express' variants )
99102 dtype: Computation dtype
100103
101104 Returns:
@@ -161,6 +164,78 @@ def get_lr_scale(
161164 assert False , f'Invalid scaling function "{ adjust_lr_fn } "'
162165
163166
167+ def _is_suitable_for_muon (
168+ param : torch .Tensor ,
169+ min_dim_size : int = 4 ,
170+ max_aspect_ratio : float = 128. ,
171+ return_reason : bool = False ,
172+ ) -> Union [bool , Tuple [bool , str ]]:
173+ """Check if a parameter is suitable for Muon optimization.
174+
175+ Args:
176+ param: Parameter tensor
177+ min_dim_size: Minimum size for non-unit dimensions
178+ max_aspect_ratio: Maximum allowed aspect ratio
179+ return_reason: If True, return (bool, reason_string), else just bool (faster)
180+
181+ Returns:
182+ If return_reason=False: bool indicating suitability
183+ If return_reason=True: Tuple of (is_suitable, reason_string)
184+
185+ Examples:
186+ (64, 128) -> True (or (True, "ok") if return_reason=True)
187+ (96, 3, 4, 4) -> True - will be flattened to (96, 48)
188+ (4, 2048) -> False - extreme aspect ratio
189+ (64,) -> False - insufficient dims
190+ (1, 196, 768) -> False - leading unit dims
191+
192+ NOTE: these rules were created to balance complexity with covering common timm model cases
193+ Please let me know if there are non-optimal cases that you run into.
194+ """
195+
196+ s = param .shape
197+ # Must have at least 2 non-unit dimensions
198+ if param .ndim < 2 or sum (1 for dim_size in s if dim_size > 1 ) < 2 :
199+ return (False , "insufficient_dims" ) if return_reason else False
200+
201+ # Unit dimension in first two positions indicates:
202+ # - Position embeddings (1, seq, dim)
203+ # - Depthwise convs (out, 1, h, w)
204+ # - Other degenerate cases possibly not caught by first rule
205+ if s [0 ] == 1 or s [1 ] == 1 :
206+ return (False , "leading_unit_dims" ) if return_reason else False
207+
208+ if param .ndim >= 3 :
209+ # For 3D+ tensors, check what dimensions will be AFTER flattening
210+ # since that's what gets passed to Newton-Schulz iteration
211+ # Flatten mode: (out, in, *spatial) -> (out, in * spatial_prod)
212+ out_ch = s [0 ]
213+ in_ch_with_spatial = 1
214+ for d in s [1 :]:
215+ in_ch_with_spatial *= d
216+ check_dims = (out_ch , in_ch_with_spatial )
217+ else :
218+ # For 2D tensors, check as-is
219+ check_dims = s
220+
221+ # Both dims should be >= minimum size
222+ min_size = min (check_dims )
223+ if min_size < min_dim_size :
224+ if return_reason :
225+ return False , f"min_dim_too_small:{ min_size } "
226+ return False
227+
228+ # Aspect ratio shouldn't be too extreme
229+ max_size = max (check_dims )
230+ aspect_ratio = max_size / min_size
231+ if aspect_ratio > max_aspect_ratio :
232+ if return_reason :
233+ return False , f"extreme_aspect_ratio:{ aspect_ratio :.1f} "
234+ return False
235+
236+ return (True , "ok" ) if return_reason else True
237+
238+
164239def reshape_for_muon (
165240 tensor : torch .Tensor ,
166241 mode : str = "flatten" ,
@@ -320,6 +395,7 @@ def __init__(
320395 normalize_spatial : bool = True ,
321396 adamw_lr : Optional [float ] = None ,
322397 betas : Tuple [float , float ] = (0.9 , 0.95 ),
398+ verbose : bool = False ,
323399 ):
324400 """ Create Muon optimizer.
325401 Args:
@@ -337,6 +413,7 @@ def __init__(
337413 normalize_spatial: Whether to normalize by sqrt(spatial_size) in batched mode
338414 adamw_lr: Learning rate for AdamW (1D params), defaults to lr if not specified
339415 betas: AdamW beta coefficients
416+ verbose: Log parameter routing decisions (Muon vs AdamW)
340417
341418 Example:
342419 ```python
@@ -375,6 +452,7 @@ def __init__(
375452 normalize_spatial = normalize_spatial ,
376453 adamw_lr = adamw_lr if adamw_lr is not None else lr ,
377454 betas = betas ,
455+ verbose = verbose ,
378456 )
379457 super ().__init__ (params , defaults )
380458
@@ -386,6 +464,13 @@ def step(self, closure=None):
386464 with torch .enable_grad ():
387465 loss = closure ()
388466
467+ verbose = self .defaults .get ("verbose" , False )
468+
469+ # Tracking for logging (populated on first encounter of each param)
470+ muon_count = 0
471+ adamw_count = 0
472+ routing_reasons = {} if verbose else None
473+
389474 for group in self .param_groups :
390475 # Separate params into Muon and AdamW groups
391476 muon_params = []
@@ -405,33 +490,58 @@ def step(self, closure=None):
405490 if p .grad .is_sparse :
406491 raise RuntimeError ("Muon does not support sparse gradients" )
407492
408- # Determine if we should use Muon or AdamW fallback
409- force_adamw = p .ndim < 2 or group .get ("simple" , False )
410-
411493 state = self .state [p ]
412494
413- if force_adamw :
495+ # Determine routing on first encounter (cache in state)
496+ if "use_muon" not in state :
497+ # Check explicit simple flag first
498+ reason = None
499+ if group .get ("simple" , False ):
500+ state ["use_muon" ] = False
501+ if verbose :
502+ reason = "simple_flag"
503+ else :
504+ # Check shape suitability
505+ if verbose :
506+ suitable , reason = _is_suitable_for_muon (p , return_reason = True )
507+ else :
508+ suitable = _is_suitable_for_muon (p , return_reason = False )
509+ state ["use_muon" ] = suitable
510+
511+ # Track routing decision for logging
512+ if routing_reasons is not None and reason is not None :
513+ shape_str = "x" .join (str (s ) for s in p .shape )
514+ if shape_str not in routing_reasons :
515+ routing_reasons [shape_str ] = []
516+ routing_reasons [shape_str ].append (reason )
517+
518+ # Use cached routing decision
519+ use_muon = state ["use_muon" ]
520+ if use_muon :
521+ # Collect Muon params
522+ muon_params .append (p )
523+ muon_grads .append (p .grad )
524+ muon_count += 1
525+
526+ # State initialization for Muon
527+ if "momentum_buffer" not in state :
528+ state ["momentum_buffer" ] = torch .zeros_like (p , memory_format = torch .preserve_format )
529+ muon_momentum_bufs .append (state ["momentum_buffer" ])
530+ else :
414531 # Collect AdamW/NAdamW params
415532 adamw_params .append (p )
416533 adamw_grads .append (p .grad )
534+ adamw_count += 1
417535
418- # State initialization
419- if len ( state ) == 0 :
536+ # State initialization for AdamW
537+ if "step" not in state :
420538 state ["step" ] = torch .tensor (0. )
421539 state ["exp_avg" ] = torch .zeros_like (p , memory_format = torch .preserve_format )
422540 state ["exp_avg_sq" ] = torch .zeros_like (p , memory_format = torch .preserve_format )
423541
424542 adamw_exp_avgs .append (state ["exp_avg" ])
425543 adamw_exp_avg_sqs .append (state ["exp_avg_sq" ])
426544 adamw_state_steps .append (state ["step" ])
427- else :
428- # Collect Muon params
429- muon_params .append (p )
430- muon_grads .append (p .grad )
431-
432- if len (state ) == 0 :
433- state ["momentum_buffer" ] = torch .zeros_like (p , memory_format = torch .preserve_format )
434- muon_momentum_bufs .append (state ["momentum_buffer" ])
435545
436546 # Apply Muon updates
437547 if muon_params :
@@ -495,12 +605,41 @@ def step(self, closure=None):
495605 max_lr = None ,
496606 )
497607
608+ # Log routing summary when we have new routing decisions
609+ if routing_reasons and len (routing_reasons ) > 0 :
610+ # Concise summary
611+ _logger .info (f"Muon parameter routing: { muon_count } Muon, { adamw_count } AdamW" )
612+
613+ # Group by reason for detailed breakdown
614+ reason_groups = {}
615+ for shape_str , reasons in sorted (routing_reasons .items ()):
616+ for reason in reasons :
617+ if reason not in reason_groups :
618+ reason_groups [reason ] = []
619+ reason_groups [reason ].append (shape_str )
620+
621+ # Log summary counts per reason
622+ reason_summary = []
623+ for reason , shapes in sorted (reason_groups .items ()):
624+ reason_summary .append (f"{ reason } ={ len (shapes )} " )
625+ _logger .info (f" Breakdown: { ', ' .join (reason_summary )} " )
626+
627+ # Detailed breakdown at INFO level
628+ if _logger .isEnabledFor (logging .INFO ):
629+ for reason , shapes in sorted (reason_groups .items ()):
630+ optimizer_name = "Muon" if reason == "ok" else "AdamW"
631+ _logger .info (f" { reason } -> { optimizer_name } :" )
632+ for shape in shapes [:10 ]:
633+ _logger .info (f" { shape } " )
634+ if len (shapes ) > 10 :
635+ _logger .info (f" ... and { len (shapes ) - 10 } more" )
636+
498637 return loss
499638
500639
501640def resolve_ns_coefficients (
502- value : Union [str , Sequence [float ], Sequence [Sequence [float ]]],
503- presets : Mapping [str , Sequence [Sequence [float ]]]
641+ value : Union [str , Sequence [float ], Sequence [Sequence [float ]]],
642+ presets : Mapping [str , Sequence [Sequence [float ]]]
504643) -> List [Tuple [float , float , float ]]:
505644 # tiny helpers (kept inline for succinctness)
506645 is_seq = lambda x : isinstance (x , Sequence ) and not isinstance (x , (str , bytes ))
0 commit comments