Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import torch
from packaging.version import Version
from torch import nn
from torch.testing._internal.common_utils import (
TestCase,
Expand Down Expand Up @@ -105,8 +106,11 @@ def test_optim_8bit_correctness(self, optim_name):
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model2 = copy.deepcopy(model1)

# https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048

optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters())
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ model = ...
optim = Adam8bit(model.parameters())
```

To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers.
To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 256 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers.

**Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand.

Expand Down
8 changes: 4 additions & 4 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(
weight_decay=0,
amsgrad=False,
*,
block_size=2048,
block_size=256,
) -> None:
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False)

Expand Down Expand Up @@ -199,7 +199,7 @@ def __init__(
weight_decay=0,
amsgrad=False,
*,
block_size=2048,
block_size=256,
) -> None:
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False)

Expand All @@ -218,7 +218,7 @@ def __init__(
weight_decay=1e-2,
amsgrad=False,
*,
block_size=2048,
block_size=256,
) -> None:
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True)

Expand Down Expand Up @@ -256,7 +256,7 @@ def __init__(
weight_decay=1e-2,
amsgrad=False,
*,
block_size=2048,
block_size=256,
) -> None:
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True)

Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/low_bit_optim/subclass_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def dequantize(self, output_dtype=None):
return dequant_with_qmap(self.codes, self.qmap, self.scale).to(dtype)

@classmethod
def zeros(cls, shape, signed: bool = True, block_size: int = 2048, device=None):
def zeros(cls, shape, signed: bool = True, block_size: int = 256, device=None):
codes = torch.zeros(shape, dtype=torch.uint8, device=device)
scale = torch.zeros(codes.numel() // block_size, device=device)
qmap = torch.tensor(QMAP_SIGNED if signed else QMAP_UNSIGNED, device=device)
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/low_bit_optim/subclass_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def dequantize(self, output_dtype=None):
return float_data.view(self.codes.shape).to(dtype)

@classmethod
def zeros(cls, shape, block_size: int = 2048, device=None):
def zeros(cls, shape, block_size: int = 256, device=None):
codes = torch.zeros(shape, dtype=DTYPE, device=device)
scale = torch.zeros(codes.numel() // block_size, device=device)
return cls(codes, scale)
Expand Down