Skip to content

Commit ba083ea

Browse files
committed
change block_size
1 parent 653efe9 commit ba083ea

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

torchao/prototype/low_bit_optim/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ model = ...
1919
optim = Adam8bit(model.parameters())
2020
```
2121

22-
To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers.
22+
To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 256 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers.
2323

2424
**Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand.
2525

torchao/prototype/low_bit_optim/adam.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __init__(
161161
weight_decay=0,
162162
amsgrad=False,
163163
*,
164-
block_size=2048,
164+
block_size=256,
165165
) -> None:
166166
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False)
167167

@@ -199,7 +199,7 @@ def __init__(
199199
weight_decay=0,
200200
amsgrad=False,
201201
*,
202-
block_size=2048,
202+
block_size=256,
203203
) -> None:
204204
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False)
205205

@@ -218,7 +218,7 @@ def __init__(
218218
weight_decay=1e-2,
219219
amsgrad=False,
220220
*,
221-
block_size=2048,
221+
block_size=256,
222222
) -> None:
223223
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True)
224224

@@ -256,7 +256,7 @@ def __init__(
256256
weight_decay=1e-2,
257257
amsgrad=False,
258258
*,
259-
block_size=2048,
259+
block_size=256,
260260
) -> None:
261261
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True)
262262

torchao/prototype/low_bit_optim/subclass_8bit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def dequantize(self, output_dtype=None):
5353
return dequant_with_qmap(self.codes, self.qmap, self.scale).to(dtype)
5454

5555
@classmethod
56-
def zeros(cls, shape, signed: bool = True, block_size: int = 2048, device=None):
56+
def zeros(cls, shape, signed: bool = True, block_size: int = 256, device=None):
5757
codes = torch.zeros(shape, dtype=torch.uint8, device=device)
5858
scale = torch.zeros(codes.numel() // block_size, device=device)
5959
qmap = torch.tensor(QMAP_SIGNED if signed else QMAP_UNSIGNED, device=device)

torchao/prototype/low_bit_optim/subclass_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def dequantize(self, output_dtype=None):
6060
return float_data.view(self.codes.shape).to(dtype)
6161

6262
@classmethod
63-
def zeros(cls, shape, block_size: int = 2048, device=None):
63+
def zeros(cls, shape, block_size: int = 256, device=None):
6464
codes = torch.zeros(shape, dtype=DTYPE, device=device)
6565
scale = torch.zeros(codes.numel() // block_size, device=device)
6666
return cls(codes, scale)

0 commit comments

Comments
 (0)