11import torch
2+ import unittest
23from torchao .testing .utils import copy_tests , TorchAOTensorParallelTestCase
34from torch .testing ._internal .common_utils import run_tests
5+ from torch .testing ._internal import common_utils
46from torchao .quantization import int8_weight_only , float8_weight_only , float8_dynamic_activation_float8_weight
57from torchao .quantization .observer import PerRow , PerTensor
8+ import torch .distributed as dist
9+ from torch .distributed ._tensor import DTensor , Replicate , Shard , DeviceMesh
10+ from torch .testing ._internal .distributed ._tensor .common_dtensor import (
11+ DTensorTestBase ,
12+ with_comms ,
13+ NUM_DEVICES ,
14+ )
15+ from torchao .quantization .quant_api import quantize_
16+ from torchao .dtypes import AffineQuantizedTensor
17+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
618
719class TestInt8woAffineQuantizedTensorParallel (TorchAOTensorParallelTestCase ):
820 QUANT_METHOD_FN = staticmethod (int8_weight_only )
@@ -16,17 +28,131 @@ class TestFloat8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
1628
1729# Run only on H100
1830if torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 ):
19- class TestFloat8dqTensorAffineQuantizedTensorParallel (TorchAOTensorParallelTestCase ):
31+ class TestFloat8dqAffineQuantizedTensorParallel (DTensorTestBase ):
32+ """Basic test case for tensor subclasses
33+ """
34+ COMMON_DTYPES = [torch .bfloat16 , torch .float16 , torch .float32 ]
35+ TENSOR_SUBCLASS = AffineQuantizedTensor
36+ QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
37+ QUANT_METHOD_KWARGS = {}
38+
39+ @staticmethod
40+ def colwise_shard (m : torch .nn .Module , mesh : DeviceMesh ) -> torch .nn .Module :
41+ """
42+ Shard linear layer of the model in column-wise fashion
43+ """
44+ # Column-wise is wrt to A^T, so for A it is row-wise.
45+ # Number of rows per rank
46+ orig_weight = m .linear .weight
47+ n_local_rows = orig_weight .size (0 ) // mesh .size ()
48+ rank = mesh .get_local_rank ()
49+ local_shard = orig_weight [rank * n_local_rows : (rank + 1 ) * n_local_rows , :]
50+ # Construct DTensor from local shard
51+ dtensor = DTensor .from_local (local_shard , mesh , [Shard (0 )])
52+ # Replace parameter in module
53+ m .linear .weight = torch .nn .Parameter (
54+ dtensor , requires_grad = False
55+ )
56+ return m
57+
58+ @staticmethod
59+ def rowwise_shard (m : torch .nn .Module , mesh : DeviceMesh ) -> torch .nn .Module :
60+ """
61+ Shard linear layer of the model in row-wise fashion
62+ """
63+ # Row-wise is wrt to A^T, so for A it is column-wise.
64+ # Number of rows per rank
65+ orig_weight = m .linear .weight
66+ n_local_cols = orig_weight .size (1 ) // mesh .size ()
67+ rank = mesh .get_local_rank ()
68+ local_shard = orig_weight [:, rank * n_local_cols : (rank + 1 ) * n_local_cols ]
69+ # Construct DTensor from local shard
70+ dtensor = DTensor .from_local (local_shard , mesh , [Shard (1 )], run_check = True )
71+ # Replace parameter in module
72+ m .linear .weight = torch .nn .Parameter (
73+ dtensor , requires_grad = False
74+ )
75+ return m
76+
77+ def quantize (self , m : torch .nn .Module ) -> torch .nn .Module :
78+ """
79+ Quantize the model
80+ """
81+ quantize_ (m , self .QUANT_METHOD_FN (** self .QUANT_METHOD_KWARGS ))
82+ return m
83+
84+ def _test_tp (self , dtype ):
85+ device = "cuda"
86+ # To make sure different ranks create the same module
87+ torch .manual_seed (5 )
88+
89+ class M (torch .nn .Module ):
90+ def __init__ (self , in_features , out_features , ** kwargs ) -> None :
91+ super ().__init__ (** kwargs )
92+ self .linear = torch .nn .Linear (in_features , out_features , bias = False , device = "cuda" )
93+
94+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
95+ return self .linear (x )
96+
97+ # Get rank and device
98+ device = torch .device (f"cuda:{ self .rank % torch .cuda .device_count ()} " )
99+
100+ # Original model
101+ proj_up = M (1024 , 2048 ).to (device ).to (dtype )
102+ proj_dn = M (2048 , 1024 ).to (device ).to (dtype )
103+ example_input = 100 * torch .randn (128 , 1024 , device = device , dtype = dtype )
104+ y = proj_dn (proj_up (example_input ))
105+ # Quantize the model
106+ up_quant = self .quantize (proj_up )
107+ dn_quant = self .quantize (proj_dn )
108+ y_q = dn_quant (up_quant (example_input ))
109+
110+ mesh = self .build_device_mesh ()
111+ mesh .device_type = "cuda"
112+
113+ # Shard the models
114+ up_dist = self .colwise_shard (up_quant , mesh )
115+ dn_dist = self .rowwise_shard (dn_quant , mesh )
116+
117+ # We need to turn inputs into DTensor form as well -- just a format change
118+ input_dtensor = DTensor .from_local (
119+ example_input , mesh , [Replicate ()]
120+ )
121+
122+ y_d = dn_dist (up_dist (input_dtensor ))
123+
124+ if not TORCH_VERSION_AT_LEAST_2_5 :
125+ # Need torch 2.5 to support compiled tensor parallelism
126+ return
127+
128+ up_compiled = torch .compile (up_dist )
129+ y_up = up_compiled (input_dtensor )
130+ dn_compiled = torch .compile (dn_dist )
131+ y_dn = dn_compiled (y_up )
132+
133+ class TestFloat8dqTensorAffineQuantizedTensorParallel (TestFloat8dqAffineQuantizedTensorParallel ):
20134 QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
21135 QUANT_METHOD_KWARGS = {"granularity" : PerTensor ()}
22- copy_tests ( TorchAOTensorParallelTestCase , TestFloat8dqTensorAffineQuantizedTensorParallel , "fp8dqt_tp" )
136+ COMMON_DTYPES = [ torch . bfloat16 , torch . float16 , torch . float32 ]
23137
24- # Run only on H100
25- if torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 ):
26- class TestFloat8dqRowAffineQuantizedTensorParallel (TorchAOTensorParallelTestCase ):
138+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
139+ @with_comms
140+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
141+ def test_tp (self , dtype ):
142+ return self ._test_tp (dtype )
143+
144+ class TestFloat8dqRowAffineQuantizedTensorParallel (TestFloat8dqAffineQuantizedTensorParallel ):
27145 QUANT_METHOD_FN = staticmethod (float8_dynamic_activation_float8_weight )
28146 QUANT_METHOD_KWARGS = {"granularity" : PerRow ()}
29- copy_tests ( TorchAOTensorParallelTestCase , TestFloat8dqRowAffineQuantizedTensorParallel , "fp8dqr_tp" )
147+ COMMON_DTYPES = [ torch . bfloat16 ]
30148
149+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
150+ @with_comms
151+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
152+ def test_tp (self , dtype ):
153+ return self ._test_tp (dtype )
154+
155+ common_utils .instantiate_parametrized_tests (TestFloat8dqTensorAffineQuantizedTensorParallel )
156+ common_utils .instantiate_parametrized_tests (TestFloat8dqRowAffineQuantizedTensorParallel )
31157if __name__ == "__main__" :
32158 run_tests ()
0 commit comments