diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index ccf925a3fd..496fa3659f 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -3,6 +3,7 @@ import pytest import torch +from packaging.version import Version from torch import nn from torch.testing._internal.common_utils import ( TestCase, @@ -105,8 +106,11 @@ def test_optim_8bit_correctness(self, optim_name): model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) model2 = copy.deepcopy(model1) + # https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0 + block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048 + optim1 = getattr(bnb.optim, optim_name)(model1.parameters()) - optim2 = getattr(low_bit_optim, optim_name)(model2.parameters()) + optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size) for _ in range(2): x = torch.randn(4, 32, device=device) diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index 7781386bdd..64cb536ac1 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -19,7 +19,7 @@ model = ... optim = Adam8bit(model.parameters()) ``` -To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers. +To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 256 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers. **Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand. diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 7f0d47854b..6c3c6996b9 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -161,7 +161,7 @@ def __init__( weight_decay=0, amsgrad=False, *, - block_size=2048, + block_size=256, ) -> None: super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False) @@ -199,7 +199,7 @@ def __init__( weight_decay=0, amsgrad=False, *, - block_size=2048, + block_size=256, ) -> None: super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False) @@ -218,7 +218,7 @@ def __init__( weight_decay=1e-2, amsgrad=False, *, - block_size=2048, + block_size=256, ) -> None: super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True) @@ -256,7 +256,7 @@ def __init__( weight_decay=1e-2, amsgrad=False, *, - block_size=2048, + block_size=256, ) -> None: super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True) diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 865498a57e..9c6e641e6d 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -53,7 +53,7 @@ def dequantize(self, output_dtype=None): return dequant_with_qmap(self.codes, self.qmap, self.scale).to(dtype) @classmethod - def zeros(cls, shape, signed: bool = True, block_size: int = 2048, device=None): + def zeros(cls, shape, signed: bool = True, block_size: int = 256, device=None): codes = torch.zeros(shape, dtype=torch.uint8, device=device) scale = torch.zeros(codes.numel() // block_size, device=device) qmap = torch.tensor(QMAP_SIGNED if signed else QMAP_UNSIGNED, device=device) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 805c516f4e..146023c9f5 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -60,7 +60,7 @@ def dequantize(self, output_dtype=None): return float_data.view(self.codes.shape).to(dtype) @classmethod - def zeros(cls, shape, block_size: int = 2048, device=None): + def zeros(cls, shape, block_size: int = 256, device=None): codes = torch.zeros(shape, dtype=DTYPE, device=device) scale = torch.zeros(codes.numel() // block_size, device=device) return cls(codes, scale)