Skip to content

Commit bbeb28f

Browse files
gau-nernstmsaroufim
authored andcommitted
[Low-bit optim] Support for dcp.save() and dcp.load() (#1217)
* support dcp.save * add test for dcp.load() * fix test * typo * implement aten.slice * skip test * fix checks * run ruff * fix formatting * remove add safe globals in test * sort some imports --------- Co-authored-by: Mark Saroufim <[email protected]>
1 parent b06c52f commit bbeb28f

File tree

4 files changed

+240
-61
lines changed

4 files changed

+240
-61
lines changed

test/prototype/test_low_bit_optim.py

Lines changed: 98 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,33 @@
11
import copy
2+
import shutil
23
import tempfile
4+
from pathlib import Path
35

46
import pytest
57
import torch
68
from packaging.version import Version
79
from torch import nn
10+
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
11+
from torch.testing._internal.common_fsdp import FSDPTest
812
from torch.testing._internal.common_utils import (
913
TestCase,
1014
instantiate_parametrized_tests,
1115
parametrize,
1216
run_tests,
1317
)
14-
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
15-
from torch.testing._internal.common_fsdp import FSDPTest
18+
1619
from torchao.prototype import low_bit_optim
1720
from torchao.prototype.low_bit_optim.quant_utils import (
18-
quantize_8bit_with_qmap,
19-
quantize_4bit_with_qmap,
2021
_fp32_to_bf16_sr,
22+
quantize_4bit_with_qmap,
23+
quantize_8bit_with_qmap,
2124
)
25+
from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit
26+
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
27+
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
2228
from torchao.utils import (
23-
TORCH_VERSION_AT_LEAST_2_3,
2429
TORCH_VERSION_AT_LEAST_2_4,
30+
TORCH_VERSION_AT_LEAST_2_5,
2531
TORCH_VERSION_AT_LEAST_2_6,
2632
)
2733

@@ -88,23 +94,15 @@ def test_bf16_stochastic_round(self, device, compile):
8894
x = torch.rand(32, device=device) * 100
8995
x_rep = x.view(-1, 1).repeat(1, 100_000)
9096

91-
if compile:
92-
x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(
93-
x_rep
94-
)
95-
else:
96-
x_rep_bf16 = _fp32_to_bf16_sr(x_rep)
97-
97+
func = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile)
98+
x_rep_bf16 = func(x_rep)
9899
assert x_rep_bf16.dtype is torch.bfloat16
99100

100101
# must cast BF16 tensor back to FP32 so that .mean() is accurate
101102
torch.testing.assert_close(x_rep_bf16.float().mean(1), x, atol=3e-5, rtol=3e-5)
102103

103104

104105
class TestOptim(TestCase):
105-
@pytest.mark.skipif(
106-
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
107-
)
108106
@parametrize(
109107
"optim_name",
110108
["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"],
@@ -151,29 +149,46 @@ def test_optim_smoke(self, optim_name, dtype, device):
151149
for p1, p2 in zip(model.parameters(), model2.parameters()):
152150
torch.testing.assert_close(p2, p1)
153151

152+
# aten.slice is required for dcp.load() when world size changes i.e. re-sharding
153+
# however, it's cumbersome to test it directly, since we would need to run distributed
154+
# test 2 times with different world size, and persist checkpoint across the 2 runs.
155+
# thus, we only test for the required op. note that future implementations of dcp.load()
156+
# may use other ops.
157+
@parametrize("subclass", [OptimState4bit, OptimState8bit, OptimStateFp8])
158+
@parametrize("shape", [(4096,), (256, 256)])
159+
@parametrize("device", _DEVICES)
160+
def test_subclass_slice(self, subclass, shape, device):
161+
if subclass == OptimStateFp8:
162+
if device == "cpu" and len(shape) > 1 and not TORCH_VERSION_AT_LEAST_2_5:
163+
pytest.skip("fill_cpu not implemented for Float8_e4m3fn for torch<2.5")
164+
if device == "cuda" and not TORCH_VERSION_AT_LEAST_2_4:
165+
pytest.skip("FP8 CUDA requires PyTorch >= 2.4")
166+
if device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
167+
pytest.skip("FP8 CUDA requires compute capability >= 8.9")
168+
169+
tensor = subclass.zeros(shape, device=device)
170+
offset = shape[0] // 2
171+
172+
torch.testing.assert_close(tensor.dequantize()[:offset], tensor[:offset].dequantize())
173+
torch.testing.assert_close(tensor.dequantize()[offset:offset*2], tensor[offset:offset*2].dequantize())
174+
154175
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available")
155176
@pytest.mark.skipif(
156177
not torch.cuda.is_available(),
157178
reason="bitsandbytes 8-bit Adam only works for CUDA",
158179
)
159-
@pytest.mark.skipif(
160-
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
161-
)
162180
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
163181
def test_optim_8bit_correctness(self, optim_name):
164182
device = "cuda"
165-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
166-
device
167-
)
183+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
184+
model1.to(device)
168185
model2 = copy.deepcopy(model1)
169186

170187
# https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0
171188
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048
172189

173190
optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
174-
optim2 = getattr(low_bit_optim, optim_name)(
175-
model2.parameters(), block_size=block_size
176-
)
191+
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)
177192

178193
for _ in range(2):
179194
x = torch.randn(4, 32, device=device)
@@ -196,15 +211,11 @@ def test_optim_8bit_correctness(self, optim_name):
196211
@pytest.mark.skipif(
197212
not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA"
198213
)
199-
@pytest.mark.skipif(
200-
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
201-
)
202214
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
203215
def test_optim_4bit_correctness(self, optim_name):
204216
device = "cuda"
205-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
206-
device
207-
)
217+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
218+
model1.to(device)
208219
model2 = copy.deepcopy(model1)
209220

210221
# lpmm doesn't have Adam. use AdamW with no weight decay instead.
@@ -238,12 +249,11 @@ def test_optim_4bit_correctness(self, optim_name):
238249
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
239250
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
240251
device = "cuda"
241-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
242-
device
243-
)
244-
model1[0].requires_grad_(
245-
False
246-
) # make sure it can work in the presence of non-trainable params
252+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
253+
model1.to(device)
254+
255+
# make sure it can work in the presence of non-trainable params
256+
model1[0].requires_grad_(False)
247257
model2 = copy.deepcopy(model1)
248258

249259
optim1 = torch.optim.AdamW(model1.parameters())
@@ -273,12 +283,9 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
273283
)
274284
def test_optim_cpu_offload_save_load(self):
275285
device = "cuda"
276-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
277-
device
278-
)
279-
optim1 = low_bit_optim.CPUOffloadOptimizer(
280-
model1.parameters(), torch.optim.AdamW
281-
)
286+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
287+
model1.to(device)
288+
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)
282289

283290
for _ in range(2):
284291
x = torch.randn(4, 32, device=device)
@@ -293,9 +300,7 @@ def test_optim_cpu_offload_save_load(self):
293300

294301
# resume training
295302
model2 = copy.deepcopy(model1)
296-
optim2 = low_bit_optim.CPUOffloadOptimizer(
297-
model2.parameters(), torch.optim.AdamW
298-
)
303+
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
299304
optim2.load_state_dict(state_dict)
300305

301306
for _ in range(2):
@@ -315,16 +320,17 @@ def test_optim_cpu_offload_save_load(self):
315320
def test_optim_bf16_stochastic_round_correctness(self):
316321
device = "cuda" if torch.cuda.is_available() else "cpu"
317322
torch.manual_seed(2024)
318-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
319-
device
320-
)
323+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
324+
model1.to(device)
321325
model2 = copy.deepcopy(model1).bfloat16()
322326

323327
# small LR so that weight update is small
324328
# when bf16_stochastic_round=False, the test will fail after 1 iteration
325329
optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5)
326330
optim2 = low_bit_optim._AdamW(
327-
model2.parameters(), lr=1e-5, bf16_stochastic_round=True
331+
model2.parameters(),
332+
lr=1e-5,
333+
bf16_stochastic_round=True,
328334
)
329335

330336
# overfit on this sample
@@ -350,10 +356,13 @@ def test_optim_bf16_stochastic_round_correctness(self):
350356
)
351357

352358

359+
_FSDP_WORLD_SIZE = 2
360+
361+
353362
class TestFSDP2(FSDPTest):
354363
@property
355364
def world_size(self) -> int:
356-
return 2
365+
return _FSDP_WORLD_SIZE
357366

358367
@pytest.mark.skipif(
359368
not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required."
@@ -370,12 +379,12 @@ def test_fsdp2(self):
370379
)
371380

372381
def _test_fsdp2(self, optim_cls):
382+
import torch.distributed as dist
383+
import torch.distributed.checkpoint as dcp
384+
import torch.utils._pytree as pytree
373385
from torch.distributed._composable.fsdp import fully_shard
374-
from torch.testing._internal.distributed._tensor.common_dtensor import (
375-
ModelArgs,
376-
Transformer,
377-
TransformerBlock,
378-
)
386+
from torch.distributed.tensor import DTensor
387+
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer, TransformerBlock
379388

380389
batch_size = 3
381390
vocab_size = 1024
@@ -413,9 +422,7 @@ def _test_fsdp2(self, optim_cls):
413422
base_loss.backward()
414423
for param in base_model.parameters():
415424
if param.grad is not None:
416-
torch.distributed.all_reduce(
417-
param.grad, op=torch.distributed.ReduceOp.AVG
418-
)
425+
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
419426
base_optim.step()
420427
self.assertEqual(fsdp_loss, base_loss)
421428

@@ -428,6 +435,39 @@ def _test_fsdp2(self, optim_cls):
428435

429436
self.assertEqual(base_exp_avg.dequantize(), full_fsdp_exp_avg.dequantize())
430437

438+
# test for compatibility with dcp.save() and .load()
439+
checkpoint_id = f"_fsdp_low_bit_optim_{optim_cls.__name__}"
440+
if Path(checkpoint_id).exists():
441+
shutil.rmtree(checkpoint_id)
442+
dcp.save(fsdp_optim.state_dict(), checkpoint_id=checkpoint_id)
443+
444+
# normally we would want to use dcp.state_dict.get_optimizer_state_dict() to initialize optim states.
445+
# however, currently it does not respect tensor-ness of LR pytorch/pytorch#139575.
446+
# therefore, we have to manually initialize optim state here.
447+
resumed_fsdp_optim = optim_cls(fsdp_model.parameters(), lr=1e-2)
448+
for p in fsdp_model.parameters():
449+
p.grad = torch.zeros_like(p)
450+
451+
# this will change model weights due to weight decay, but since we don't use the model anymore, it's fine.
452+
resumed_fsdp_optim.step()
453+
454+
dcp.load(resumed_fsdp_optim.state_dict(), checkpoint_id=checkpoint_id)
455+
if dist.get_rank() == 0:
456+
shutil.rmtree(checkpoint_id)
457+
458+
subclasses = (OptimState4bit, OptimState8bit, OptimStateFp8)
459+
460+
for v1, v2 in zip(pytree.tree_iter(resumed_fsdp_optim.state_dict()), pytree.tree_iter(fsdp_optim.state_dict())):
461+
assert v1.__class__ == v2.__class__, (v1.__class__, v2.__class__)
462+
if isinstance(v1, DTensor):
463+
v1 = v1.to_local()
464+
v2 = v2.to_local()
465+
assert v1.__class__ == v2.__class__, (v1.__class__, v2.__class__)
466+
if isinstance(v1, subclasses):
467+
v1 = v1.dequantize()
468+
v2 = v2.dequantize()
469+
self.assertEqual(v1, v2)
470+
431471

432472
instantiate_parametrized_tests(TestQuantize)
433473
instantiate_parametrized_tests(TestOptim)

torchao/prototype/low_bit_optim/subclass_4bit.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,15 @@ def _(func, types, args, kwargs):
177177
)
178178

179179

180-
# this is needed for DTensor.full_tensor()
181180
@OptimState4bit.implements(
182181
[
182+
# required by DTensor.full_tensor()
183183
c10d_functional.all_gather_into_tensor.default,
184184
_c10d_functional.all_gather_into_tensor.default,
185185
c10d_functional.wait_tensor.default,
186186
_c10d_functional.wait_tensor.default,
187+
# required by torch.distributed.checkpoint.save
188+
aten.detach.default,
187189
]
188190
)
189191
def _(func, types, args, kwargs):
@@ -201,6 +203,53 @@ def _(func, types, args, kwargs):
201203
return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape)
202204

203205

206+
# required by torch.distributed.checkpoint.save
207+
# note that we don't actually implement pin memory for this tensor subclass
208+
# (pin_memory argument is ignored in aten._to_copy)
209+
@OptimState4bit.implements(aten.is_pinned.default)
210+
def _(func, types, args, kwargs):
211+
return (
212+
args[0].codes.is_pinned()
213+
and args[0].scale.is_pinned()
214+
and args[0].qmap.is_pinned()
215+
)
216+
217+
218+
# required by torch.distributed.checkpoint.load when world size changes i.e. re-sharding
219+
@OptimState4bit.implements(aten.slice.Tensor)
220+
def _(func, types, args, kwargs):
221+
x, dim, start, end = args[:4]
222+
step = args[4] if len(args) > 4 else 1
223+
224+
# input validation
225+
if dim != 0:
226+
raise ValueError("Only support aten.slice along the first dim")
227+
if step != 1:
228+
raise ValueError("Only support aten.slice with step=1")
229+
230+
block_size = x.block_size
231+
stride = math.prod(x.shape[1:])
232+
233+
# for 1 increment in x along the first dim,
234+
# (flattened) scale will increment by stride / block_size
235+
if (start * stride) % block_size != 0 or (end * stride) % block_size != 0:
236+
raise ValueError(
237+
f"Invalid start or end for shape={x.shape} and block_size={block_size}. "
238+
f"Make sure start and end align with block boundary. "
239+
f"Received start={start}, end={end}."
240+
)
241+
242+
# note that for 4-bit, we store .codes as flattened buffer
243+
# divide by 2 since we store 2x 4-bit in 1x uint8
244+
codes = x.codes[start * stride // 2 : end * stride // 2]
245+
scale = x.scale[start * stride // block_size : end * stride // block_size]
246+
247+
# adjust the first dim
248+
shape = (x.shape[0] * codes.numel() // x.codes.numel(),) + x.shape[1:]
249+
250+
return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape)
251+
252+
204253
if TORCH_VERSION_AT_LEAST_2_5:
205254
from torch.serialization import add_safe_globals
206255

0 commit comments

Comments
 (0)