Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,7 @@ include = [
"test/dtypes/test_affine_quantized_float.py",
"torchao/quantization/weight_tensor_linear_activation_quantization.py",
"torchao/dtypes/**/*.py",
"torchao/prototype/low_bit_optim/**.py",
"test/prototype/low_bit_optim/**.py",

]
100 changes: 76 additions & 24 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
quantize_4bit_with_qmap,
_fp32_to_bf16_sr,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_6,
)

try:
import bitsandbytes as bnb
Expand Down Expand Up @@ -85,7 +89,9 @@ def test_bf16_stochastic_round(self, device, compile):
x_rep = x.view(-1, 1).repeat(1, 100_000)

if compile:
x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(x_rep)
x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(
x_rep
)
else:
x_rep_bf16 = _fp32_to_bf16_sr(x_rep)

Expand All @@ -96,8 +102,13 @@ def test_bf16_stochastic_round(self, device, compile):


class TestOptim(TestCase):
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"])
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
)
@parametrize(
"optim_name",
["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"],
)
@parametrize("dtype", [torch.float32, torch.bfloat16])
@parametrize("device", _DEVICES)
def test_optim_smoke(self, optim_name, dtype, device):
Expand Down Expand Up @@ -141,19 +152,28 @@ def test_optim_smoke(self, optim_name, dtype, device):
torch.testing.assert_close(p2, p1)

@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="bitsandbytes 8-bit Adam only works for CUDA",
)
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
)
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
def test_optim_8bit_correctness(self, optim_name):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model2 = copy.deepcopy(model1)

# https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048

optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)
optim2 = getattr(low_bit_optim, optim_name)(
model2.parameters(), block_size=block_size
)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand All @@ -173,12 +193,18 @@ def test_optim_8bit_correctness(self, optim_name):

# this will not run in CI because we can't install lpmm
@pytest.mark.skipif(lpmm is None, reason="lpmm is not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA"
)
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
)
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
def test_optim_4bit_correctness(self, optim_name):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model2 = copy.deepcopy(model1)

# lpmm doesn't have Adam. use AdamW with no weight decay instead.
Expand Down Expand Up @@ -206,17 +232,25 @@ def test_optim_4bit_correctness(self, optim_name):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA")
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
)
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model1[0].requires_grad_(False) # make sure it can work in the presence of non-trainable params
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model1[0].requires_grad_(
False
) # make sure it can work in the presence of non-trainable params
model2 = copy.deepcopy(model1)

optim1 = torch.optim.AdamW(model1.parameters())
optim2 = low_bit_optim.CPUOffloadOptimizer(
model2.parameters(), torch.optim.AdamW, offload_gradients=offload_grad,
model2.parameters(),
torch.optim.AdamW,
offload_gradients=offload_grad,
)

for _ in range(2):
Expand All @@ -234,11 +268,17 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA")
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
)
def test_optim_cpu_offload_save_load(self):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
optim1 = low_bit_optim.CPUOffloadOptimizer(
model1.parameters(), torch.optim.AdamW
)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand All @@ -253,7 +293,9 @@ def test_optim_cpu_offload_save_load(self):

# resume training
model2 = copy.deepcopy(model1)
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
optim2 = low_bit_optim.CPUOffloadOptimizer(
model2.parameters(), torch.optim.AdamW
)
optim2.load_state_dict(state_dict)

for _ in range(2):
Expand All @@ -273,13 +315,17 @@ def test_optim_cpu_offload_save_load(self):
def test_optim_bf16_stochastic_round_correctness(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(2024)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model2 = copy.deepcopy(model1).bfloat16()

# small LR so that weight update is small
# when bf16_stochastic_round=False, the test will fail after 1 iteration
optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5)
optim2 = low_bit_optim._AdamW(model2.parameters(), lr=1e-5, bf16_stochastic_round=True)
optim2 = low_bit_optim._AdamW(
model2.parameters(), lr=1e-5, bf16_stochastic_round=True
)

# overfit on this sample
x = torch.randn(4, 32, device=device)
Expand All @@ -299,15 +345,19 @@ def test_optim_bf16_stochastic_round_correctness(self):
optim2.step()
optim2.zero_grad()

torch.testing.assert_close(loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}")
torch.testing.assert_close(
loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}"
)


class TestFSDP2(FSDPTest):
@property
def world_size(self) -> int:
return 2

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required.")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required."
)
@skip_if_lt_x_gpu(2)
def test_fsdp2(self):
optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit]
Expand Down Expand Up @@ -363,7 +413,9 @@ def _test_fsdp2(self, optim_cls):
base_loss.backward()
for param in base_model.parameters():
if param.grad is not None:
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG)
torch.distributed.all_reduce(
param.grad, op=torch.distributed.ReduceOp.AVG
)
base_optim.step()
self.assertEqual(fsdp_loss, base_loss)

Expand Down
11 changes: 11 additions & 0 deletions torchao/prototype/low_bit_optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
from .adam import Adam4bit, Adam8bit, AdamFp8, AdamW4bit, AdamW8bit, AdamWFp8, _AdamW
from .cpu_offload import CPUOffloadOptimizer

__all__ = [
"Adam4bit",
"Adam8bit",
"AdamFp8",
"AdamW4bit",
"AdamW8bit",
"AdamWFp8",
"_AdamW",
"CPUOffloadOptimizer",
]
30 changes: 24 additions & 6 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,28 @@

import torch
from torch import Tensor
from torch.optim import Optimizer
from torch.distributed._tensor import DTensor
from torch.optim import Optimizer

from .subclass_8bit import OptimState8bit
from .quant_utils import _fp32_to_bf16_sr
from .subclass_4bit import OptimState4bit
from .subclass_8bit import OptimState8bit
from .subclass_fp8 import OptimStateFp8
from .quant_utils import _fp32_to_bf16_sr


class _AdamBase(Optimizer):
def __init__(
self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size, bf16_stochastic_round, is_adamw
self,
params,
lr,
betas,
eps,
weight_decay,
amsgrad,
*,
block_size,
bf16_stochastic_round,
is_adamw,
) -> None:
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
Expand All @@ -23,7 +33,13 @@ def __init__(
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=torch.tensor(lr), betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
defaults = dict(
lr=torch.tensor(lr),
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
)
super().__init__(params, defaults)
self.block_size = block_size
self.bf16_stochastic_round = bf16_stochastic_round
Expand All @@ -45,7 +61,9 @@ def _new_buffer(self, p: Tensor, signed: bool):
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
if isinstance(p, DTensor):
out = DTensor.from_local(
local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size),
local_tensor=self._subclass_zeros(
p.to_local(), signed, self.block_size
),
device_mesh=p.device_mesh,
placements=p.placements,
run_check=False,
Expand Down
10 changes: 8 additions & 2 deletions torchao/prototype/low_bit_optim/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ def __init__(
kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`.
"""
# default to fused CPU AdamW
if optimizer_class is torch.optim.AdamW and TORCH_VERSION_AT_LEAST_2_4 and "fused" not in kwargs:
if (
optimizer_class is torch.optim.AdamW
and TORCH_VERSION_AT_LEAST_2_4
and "fused" not in kwargs
):
kwargs.update(fused=True)

param_groups = list(params)
Expand Down Expand Up @@ -77,7 +81,9 @@ def backward_hook(p_cuda):
self.param_cuda2cpu_map[p_cuda] = p_cpu

p_cuda.register_post_accumulate_grad_hook(backward_hook)
self.optim_dict[p_cuda] = optimizer_class([{"params": p_cpu, **param_group}], **kwargs)
self.optim_dict[p_cuda] = optimizer_class(
[{"params": p_cpu, **param_group}], **kwargs
)

@torch.no_grad()
def step(self, closure=None):
Expand Down
13 changes: 8 additions & 5 deletions torchao/prototype/low_bit_optim/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,17 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
# [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16
#
# we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16
rand_16bit = torch.randint(0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32)
rand_16bit = torch.randint(
0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32
)
x_f32_bits = x_f32.view(torch.int32)
x_fraction = x_f32_bits & 0xFFFF # lower 16 bits
x_bf16_towards_zero = x_f32_bits & 0xFFFF0000 # upper 16 bits
x_fraction = x_f32_bits & 0xFFFF # lower 16 bits
x_bf16_towards_zero = x_f32_bits & 0xFFFF0000 # upper 16 bits

x_f32_bits = torch.where(
rand_16bit < x_fraction, # this is True with the probability of p_fraction
x_bf16_towards_zero + 0x10000, # this might overflow, which will result in UB due to signed integer
rand_16bit < x_fraction, # this is True with the probability of p_fraction
x_bf16_towards_zero
+ 0x10000, # this might overflow, which will result in UB due to signed integer
x_bf16_towards_zero,
)
# alternative, slightly faster
Expand Down
Loading
Loading