Skip to content

Commit 8b58d10

Browse files
committed
Fix for weights-only load
stack-info: PR: #1228, branch: drisspg/stack/19
1 parent 6fd77d5 commit 8b58d10

File tree

4 files changed

+169
-56
lines changed

4 files changed

+169
-56
lines changed

test/prototype/test_low_bit_optim.py

Lines changed: 78 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
quantize_4bit_with_qmap,
2020
_fp32_to_bf16_sr,
2121
)
22-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6
22+
from torchao.utils import (
23+
TORCH_VERSION_AT_LEAST_2_3,
24+
TORCH_VERSION_AT_LEAST_2_4,
25+
TORCH_VERSION_AT_LEAST_2_6,
26+
)
2327

2428
try:
2529
import bitsandbytes as bnb
@@ -85,7 +89,9 @@ def test_bf16_stochastic_round(self, device, compile):
8589
x_rep = x.view(-1, 1).repeat(1, 100_000)
8690

8791
if compile:
88-
x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(x_rep)
92+
x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(
93+
x_rep
94+
)
8995
else:
9096
x_rep_bf16 = _fp32_to_bf16_sr(x_rep)
9197

@@ -96,8 +102,13 @@ def test_bf16_stochastic_round(self, device, compile):
96102

97103

98104
class TestOptim(TestCase):
99-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
100-
@parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"])
105+
@pytest.mark.skipif(
106+
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
107+
)
108+
@parametrize(
109+
"optim_name",
110+
["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"],
111+
)
101112
@parametrize("dtype", [torch.float32, torch.bfloat16])
102113
@parametrize("device", _DEVICES)
103114
def test_optim_smoke(self, optim_name, dtype, device):
@@ -120,7 +131,7 @@ def test_optim_smoke(self, optim_name, dtype, device):
120131
# test serialization. also test the case CUDA optim loads CPU state dict
121132
with tempfile.NamedTemporaryFile() as f:
122133
torch.save(optim.state_dict(), f.name)
123-
state_dict = torch.load(f.name, map_location="cpu")
134+
state_dict = torch.load(f.name, map_location="cpu", weights_only=True)
124135

125136
model2 = copy.deepcopy(model)
126137
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters())
@@ -141,19 +152,28 @@ def test_optim_smoke(self, optim_name, dtype, device):
141152
torch.testing.assert_close(p2, p1)
142153

143154
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available")
144-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
145-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
155+
@pytest.mark.skipif(
156+
not torch.cuda.is_available(),
157+
reason="bitsandbytes 8-bit Adam only works for CUDA",
158+
)
159+
@pytest.mark.skipif(
160+
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
161+
)
146162
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
147163
def test_optim_8bit_correctness(self, optim_name):
148164
device = "cuda"
149-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
165+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
166+
device
167+
)
150168
model2 = copy.deepcopy(model1)
151169

152170
# https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0
153171
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048
154172

155173
optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
156-
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)
174+
optim2 = getattr(low_bit_optim, optim_name)(
175+
model2.parameters(), block_size=block_size
176+
)
157177

158178
for _ in range(2):
159179
x = torch.randn(4, 32, device=device)
@@ -173,12 +193,18 @@ def test_optim_8bit_correctness(self, optim_name):
173193

174194
# this will not run in CI because we can't install lpmm
175195
@pytest.mark.skipif(lpmm is None, reason="lpmm is not available")
176-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA")
177-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
196+
@pytest.mark.skipif(
197+
not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA"
198+
)
199+
@pytest.mark.skipif(
200+
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
201+
)
178202
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
179203
def test_optim_4bit_correctness(self, optim_name):
180204
device = "cuda"
181-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
205+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
206+
device
207+
)
182208
model2 = copy.deepcopy(model1)
183209

184210
# lpmm doesn't have Adam. use AdamW with no weight decay instead.
@@ -206,17 +232,25 @@ def test_optim_4bit_correctness(self, optim_name):
206232
for p1, p2 in zip(model1.parameters(), model2.parameters()):
207233
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)
208234

209-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA")
235+
@pytest.mark.skipif(
236+
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
237+
)
210238
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
211239
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
212240
device = "cuda"
213-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
214-
model1[0].requires_grad_(False) # make sure it can work in the presence of non-trainable params
241+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
242+
device
243+
)
244+
model1[0].requires_grad_(
245+
False
246+
) # make sure it can work in the presence of non-trainable params
215247
model2 = copy.deepcopy(model1)
216248

217249
optim1 = torch.optim.AdamW(model1.parameters())
218250
optim2 = low_bit_optim.CPUOffloadOptimizer(
219-
model2.parameters(), torch.optim.AdamW, offload_gradients=offload_grad,
251+
model2.parameters(),
252+
torch.optim.AdamW,
253+
offload_gradients=offload_grad,
220254
)
221255

222256
for _ in range(2):
@@ -234,11 +268,17 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
234268
for p1, p2 in zip(model1.parameters(), model2.parameters()):
235269
torch.testing.assert_close(p2, p1)
236270

237-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA")
271+
@pytest.mark.skipif(
272+
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
273+
)
238274
def test_optim_cpu_offload_save_load(self):
239275
device = "cuda"
240-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
241-
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)
276+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
277+
device
278+
)
279+
optim1 = low_bit_optim.CPUOffloadOptimizer(
280+
model1.parameters(), torch.optim.AdamW
281+
)
242282

243283
for _ in range(2):
244284
x = torch.randn(4, 32, device=device)
@@ -249,11 +289,13 @@ def test_optim_cpu_offload_save_load(self):
249289
# save checkpoint. make sure it can be serialized by torch.save()
250290
with tempfile.NamedTemporaryFile() as file:
251291
torch.save(optim1.state_dict(), file.name)
252-
state_dict = torch.load(file.name, map_location="cpu")
292+
state_dict = torch.load(file.name, map_location="cpu", weights_only=True)
253293

254294
# resume training
255295
model2 = copy.deepcopy(model1)
256-
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
296+
optim2 = low_bit_optim.CPUOffloadOptimizer(
297+
model2.parameters(), torch.optim.AdamW
298+
)
257299
optim2.load_state_dict(state_dict)
258300

259301
for _ in range(2):
@@ -273,13 +315,17 @@ def test_optim_cpu_offload_save_load(self):
273315
def test_optim_bf16_stochastic_round_correctness(self):
274316
device = "cuda" if torch.cuda.is_available() else "cpu"
275317
torch.manual_seed(2024)
276-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
318+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
319+
device
320+
)
277321
model2 = copy.deepcopy(model1).bfloat16()
278322

279323
# small LR so that weight update is small
280324
# when bf16_stochastic_round=False, the test will fail after 1 iteration
281325
optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5)
282-
optim2 = low_bit_optim._AdamW(model2.parameters(), lr=1e-5, bf16_stochastic_round=True)
326+
optim2 = low_bit_optim._AdamW(
327+
model2.parameters(), lr=1e-5, bf16_stochastic_round=True
328+
)
283329

284330
# overfit on this sample
285331
x = torch.randn(4, 32, device=device)
@@ -299,15 +345,19 @@ def test_optim_bf16_stochastic_round_correctness(self):
299345
optim2.step()
300346
optim2.zero_grad()
301347

302-
torch.testing.assert_close(loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}")
348+
torch.testing.assert_close(
349+
loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}"
350+
)
303351

304352

305353
class TestFSDP2(FSDPTest):
306354
@property
307355
def world_size(self) -> int:
308356
return 2
309357

310-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required.")
358+
@pytest.mark.skipif(
359+
not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required."
360+
)
311361
@skip_if_lt_x_gpu(2)
312362
def test_fsdp2(self):
313363
optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit]
@@ -363,7 +413,9 @@ def _test_fsdp2(self, optim_cls):
363413
base_loss.backward()
364414
for param in base_model.parameters():
365415
if param.grad is not None:
366-
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG)
416+
torch.distributed.all_reduce(
417+
param.grad, op=torch.distributed.ReduceOp.AVG
418+
)
367419
base_optim.step()
368420
self.assertEqual(fsdp_loss, base_loss)
369421

torchao/prototype/low_bit_optim/subclass_4bit.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,18 @@
33
import torch
44
from torch import Tensor
55
from torch.utils._python_dispatch import return_and_correct_aliasing
6-
from torchao.utils import TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_4
6+
from torchao.utils import (
7+
TorchAOBaseTensor,
8+
TORCH_VERSION_AT_LEAST_2_4,
9+
TORCH_VERSION_AT_LEAST_2_5,
10+
)
711

8-
from .quant_utils import create_dynamic_map, scale_tensor, quantize_4bit_with_qmap, dequant_with_qmap
12+
from .quant_utils import (
13+
create_dynamic_map,
14+
scale_tensor,
15+
quantize_4bit_with_qmap,
16+
dequant_with_qmap,
17+
)
918

1019

1120
aten = torch.ops.aten
@@ -55,8 +64,12 @@ def __tensor_flatten__(self):
5564
return self.tensor_attrs, [self.signed, self._shape]
5665

5766
@classmethod
58-
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
59-
return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes)
67+
def __tensor_unflatten__(
68+
cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None
69+
):
70+
return cls(
71+
*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes
72+
)
6073

6174
def dequantize(self, output_dtype=None):
6275
codes = torch.stack([self.codes >> 4, self.codes & 0b1111], dim=-1) # unpack
@@ -85,6 +98,7 @@ def __repr__(self):
8598
# in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when
8699
# dtype is the same but device is different. thus, we must override .to() method instead.
87100
if not TORCH_VERSION_AT_LEAST_2_4:
101+
88102
def _to(self, *args, **kwargs):
89103
# ignore other args/kwargs
90104
device = kwargs.pop("device", None)
@@ -158,16 +172,20 @@ def _(func, types, args, kwargs):
158172
if len(shape) == 1 and shape[0] == -1:
159173
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),))
160174

161-
raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]")
175+
raise ValueError(
176+
f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]"
177+
)
162178

163179

164180
# this is needed for DTensor.full_tensor()
165-
@OptimState4bit.implements([
166-
c10d_functional.all_gather_into_tensor.default,
167-
_c10d_functional.all_gather_into_tensor.default,
168-
c10d_functional.wait_tensor.default,
169-
_c10d_functional.wait_tensor.default,
170-
])
181+
@OptimState4bit.implements(
182+
[
183+
c10d_functional.all_gather_into_tensor.default,
184+
_c10d_functional.all_gather_into_tensor.default,
185+
c10d_functional.wait_tensor.default,
186+
_c10d_functional.wait_tensor.default,
187+
]
188+
)
171189
def _(func, types, args, kwargs):
172190
x = args[0]
173191
if not isinstance(x, OptimState4bit):
@@ -181,3 +199,10 @@ def _(func, types, args, kwargs):
181199

182200
# assume tensors from all ranks have the same signedness
183201
return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape)
202+
203+
204+
if TORCH_VERSION_AT_LEAST_2_5:
205+
# Needed to load Float8Tensor with weights_only = True
206+
from torch.serialization import add_safe_globals
207+
208+
add_safe_globals([OptimState4bit])

torchao/prototype/low_bit_optim/subclass_8bit.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
11
import torch
22
from torch import Tensor
33
from torch.utils._python_dispatch import return_and_correct_aliasing
4-
from torchao.utils import TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_4
4+
from torchao.utils import (
5+
TorchAOBaseTensor,
6+
TORCH_VERSION_AT_LEAST_2_4,
7+
TORCH_VERSION_AT_LEAST_2_5,
8+
)
59

6-
from .quant_utils import create_dynamic_map, scale_tensor, quantize_8bit_with_qmap, dequant_with_qmap
10+
from .quant_utils import (
11+
create_dynamic_map,
12+
scale_tensor,
13+
quantize_8bit_with_qmap,
14+
dequant_with_qmap,
15+
)
716

817

918
aten = torch.ops.aten
@@ -46,8 +55,12 @@ def __tensor_flatten__(self):
4655
return self.tensor_attrs, [self.signed]
4756

4857
@classmethod
49-
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
50-
return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes)
58+
def __tensor_unflatten__(
59+
cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None
60+
):
61+
return cls(
62+
*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes
63+
)
5164

5265
def dequantize(self, output_dtype=None):
5366
float_data = dequant_with_qmap(self.codes, self.qmap, self.scale)
@@ -72,6 +85,7 @@ def __repr__(self):
7285
# in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when
7386
# dtype is the same but device is different. thus, we must override .to() method instead.
7487
if not TORCH_VERSION_AT_LEAST_2_4:
88+
7589
def _to(self, *args, **kwargs):
7690
# ignore other args/kwargs
7791
device = kwargs.pop("device", None)
@@ -136,12 +150,14 @@ def _(func, types, args, kwargs):
136150

137151

138152
# this is needed for DTensor.full_tensor()
139-
@OptimState8bit.implements([
140-
c10d_functional.all_gather_into_tensor.default,
141-
_c10d_functional.all_gather_into_tensor.default,
142-
c10d_functional.wait_tensor.default,
143-
_c10d_functional.wait_tensor.default,
144-
])
153+
@OptimState8bit.implements(
154+
[
155+
c10d_functional.all_gather_into_tensor.default,
156+
_c10d_functional.all_gather_into_tensor.default,
157+
c10d_functional.wait_tensor.default,
158+
_c10d_functional.wait_tensor.default,
159+
]
160+
)
145161
def _(func, types, args, kwargs):
146162
x = args[0]
147163
if not isinstance(x, OptimState8bit):
@@ -154,3 +170,10 @@ def _(func, types, args, kwargs):
154170
x.qmap.clone(),
155171
x.signed,
156172
)
173+
174+
175+
if TORCH_VERSION_AT_LEAST_2_5:
176+
# Needed to load Float8Tensor with weights_only = True
177+
from torch.serialization import add_safe_globals
178+
179+
add_safe_globals([OptimState8bit])

0 commit comments

Comments
 (0)