Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions timm/optim/kron.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ class Kron(torch.optim.Optimizer):
mu_dtype: Dtype of the momentum accumulator.
precond_dtype: Dtype of the preconditioner.
decoupled_decay: AdamW style decoupled weight decay
flatten_dim: Flatten dim >= 2 instead of relying on expressions
flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
flatten_start_dim: Start of flatten range, defaults to 2. Seems good tradeoff for ConvNets.
flatten_end_dim: End of flatten range, defaults to -1.
stochastic_weight_decay: Enable random modulation of weight decay
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
"""

Expand All @@ -114,7 +117,10 @@ def __init__(
mu_dtype: Optional[torch.dtype] = None,
precond_dtype: Optional[torch.dtype] = None,
decoupled_decay: bool = False,
flatten_dim: bool = False,
flatten: bool = False,
flatten_start_dim: int = 2,
flatten_end_dim: int = -1,
stochastic_weight_decay: bool = False,
deterministic: bool = False,
):
if not has_opt_einsum:
Expand All @@ -141,7 +147,10 @@ def __init__(
mu_dtype=mu_dtype,
precond_dtype=precond_dtype,
decoupled_decay=decoupled_decay,
flatten_dim=flatten_dim,
flatten=flatten,
flatten_start_dim=flatten_start_dim,
flatten_end_dim=flatten_end_dim,
stochastic_weight_decay=stochastic_weight_decay,
)
super(Kron, self).__init__(params, defaults)

Expand Down Expand Up @@ -229,8 +238,11 @@ def step(self, closure=None):

grad = p.grad
state = self.state[p]
if group['flatten_dim']:
grad = grad.view(grad.size(0), -1)

flattened = False
if group['flatten']:
grad = safe_flatten(grad, group["flatten_start_dim"], group["flatten_end_dim"])
flattened = True

if len(state) == 0:
state["step"] = 0
Expand Down Expand Up @@ -341,15 +353,19 @@ def step(self, closure=None):

# RMS of pre_grad should be 1.0, so let's cap at 1.1
pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0))
if group['flatten_dim']:
if flattened:
pre_grad = pre_grad.view(p.shape)

# Apply weight decay
if group["weight_decay"] != 0:
weight_decay = group["weight_decay"]
if weight_decay != 0:
if group["stochastic_weight_decay"]:
weight_decay = 2 * self.rng.random() * weight_decay

if group["decoupled_decay"]:
p.mul_(1. - group["lr"] * group["weight_decay"])
p.mul_(1. - group["lr"] * weight_decay)
else:
pre_grad.add_(p, alpha=group["weight_decay"])
pre_grad.add_(p, alpha=weight_decay)

# Update parameters
p.add_(pre_grad, alpha=-group["lr"])
Expand All @@ -361,6 +377,20 @@ def step(self, closure=None):
return loss


def safe_flatten(tensor, start_dim=0, end_dim=-1):
ndim = tensor.ndim

# Convert negative end_dim to positive and clip to end
end_dim = min(end_dim if end_dim >= 0 else ndim + end_dim, ndim - 1)

# If tensor has fewer dims than start_dim or start > end, return tensor as is
if ndim <= start_dim or start_dim > end_dim:
return tensor

# Now safe to flatten
return tensor.flatten(start_dim, end_dim)


def _init_Q_exprs(
t,
scale,
Expand Down