1919 quantize_4bit_with_qmap ,
2020 _fp32_to_bf16_sr ,
2121)
22- from torchao .utils import TORCH_VERSION_AT_LEAST_2_3 , TORCH_VERSION_AT_LEAST_2_4 , TORCH_VERSION_AT_LEAST_2_6
22+ from torchao .utils import (
23+ TORCH_VERSION_AT_LEAST_2_3 ,
24+ TORCH_VERSION_AT_LEAST_2_4 ,
25+ TORCH_VERSION_AT_LEAST_2_6 ,
26+ )
2327
2428try :
2529 import bitsandbytes as bnb
@@ -85,7 +89,9 @@ def test_bf16_stochastic_round(self, device, compile):
8589 x_rep = x .view (- 1 , 1 ).repeat (1 , 100_000 )
8690
8791 if compile :
88- x_rep_bf16 = torch .compile (_fp32_to_bf16_sr , fullgraph = True , dynamic = False )(x_rep )
92+ x_rep_bf16 = torch .compile (_fp32_to_bf16_sr , fullgraph = True , dynamic = False )(
93+ x_rep
94+ )
8995 else :
9096 x_rep_bf16 = _fp32_to_bf16_sr (x_rep )
9197
@@ -96,8 +102,13 @@ def test_bf16_stochastic_round(self, device, compile):
96102
97103
98104class TestOptim (TestCase ):
99- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_3 , reason = "requires PyTorch >= 2.3" )
100- @parametrize ("optim_name" , ["Adam8bit" , "AdamW8bit" , "Adam4bit" , "AdamW4bit" , "AdamFp8" , "AdamWFp8" ])
105+ @pytest .mark .skipif (
106+ not TORCH_VERSION_AT_LEAST_2_3 , reason = "requires PyTorch >= 2.3"
107+ )
108+ @parametrize (
109+ "optim_name" ,
110+ ["Adam8bit" , "AdamW8bit" , "Adam4bit" , "AdamW4bit" , "AdamFp8" , "AdamWFp8" ],
111+ )
101112 @parametrize ("dtype" , [torch .float32 , torch .bfloat16 ])
102113 @parametrize ("device" , _DEVICES )
103114 def test_optim_smoke (self , optim_name , dtype , device ):
@@ -120,7 +131,7 @@ def test_optim_smoke(self, optim_name, dtype, device):
120131 # test serialization. also test the case CUDA optim loads CPU state dict
121132 with tempfile .NamedTemporaryFile () as f :
122133 torch .save (optim .state_dict (), f .name )
123- state_dict = torch .load (f .name , map_location = "cpu" )
134+ state_dict = torch .load (f .name , map_location = "cpu" , weights_only = True )
124135
125136 model2 = copy .deepcopy (model )
126137 optim2 = getattr (low_bit_optim , optim_name )(model2 .parameters ())
@@ -141,19 +152,28 @@ def test_optim_smoke(self, optim_name, dtype, device):
141152 torch .testing .assert_close (p2 , p1 )
142153
143154 @pytest .mark .skipif (bnb is None , reason = "bitsandbytes is not available" )
144- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "bitsandbytes 8-bit Adam only works for CUDA" )
145- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_3 , reason = "requires PyTorch >= 2.3" )
155+ @pytest .mark .skipif (
156+ not torch .cuda .is_available (),
157+ reason = "bitsandbytes 8-bit Adam only works for CUDA" ,
158+ )
159+ @pytest .mark .skipif (
160+ not TORCH_VERSION_AT_LEAST_2_3 , reason = "requires PyTorch >= 2.3"
161+ )
146162 @parametrize ("optim_name" , ["Adam8bit" , "AdamW8bit" ])
147163 def test_optim_8bit_correctness (self , optim_name ):
148164 device = "cuda"
149- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (device )
165+ model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (
166+ device
167+ )
150168 model2 = copy .deepcopy (model1 )
151169
152170 # https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0
153171 block_size = 256 if Version (bnb .__version__ ) >= Version ("0.44.0" ) else 2048
154172
155173 optim1 = getattr (bnb .optim , optim_name )(model1 .parameters ())
156- optim2 = getattr (low_bit_optim , optim_name )(model2 .parameters (), block_size = block_size )
174+ optim2 = getattr (low_bit_optim , optim_name )(
175+ model2 .parameters (), block_size = block_size
176+ )
157177
158178 for _ in range (2 ):
159179 x = torch .randn (4 , 32 , device = device )
@@ -173,12 +193,18 @@ def test_optim_8bit_correctness(self, optim_name):
173193
174194 # this will not run in CI because we can't install lpmm
175195 @pytest .mark .skipif (lpmm is None , reason = "lpmm is not available" )
176- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "lpmm 4-bit Adam only works for CUDA" )
177- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_3 , reason = "requires PyTorch >= 2.3" )
196+ @pytest .mark .skipif (
197+ not torch .cuda .is_available (), reason = "lpmm 4-bit Adam only works for CUDA"
198+ )
199+ @pytest .mark .skipif (
200+ not TORCH_VERSION_AT_LEAST_2_3 , reason = "requires PyTorch >= 2.3"
201+ )
178202 @parametrize ("optim_name" , ["Adam4bit" , "AdamW4bit" ])
179203 def test_optim_4bit_correctness (self , optim_name ):
180204 device = "cuda"
181- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (device )
205+ model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (
206+ device
207+ )
182208 model2 = copy .deepcopy (model1 )
183209
184210 # lpmm doesn't have Adam. use AdamW with no weight decay instead.
@@ -206,17 +232,25 @@ def test_optim_4bit_correctness(self, optim_name):
206232 for p1 , p2 in zip (model1 .parameters (), model2 .parameters ()):
207233 torch .testing .assert_close (p2 , p1 , rtol = 1e-5 , atol = 1e-5 )
208234
209- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "optim CPU offload requires CUDA" )
235+ @pytest .mark .skipif (
236+ not torch .cuda .is_available (), reason = "optim CPU offload requires CUDA"
237+ )
210238 @parametrize ("offload_grad,grad_accum" , [(False , 1 ), (False , 2 ), (True , 1 )])
211239 def test_optim_cpu_offload_correctness (self , offload_grad , grad_accum ):
212240 device = "cuda"
213- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (device )
214- model1 [0 ].requires_grad_ (False ) # make sure it can work in the presence of non-trainable params
241+ model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (
242+ device
243+ )
244+ model1 [0 ].requires_grad_ (
245+ False
246+ ) # make sure it can work in the presence of non-trainable params
215247 model2 = copy .deepcopy (model1 )
216248
217249 optim1 = torch .optim .AdamW (model1 .parameters ())
218250 optim2 = low_bit_optim .CPUOffloadOptimizer (
219- model2 .parameters (), torch .optim .AdamW , offload_gradients = offload_grad ,
251+ model2 .parameters (),
252+ torch .optim .AdamW ,
253+ offload_gradients = offload_grad ,
220254 )
221255
222256 for _ in range (2 ):
@@ -234,11 +268,17 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
234268 for p1 , p2 in zip (model1 .parameters (), model2 .parameters ()):
235269 torch .testing .assert_close (p2 , p1 )
236270
237- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "optim CPU offload requires CUDA" )
271+ @pytest .mark .skipif (
272+ not torch .cuda .is_available (), reason = "optim CPU offload requires CUDA"
273+ )
238274 def test_optim_cpu_offload_save_load (self ):
239275 device = "cuda"
240- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (device )
241- optim1 = low_bit_optim .CPUOffloadOptimizer (model1 .parameters (), torch .optim .AdamW )
276+ model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (
277+ device
278+ )
279+ optim1 = low_bit_optim .CPUOffloadOptimizer (
280+ model1 .parameters (), torch .optim .AdamW
281+ )
242282
243283 for _ in range (2 ):
244284 x = torch .randn (4 , 32 , device = device )
@@ -249,11 +289,13 @@ def test_optim_cpu_offload_save_load(self):
249289 # save checkpoint. make sure it can be serialized by torch.save()
250290 with tempfile .NamedTemporaryFile () as file :
251291 torch .save (optim1 .state_dict (), file .name )
252- state_dict = torch .load (file .name , map_location = "cpu" )
292+ state_dict = torch .load (file .name , map_location = "cpu" , weights_only = True )
253293
254294 # resume training
255295 model2 = copy .deepcopy (model1 )
256- optim2 = low_bit_optim .CPUOffloadOptimizer (model2 .parameters (), torch .optim .AdamW )
296+ optim2 = low_bit_optim .CPUOffloadOptimizer (
297+ model2 .parameters (), torch .optim .AdamW
298+ )
257299 optim2 .load_state_dict (state_dict )
258300
259301 for _ in range (2 ):
@@ -273,13 +315,17 @@ def test_optim_cpu_offload_save_load(self):
273315 def test_optim_bf16_stochastic_round_correctness (self ):
274316 device = "cuda" if torch .cuda .is_available () else "cpu"
275317 torch .manual_seed (2024 )
276- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (device )
318+ model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (
319+ device
320+ )
277321 model2 = copy .deepcopy (model1 ).bfloat16 ()
278322
279323 # small LR so that weight update is small
280324 # when bf16_stochastic_round=False, the test will fail after 1 iteration
281325 optim1 = torch .optim .AdamW (model1 .parameters (), lr = 1e-5 )
282- optim2 = low_bit_optim ._AdamW (model2 .parameters (), lr = 1e-5 , bf16_stochastic_round = True )
326+ optim2 = low_bit_optim ._AdamW (
327+ model2 .parameters (), lr = 1e-5 , bf16_stochastic_round = True
328+ )
283329
284330 # overfit on this sample
285331 x = torch .randn (4 , 32 , device = device )
@@ -299,15 +345,19 @@ def test_optim_bf16_stochastic_round_correctness(self):
299345 optim2 .step ()
300346 optim2 .zero_grad ()
301347
302- torch .testing .assert_close (loss1 , loss2 , msg = lambda msg : f"Iteration { idx } . { msg } " )
348+ torch .testing .assert_close (
349+ loss1 , loss2 , msg = lambda msg : f"Iteration { idx } . { msg } "
350+ )
303351
304352
305353class TestFSDP2 (FSDPTest ):
306354 @property
307355 def world_size (self ) -> int :
308356 return 2
309357
310- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_6 , reason = "PyTorch>=2.6 is required." )
358+ @pytest .mark .skipif (
359+ not TORCH_VERSION_AT_LEAST_2_6 , reason = "PyTorch>=2.6 is required."
360+ )
311361 @skip_if_lt_x_gpu (2 )
312362 def test_fsdp2 (self ):
313363 optim_classes = [low_bit_optim .AdamW8bit , low_bit_optim .AdamW4bit ]
@@ -363,7 +413,9 @@ def _test_fsdp2(self, optim_cls):
363413 base_loss .backward ()
364414 for param in base_model .parameters ():
365415 if param .grad is not None :
366- torch .distributed .all_reduce (param .grad , op = torch .distributed .ReduceOp .AVG )
416+ torch .distributed .all_reduce (
417+ param .grad , op = torch .distributed .ReduceOp .AVG
418+ )
367419 base_optim .step ()
368420 self .assertEqual (fsdp_loss , base_loss )
369421
0 commit comments