11import copy
2+ import shutil
23import tempfile
4+ from pathlib import Path
35
46import pytest
57import torch
68from packaging .version import Version
79from torch import nn
10+ from torch .testing ._internal .common_distributed import skip_if_lt_x_gpu
11+ from torch .testing ._internal .common_fsdp import FSDPTest
812from torch .testing ._internal .common_utils import (
913 TestCase ,
1014 instantiate_parametrized_tests ,
1115 parametrize ,
1216 run_tests ,
1317)
14- from torch .testing ._internal .common_distributed import skip_if_lt_x_gpu
15- from torch .testing ._internal .common_fsdp import FSDPTest
18+
1619from torchao .prototype import low_bit_optim
1720from torchao .prototype .low_bit_optim .quant_utils import (
18- quantize_8bit_with_qmap ,
19- quantize_4bit_with_qmap ,
2021 _fp32_to_bf16_sr ,
22+ quantize_4bit_with_qmap ,
23+ quantize_8bit_with_qmap ,
2124)
25+ from torchao .prototype .low_bit_optim .subclass_4bit import OptimState4bit
26+ from torchao .prototype .low_bit_optim .subclass_8bit import OptimState8bit
27+ from torchao .prototype .low_bit_optim .subclass_fp8 import OptimStateFp8
2228from torchao .utils import (
23- TORCH_VERSION_AT_LEAST_2_3 ,
2429 TORCH_VERSION_AT_LEAST_2_4 ,
30+ TORCH_VERSION_AT_LEAST_2_5 ,
2531 TORCH_VERSION_AT_LEAST_2_6 ,
2632)
2733
@@ -88,23 +94,15 @@ def test_bf16_stochastic_round(self, device, compile):
8894 x = torch .rand (32 , device = device ) * 100
8995 x_rep = x .view (- 1 , 1 ).repeat (1 , 100_000 )
9096
91- if compile :
92- x_rep_bf16 = torch .compile (_fp32_to_bf16_sr , fullgraph = True , dynamic = False )(
93- x_rep
94- )
95- else :
96- x_rep_bf16 = _fp32_to_bf16_sr (x_rep )
97-
97+ func = torch .compile (_fp32_to_bf16_sr , fullgraph = True , dynamic = False , disable = not compile )
98+ x_rep_bf16 = func (x_rep )
9899 assert x_rep_bf16 .dtype is torch .bfloat16
99100
100101 # must cast BF16 tensor back to FP32 so that .mean() is accurate
101102 torch .testing .assert_close (x_rep_bf16 .float ().mean (1 ), x , atol = 3e-5 , rtol = 3e-5 )
102103
103104
104105class TestOptim (TestCase ):
105- @pytest .mark .skipif (
106- not TORCH_VERSION_AT_LEAST_2_3 , reason = "requires PyTorch >= 2.3"
107- )
108106 @parametrize (
109107 "optim_name" ,
110108 ["Adam8bit" , "AdamW8bit" , "Adam4bit" , "AdamW4bit" , "AdamFp8" , "AdamWFp8" ],
@@ -151,29 +149,46 @@ def test_optim_smoke(self, optim_name, dtype, device):
151149 for p1 , p2 in zip (model .parameters (), model2 .parameters ()):
152150 torch .testing .assert_close (p2 , p1 )
153151
152+ # aten.slice is required for dcp.load() when world size changes i.e. re-sharding
153+ # however, it's cumbersome to test it directly, since we would need to run distributed
154+ # test 2 times with different world size, and persist checkpoint across the 2 runs.
155+ # thus, we only test for the required op. note that future implementations of dcp.load()
156+ # may use other ops.
157+ @parametrize ("subclass" , [OptimState4bit , OptimState8bit , OptimStateFp8 ])
158+ @parametrize ("shape" , [(4096 ,), (256 , 256 )])
159+ @parametrize ("device" , _DEVICES )
160+ def test_subclass_slice (self , subclass , shape , device ):
161+ if subclass == OptimStateFp8 :
162+ if device == "cpu" and len (shape ) > 1 and not TORCH_VERSION_AT_LEAST_2_5 :
163+ pytest .skip ("fill_cpu not implemented for Float8_e4m3fn for torch<2.5" )
164+ if device == "cuda" and not TORCH_VERSION_AT_LEAST_2_4 :
165+ pytest .skip ("FP8 CUDA requires PyTorch >= 2.4" )
166+ if device == "cuda" and torch .cuda .get_device_capability () < (8 , 9 ):
167+ pytest .skip ("FP8 CUDA requires compute capability >= 8.9" )
168+
169+ tensor = subclass .zeros (shape , device = device )
170+ offset = shape [0 ] // 2
171+
172+ torch .testing .assert_close (tensor .dequantize ()[:offset ], tensor [:offset ].dequantize ())
173+ torch .testing .assert_close (tensor .dequantize ()[offset :offset * 2 ], tensor [offset :offset * 2 ].dequantize ())
174+
154175 @pytest .mark .skipif (bnb is None , reason = "bitsandbytes is not available" )
155176 @pytest .mark .skipif (
156177 not torch .cuda .is_available (),
157178 reason = "bitsandbytes 8-bit Adam only works for CUDA" ,
158179 )
159- @pytest .mark .skipif (
160- not TORCH_VERSION_AT_LEAST_2_3 , reason = "requires PyTorch >= 2.3"
161- )
162180 @parametrize ("optim_name" , ["Adam8bit" , "AdamW8bit" ])
163181 def test_optim_8bit_correctness (self , optim_name ):
164182 device = "cuda"
165- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (
166- device
167- )
183+ model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 ))
184+ model1 .to (device )
168185 model2 = copy .deepcopy (model1 )
169186
170187 # https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0
171188 block_size = 256 if Version (bnb .__version__ ) >= Version ("0.44.0" ) else 2048
172189
173190 optim1 = getattr (bnb .optim , optim_name )(model1 .parameters ())
174- optim2 = getattr (low_bit_optim , optim_name )(
175- model2 .parameters (), block_size = block_size
176- )
191+ optim2 = getattr (low_bit_optim , optim_name )(model2 .parameters (), block_size = block_size )
177192
178193 for _ in range (2 ):
179194 x = torch .randn (4 , 32 , device = device )
@@ -196,15 +211,11 @@ def test_optim_8bit_correctness(self, optim_name):
196211 @pytest .mark .skipif (
197212 not torch .cuda .is_available (), reason = "lpmm 4-bit Adam only works for CUDA"
198213 )
199- @pytest .mark .skipif (
200- not TORCH_VERSION_AT_LEAST_2_3 , reason = "requires PyTorch >= 2.3"
201- )
202214 @parametrize ("optim_name" , ["Adam4bit" , "AdamW4bit" ])
203215 def test_optim_4bit_correctness (self , optim_name ):
204216 device = "cuda"
205- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (
206- device
207- )
217+ model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 ))
218+ model1 .to (device )
208219 model2 = copy .deepcopy (model1 )
209220
210221 # lpmm doesn't have Adam. use AdamW with no weight decay instead.
@@ -238,12 +249,11 @@ def test_optim_4bit_correctness(self, optim_name):
238249 @parametrize ("offload_grad,grad_accum" , [(False , 1 ), (False , 2 ), (True , 1 )])
239250 def test_optim_cpu_offload_correctness (self , offload_grad , grad_accum ):
240251 device = "cuda"
241- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (
242- device
243- )
244- model1 [0 ].requires_grad_ (
245- False
246- ) # make sure it can work in the presence of non-trainable params
252+ model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 ))
253+ model1 .to (device )
254+
255+ # make sure it can work in the presence of non-trainable params
256+ model1 [0 ].requires_grad_ (False )
247257 model2 = copy .deepcopy (model1 )
248258
249259 optim1 = torch .optim .AdamW (model1 .parameters ())
@@ -273,12 +283,9 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
273283 )
274284 def test_optim_cpu_offload_save_load (self ):
275285 device = "cuda"
276- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (
277- device
278- )
279- optim1 = low_bit_optim .CPUOffloadOptimizer (
280- model1 .parameters (), torch .optim .AdamW
281- )
286+ model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 ))
287+ model1 .to (device )
288+ optim1 = low_bit_optim .CPUOffloadOptimizer (model1 .parameters (), torch .optim .AdamW )
282289
283290 for _ in range (2 ):
284291 x = torch .randn (4 , 32 , device = device )
@@ -293,9 +300,7 @@ def test_optim_cpu_offload_save_load(self):
293300
294301 # resume training
295302 model2 = copy .deepcopy (model1 )
296- optim2 = low_bit_optim .CPUOffloadOptimizer (
297- model2 .parameters (), torch .optim .AdamW
298- )
303+ optim2 = low_bit_optim .CPUOffloadOptimizer (model2 .parameters (), torch .optim .AdamW )
299304 optim2 .load_state_dict (state_dict )
300305
301306 for _ in range (2 ):
@@ -315,16 +320,17 @@ def test_optim_cpu_offload_save_load(self):
315320 def test_optim_bf16_stochastic_round_correctness (self ):
316321 device = "cuda" if torch .cuda .is_available () else "cpu"
317322 torch .manual_seed (2024 )
318- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (
319- device
320- )
323+ model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 ))
324+ model1 .to (device )
321325 model2 = copy .deepcopy (model1 ).bfloat16 ()
322326
323327 # small LR so that weight update is small
324328 # when bf16_stochastic_round=False, the test will fail after 1 iteration
325329 optim1 = torch .optim .AdamW (model1 .parameters (), lr = 1e-5 )
326330 optim2 = low_bit_optim ._AdamW (
327- model2 .parameters (), lr = 1e-5 , bf16_stochastic_round = True
331+ model2 .parameters (),
332+ lr = 1e-5 ,
333+ bf16_stochastic_round = True ,
328334 )
329335
330336 # overfit on this sample
@@ -350,10 +356,13 @@ def test_optim_bf16_stochastic_round_correctness(self):
350356 )
351357
352358
359+ _FSDP_WORLD_SIZE = 2
360+
361+
353362class TestFSDP2 (FSDPTest ):
354363 @property
355364 def world_size (self ) -> int :
356- return 2
365+ return _FSDP_WORLD_SIZE
357366
358367 @pytest .mark .skipif (
359368 not TORCH_VERSION_AT_LEAST_2_6 , reason = "PyTorch>=2.6 is required."
@@ -370,12 +379,12 @@ def test_fsdp2(self):
370379 )
371380
372381 def _test_fsdp2 (self , optim_cls ):
382+ import torch .distributed as dist
383+ import torch .distributed .checkpoint as dcp
384+ import torch .utils ._pytree as pytree
373385 from torch .distributed ._composable .fsdp import fully_shard
374- from torch .testing ._internal .distributed ._tensor .common_dtensor import (
375- ModelArgs ,
376- Transformer ,
377- TransformerBlock ,
378- )
386+ from torch .distributed .tensor import DTensor
387+ from torch .testing ._internal .distributed ._tensor .common_dtensor import ModelArgs , Transformer , TransformerBlock
379388
380389 batch_size = 3
381390 vocab_size = 1024
@@ -413,9 +422,7 @@ def _test_fsdp2(self, optim_cls):
413422 base_loss .backward ()
414423 for param in base_model .parameters ():
415424 if param .grad is not None :
416- torch .distributed .all_reduce (
417- param .grad , op = torch .distributed .ReduceOp .AVG
418- )
425+ dist .all_reduce (param .grad , op = dist .ReduceOp .AVG )
419426 base_optim .step ()
420427 self .assertEqual (fsdp_loss , base_loss )
421428
@@ -428,6 +435,39 @@ def _test_fsdp2(self, optim_cls):
428435
429436 self .assertEqual (base_exp_avg .dequantize (), full_fsdp_exp_avg .dequantize ())
430437
438+ # test for compatibility with dcp.save() and .load()
439+ checkpoint_id = f"_fsdp_low_bit_optim_{ optim_cls .__name__ } "
440+ if Path (checkpoint_id ).exists ():
441+ shutil .rmtree (checkpoint_id )
442+ dcp .save (fsdp_optim .state_dict (), checkpoint_id = checkpoint_id )
443+
444+ # normally we would want to use dcp.state_dict.get_optimizer_state_dict() to initialize optim states.
445+ # however, currently it does not respect tensor-ness of LR pytorch/pytorch#139575.
446+ # therefore, we have to manually initialize optim state here.
447+ resumed_fsdp_optim = optim_cls (fsdp_model .parameters (), lr = 1e-2 )
448+ for p in fsdp_model .parameters ():
449+ p .grad = torch .zeros_like (p )
450+
451+ # this will change model weights due to weight decay, but since we don't use the model anymore, it's fine.
452+ resumed_fsdp_optim .step ()
453+
454+ dcp .load (resumed_fsdp_optim .state_dict (), checkpoint_id = checkpoint_id )
455+ if dist .get_rank () == 0 :
456+ shutil .rmtree (checkpoint_id )
457+
458+ subclasses = (OptimState4bit , OptimState8bit , OptimStateFp8 )
459+
460+ for v1 , v2 in zip (pytree .tree_iter (resumed_fsdp_optim .state_dict ()), pytree .tree_iter (fsdp_optim .state_dict ())):
461+ assert v1 .__class__ == v2 .__class__ , (v1 .__class__ , v2 .__class__ )
462+ if isinstance (v1 , DTensor ):
463+ v1 = v1 .to_local ()
464+ v2 = v2 .to_local ()
465+ assert v1 .__class__ == v2 .__class__ , (v1 .__class__ , v2 .__class__ )
466+ if isinstance (v1 , subclasses ):
467+ v1 = v1 .dequantize ()
468+ v2 = v2 .dequantize ()
469+ self .assertEqual (v1 , v2 )
470+
431471
432472instantiate_parametrized_tests (TestQuantize )
433473instantiate_parametrized_tests (TestOptim )
0 commit comments