11import torch
22import unittest
3- from torchao .testing .utils import copy_tests , TorchAOTensorParallelTestCase
43from torch .testing ._internal .common_utils import run_tests
54from torch .testing ._internal import common_utils
6- from torchao .quantization import int8_weight_only , float8_weight_only , float8_dynamic_activation_float8_weight
5+ from torchao .quantization import (
6+ int4_weight_only ,
7+ int8_weight_only ,
8+ float8_weight_only ,
9+ float8_dynamic_activation_float8_weight ,
10+ )
711from torchao .quantization .observer import PerRow , PerTensor
812import torch .distributed as dist
913from torch .distributed ._tensor import DTensor , Replicate , Shard , DeviceMesh
1620from torchao .dtypes import AffineQuantizedTensor
1721from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
1822
19- class TestInt8woAffineQuantizedTensorParallel (TorchAOTensorParallelTestCase ):
23+ class TestAffineQuantizedTensorParallel (DTensorTestBase ):
24+ """Basic test case for tensor subclasses
25+ """
2026 QUANT_METHOD_FN = staticmethod (int8_weight_only )
21- copy_tests ( TorchAOTensorParallelTestCase , TestInt8woAffineQuantizedTensorParallel , "int8wo_tp" )
27+ QUANT_METHOD_KWARGS = {}
2228
23- # Run only on H100
24- if torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 ):
25- class TestFloat8woAffineQuantizedTensorParallel (TorchAOTensorParallelTestCase ):
26- QUANT_METHOD_FN = staticmethod (float8_weight_only )
27- copy_tests (TorchAOTensorParallelTestCase , TestFloat8woAffineQuantizedTensorParallel , "fp8wo_tp" )
29+ @staticmethod
30+ def colwise_shard (m : torch .nn .Module , mesh : DeviceMesh ) -> torch .nn .Module :
31+ """
32+ Shard linear layer of the model in column-wise fashion
33+ """
34+ # Column-wise is wrt to A^T, so for A it is row-wise.
35+ # Number of rows per rank
36+ orig_weight = m .linear .weight
37+ n_local_rows = orig_weight .size (0 ) // mesh .size ()
38+ rank = mesh .get_local_rank ()
39+ local_shard = orig_weight [rank * n_local_rows : (rank + 1 ) * n_local_rows , :]
40+ # Construct DTensor from local shard
41+ dtensor = DTensor .from_local (local_shard , mesh , [Shard (0 )])
42+ # Replace parameter in module
43+ m .linear .weight = torch .nn .Parameter (
44+ dtensor , requires_grad = False
45+ )
46+ return m
47+
48+ @staticmethod
49+ def rowwise_shard (m : torch .nn .Module , mesh : DeviceMesh ) -> torch .nn .Module :
50+ """
51+ Shard linear layer of the model in row-wise fashion
52+ """
53+ # Row-wise is wrt to A^T, so for A it is column-wise.
54+ # Number of rows per rank
55+ orig_weight = m .linear .weight
56+ n_local_cols = orig_weight .size (1 ) // mesh .size ()
57+ rank = mesh .get_local_rank ()
58+ local_shard = orig_weight [:, rank * n_local_cols : (rank + 1 ) * n_local_cols ]
59+ # Construct DTensor from local shard
60+ dtensor = DTensor .from_local (local_shard , mesh , [Shard (1 )], run_check = True )
61+ # Replace parameter in module
62+ m .linear .weight = torch .nn .Parameter (
63+ dtensor , requires_grad = False
64+ )
65+ return m
66+
67+ def quantize (self , m : torch .nn .Module ) -> torch .nn .Module :
68+ """
69+ Quantize the model
70+ """
71+ quantize_ (m , self .QUANT_METHOD_FN (** self .QUANT_METHOD_KWARGS ))
72+ return m
73+
74+ def _test_tp (self , dtype ):
75+ device = "cuda"
76+ # To make sure different ranks create the same module
77+ torch .manual_seed (5 )
78+
79+ class M (torch .nn .Module ):
80+ def __init__ (self , in_features , out_features , ** kwargs ) -> None :
81+ super ().__init__ (** kwargs )
82+ self .linear = torch .nn .Linear (in_features , out_features , bias = False , device = "cuda" )
83+
84+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
85+ return self .linear (x )
86+
87+ # Get rank and device
88+ device = torch .device (f"cuda:{ self .rank % torch .cuda .device_count ()} " )
89+
90+ # Original model
91+ proj_up = M (1024 , 2048 ).to (device ).to (dtype )
92+ proj_dn = M (2048 , 1024 ).to (device ).to (dtype )
93+ example_input = 100 * torch .randn (128 , 1024 , device = device , dtype = dtype )
94+ y = proj_dn (proj_up (example_input ))
95+ # Quantize the model
96+ up_quant = self .quantize (proj_up )
97+ dn_quant = self .quantize (proj_dn )
98+ y_q = dn_quant (up_quant (example_input ))
99+
100+ mesh = self .build_device_mesh ()
101+ mesh .device_type = "cuda"
102+
103+ # Shard the models
104+ up_dist = self .colwise_shard (up_quant , mesh )
105+ dn_dist = self .rowwise_shard (dn_quant , mesh )
106+
107+ # We need to turn inputs into DTensor form as well -- just a format change
108+ input_dtensor = DTensor .from_local (
109+ example_input , mesh , [Replicate ()]
110+ )
111+
112+ y_d = dn_dist (up_dist (input_dtensor ))
113+
114+ if not TORCH_VERSION_AT_LEAST_2_5 :
115+ # Need torch 2.5 to support compiled tensor parallelism
116+ return
117+
118+ up_compiled = torch .compile (up_dist )
119+ y_up = up_compiled (input_dtensor )
120+ dn_compiled = torch .compile (dn_dist )
121+ y_dn = dn_compiled (y_up )
122+
123+
124+ class TestInt8woAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
125+ QUANT_METHOD_FN = staticmethod (int8_weight_only )
126+ COMMON_DTYPES = [torch .bfloat16 , torch .float16 , torch .float32 ]
127+
128+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
129+ @with_comms
130+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
131+ def test_tp (self , dtype ):
132+ return self ._test_tp (dtype )
133+
134+
135+ class TestInt4woAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
136+ QUANT_METHOD_FN = staticmethod (int4_weight_only )
137+ COMMON_DTYPES = [torch .bfloat16 ]
138+
139+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
140+ @with_comms
141+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
142+ def test_tp (self , dtype ):
143+ return self ._test_tp (dtype )
144+
145+ common_utils .instantiate_parametrized_tests (TestInt8woAffineQuantizedTensorParallel )
146+ common_utils .instantiate_parametrized_tests (TestInt4woAffineQuantizedTensorParallel )
28147
29148# Run only on H100
30149if torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 ):
31- class TestFloat8dqAffineQuantizedTensorParallel (DTensorTestBase ):
32- """Basic test case for tensor subclasses
33- """
150+ class TestFloat8woAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
151+ QUANT_METHOD_FN = staticmethod (float8_weight_only )
34152 COMMON_DTYPES = [torch .bfloat16 , torch .float16 , torch .float32 ]
35- TENSOR_SUBCLASS = AffineQuantizedTensor
36- QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
37- QUANT_METHOD_KWARGS = {}
38-
39- @staticmethod
40- def colwise_shard (m : torch .nn .Module , mesh : DeviceMesh ) -> torch .nn .Module :
41- """
42- Shard linear layer of the model in column-wise fashion
43- """
44- # Column-wise is wrt to A^T, so for A it is row-wise.
45- # Number of rows per rank
46- orig_weight = m .linear .weight
47- n_local_rows = orig_weight .size (0 ) // mesh .size ()
48- rank = mesh .get_local_rank ()
49- local_shard = orig_weight [rank * n_local_rows : (rank + 1 ) * n_local_rows , :]
50- # Construct DTensor from local shard
51- dtensor = DTensor .from_local (local_shard , mesh , [Shard (0 )])
52- # Replace parameter in module
53- m .linear .weight = torch .nn .Parameter (
54- dtensor , requires_grad = False
55- )
56- return m
57-
58- @staticmethod
59- def rowwise_shard (m : torch .nn .Module , mesh : DeviceMesh ) -> torch .nn .Module :
60- """
61- Shard linear layer of the model in row-wise fashion
62- """
63- # Row-wise is wrt to A^T, so for A it is column-wise.
64- # Number of rows per rank
65- orig_weight = m .linear .weight
66- n_local_cols = orig_weight .size (1 ) // mesh .size ()
67- rank = mesh .get_local_rank ()
68- local_shard = orig_weight [:, rank * n_local_cols : (rank + 1 ) * n_local_cols ]
69- # Construct DTensor from local shard
70- dtensor = DTensor .from_local (local_shard , mesh , [Shard (1 )], run_check = True )
71- # Replace parameter in module
72- m .linear .weight = torch .nn .Parameter (
73- dtensor , requires_grad = False
74- )
75- return m
76-
77- def quantize (self , m : torch .nn .Module ) -> torch .nn .Module :
78- """
79- Quantize the model
80- """
81- quantize_ (m , self .QUANT_METHOD_FN (** self .QUANT_METHOD_KWARGS ))
82- return m
83-
84- def _test_tp (self , dtype ):
85- device = "cuda"
86- # To make sure different ranks create the same module
87- torch .manual_seed (5 )
88-
89- class M (torch .nn .Module ):
90- def __init__ (self , in_features , out_features , ** kwargs ) -> None :
91- super ().__init__ (** kwargs )
92- self .linear = torch .nn .Linear (in_features , out_features , bias = False , device = "cuda" )
93-
94- def forward (self , x : torch .Tensor ) -> torch .Tensor :
95- return self .linear (x )
96-
97- # Get rank and device
98- device = torch .device (f"cuda:{ self .rank % torch .cuda .device_count ()} " )
99-
100- # Original model
101- proj_up = M (1024 , 2048 ).to (device ).to (dtype )
102- proj_dn = M (2048 , 1024 ).to (device ).to (dtype )
103- example_input = 100 * torch .randn (128 , 1024 , device = device , dtype = dtype )
104- y = proj_dn (proj_up (example_input ))
105- # Quantize the model
106- up_quant = self .quantize (proj_up )
107- dn_quant = self .quantize (proj_dn )
108- y_q = dn_quant (up_quant (example_input ))
109-
110- mesh = self .build_device_mesh ()
111- mesh .device_type = "cuda"
112-
113- # Shard the models
114- up_dist = self .colwise_shard (up_quant , mesh )
115- dn_dist = self .rowwise_shard (dn_quant , mesh )
116-
117- # We need to turn inputs into DTensor form as well -- just a format change
118- input_dtensor = DTensor .from_local (
119- example_input , mesh , [Replicate ()]
120- )
121-
122- y_d = dn_dist (up_dist (input_dtensor ))
123-
124- if not TORCH_VERSION_AT_LEAST_2_5 :
125- # Need torch 2.5 to support compiled tensor parallelism
126- return
127-
128- up_compiled = torch .compile (up_dist )
129- y_up = up_compiled (input_dtensor )
130- dn_compiled = torch .compile (dn_dist )
131- y_dn = dn_compiled (y_up )
132-
133- class TestFloat8dqTensorAffineQuantizedTensorParallel (TestFloat8dqAffineQuantizedTensorParallel ):
153+
154+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
155+ @with_comms
156+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
157+ def test_tp (self , dtype ):
158+ return self ._test_tp (dtype )
159+
160+ class TestFloat8dqTensorAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
134161 QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
135162 QUANT_METHOD_KWARGS = {"granularity" : PerTensor ()}
136163 COMMON_DTYPES = [torch .bfloat16 , torch .float16 , torch .float32 ]
@@ -141,7 +168,7 @@ class TestFloat8dqTensorAffineQuantizedTensorParallel(TestFloat8dqAffineQuantize
141168 def test_tp (self , dtype ):
142169 return self ._test_tp (dtype )
143170
144- class TestFloat8dqRowAffineQuantizedTensorParallel (TestFloat8dqAffineQuantizedTensorParallel ):
171+ class TestFloat8dqRowAffineQuantizedTensorParallel (TestAffineQuantizedTensorParallel ):
145172 QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
146173 QUANT_METHOD_KWARGS = {"granularity" : PerRow ()}
147174 COMMON_DTYPES = [torch .bfloat16 ]
@@ -151,7 +178,7 @@ class TestFloat8dqRowAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTe
151178 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
152179 def test_tp (self , dtype ):
153180 return self ._test_tp (dtype )
154-
181+
155182 common_utils .instantiate_parametrized_tests (TestFloat8dqTensorAffineQuantizedTensorParallel )
156183 common_utils .instantiate_parametrized_tests (TestFloat8dqRowAffineQuantizedTensorParallel )
157184if __name__ == "__main__" :
0 commit comments