From 1494636ca31670e1bc7e25e9c0c86fca7a97b0b2 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 24 Sep 2024 15:45:01 -0700 Subject: [PATCH 01/14] [WIP] Supporting tensor parallelism for int8 weight only quant Summary: following https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/tensor_parallel.py we can support tensor parallelism for int8 weight only quant, this is needed for torchchat Test Plan: python test/dtypes/test_affine_quantized_tensor_parallel.py Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_affine_quantized.py | 1 + .../test_affine_quantized_tensor_parallel.py | 12 ++ torchao/__init__.py | 2 + torchao/testing/utils.py | 128 ++++++++++++++++++ 4 files changed, 143 insertions(+) create mode 100644 test/dtypes/test_affine_quantized_tensor_parallel.py diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 158d003a36..2265be31ef 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -135,5 +135,6 @@ def test_print_quantized_module(self, apply_quant): common_utils.instantiate_parametrized_tests(TestAffineQuantized) + if __name__ == "__main__": run_tests() diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py new file mode 100644 index 0000000000..9992c2c9fb --- /dev/null +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -0,0 +1,12 @@ +from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase +from torch.testing._internal.common_utils import run_tests +from torchao.quantization import int8_weight_only + +class TestAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase): + QUANT_METHOD_FN = int8_weight_only + + +copy_tests(TorchAOTensorParallelTestCase, TestAffineQuantizedTensorParallel, "aqt_tp") + +if __name__ == "__main__": + run_tests() diff --git a/torchao/__init__.py b/torchao/__init__.py index dce378411c..0d1230ba93 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -33,11 +33,13 @@ quantize_, ) from . import dtypes +from . import testing __all__ = [ "dtypes", "autoquant", "quantize_", + "testing", ] # test-pytorchbot diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 48a171a75f..1f3edf0587 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -3,11 +3,13 @@ import copy import torch import torchao +import os from torch.testing._internal import common_utils from torchao.dtypes import AffineQuantizedTensor from torchao.dtypes import to_affine_quantized_intx from torchao.quantization.quant_primitives import MappingType +from torchao.quantization import quantize_, int4_weight_only """ How to use: @@ -213,10 +215,136 @@ def test_linear_compile(self, device, dtype): lp_res = torch.compile(l)(hp_act_tensor) self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) +import torch.distributed as dist +from torch.distributed import DeviceMesh +from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, + NUM_DEVICES, +) + +class TorchAOTensorParallelTestCase(DTensorTestBase): + """Basic test case for tensor subclasses + """ + COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] + + TENSOR_SUBCLASS = AffineQuantizedTensor + QUANT_METHOD_FN = int4_weight_only + QUANT_METHOD_KWARGS = {} + + # def setUp(self) -> None: + # # Create a device mesh + # world_size = int(os.environ["WORLD_SIZE"]) + # if not dist.is_initialized(): + # dist.init_process_group(backend="nccl") + # self.mesh = dist.init_device_mesh("cuda", (world_size,)) + + # def tearDown(self) -> None: + # dist.destroy_process_group() + @staticmethod + def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: + """ + Shard linear layer of the model in column-wise fashion + """ + # Column-wise is wrt to A^T, so for A it is row-wise. + # Number of rows per rank + orig_weight = m.linear.weight + n_local_rows = orig_weight.size(0) // mesh.size() + rank = mesh.get_local_rank() + local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :] + # Construct DTensor from local shard + dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) + # Replace parameter in module + m.linear.weight = torch.nn.Parameter( + dtensor, requires_grad=False + ) + return m + + @staticmethod + def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: + """ + Shard linear layer of the model in row-wise fashion + """ + # Row-wise is wrt to A^T, so for A it is column-wise. + # Number of rows per rank + orig_weight = m.linear.weight + n_local_cols = orig_weight.size(1) // mesh.size() + rank = mesh.get_local_rank() + local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols] + # Construct DTensor from local shard + dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)]) + # Replace parameter in module + m.linear.weight = torch.nn.Parameter( + dtensor, requires_grad=False + ) + return m + + def quantize(self, m: torch.nn.Module) -> torch.nn.Module: + """ + Quantize the model + """ + quantize_(m, self.QUANT_METHOD_FN(**(self.QUANT_METHOD_KWARGS))) + return m + + # @common_utils.parametrize("device", COMMON_DEVICES) + # @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_tp(self): + device = "cuda" + dtype = torch.bfloat16 + # To make sure different ranks create the same module + torch.manual_seed(5) + + class M(torch.nn.Module): + def __init__(self, in_features, out_features, **kwargs) -> None: + super().__init__(**kwargs) + self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + # Get rank and device + rank = int(os.environ["RANK"]) + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + + # Original model + proj_up = M(1024, 2048).to(device) + proj_dn = M(2048, 1024).to(device) + example_input = 100 * torch.randn(128, 1024, device=device) + y = proj_dn(proj_up(example_input)) + + # Quantize the model + up_quant = self.quantize(proj_up) + dn_quant = self.quantize(proj_dn) + y_q = dn_quant(up_quant(example_input)) + print("Quantization works!") + + mesh = self.build_device_mesh() + # Shard the models + up_dist = colwise_shard(up_quant, mesh) + dn_dist = rowwise_shard(dn_quant, mesh) + + # We need to turn inputs into DTensor form as well -- just a format change + input_dtensor = DTensor.from_local( + example_input, mesh, [Replicate()] + ) + + y_d = dn_dist(up_dist(input_dtensor)) + print("Distributed result:", y_d) + print("Distributed works!") + + up_compiled = torch.compile(up_dist) + y_up = up_compiled(input_dtensor) + dn_compiled = torch.compile(dn_dist) + y_dn = dn_compiled(y_up) + print("compiled result:", y_dn) + print("torch.compile works!") common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase) +common_utils.instantiate_parametrized_tests(TorchAOTensorParallelTestCase) if __name__ == "__main__": unittest.main() From cc938f7fde55946f30bf13309d4a8635aa8b6750 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 24 Sep 2024 19:14:26 -0700 Subject: [PATCH 02/14] implement tp for aqt --- .../test_affine_quantized_tensor_parallel.py | 2 +- torchao/dtypes/affine_quantized_tensor.py | 50 ++++++++++++++++++- torchao/testing/utils.py | 26 ++++------ torchao/utils.py | 24 +++++++++ .../my_dtype_tensor_subclass.py | 33 ++---------- .../developer_api_guide/tensor_parallel.py | 3 +- 6 files changed, 91 insertions(+), 47 deletions(-) diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 9992c2c9fb..0ca8793521 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -3,7 +3,7 @@ from torchao.quantization import int8_weight_only class TestAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase): - QUANT_METHOD_FN = int8_weight_only + QUANT_METHOD_FN = staticmethod(int8_weight_only) copy_tests(TorchAOTensorParallelTestCase, TestAffineQuantizedTensorParallel, "aqt_tp") diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e00576263f..952777c966 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -38,7 +38,8 @@ find_multiple, TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_5, - _is_float8_type + _is_float8_type, + fill_defaults, ) import logging @@ -599,13 +600,25 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) - if func is aten.t.default: + elif func is aten.t.default: tensor = args[0] new = tensor.__class__( tensor.int_data.view(tensor.shape[::-1]), tensor.scale, tensor.zero_point, tensor.layout_type ) return return_and_correct_aliasing(func, args, kwargs, new) + elif func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) + ) + elif dim == 1: + assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" + return PlainAQTLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.zero_point.view(-1), self.layout_type) + else: + raise NotImplementedError(f"PlainAQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError( f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported" ) @@ -1595,6 +1608,39 @@ def _(func, types, args, kwargs): ) return return_and_correct_aliasing(func, args, kwargs, new) +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" + if end >= self.shape[dim]: + end = self.shape[dim] + shape = list(self.shape) + shape[dim] = end - start + block_size = self.block_size + assert len(block_size) == 2, f"Slice only works for 2d block_size right now, got: {block_size}" + # with slice, some shape dimension might be smaller than block_size dimension, so + # we need to make sure there is no overflow + block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) + new = self.__class__(aten.slice.Tensor(self.layout_tensor, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + return return_and_correct_aliasing(func, args, kwargs, new) + +# this is needed for DTensor.from_local() and for flattening tensor +@implements(aten.view.default) +def _(func, types, args, kwargs): + self, shape = args + + if tuple(self.shape) == tuple(shape): + return self.__class__(self.layout_tensor, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + + if len(shape) == 1 and shape[0] == -1: + assert len(self.block_size) == 2 and self.block_size[0] == 1 + block_size = (block_size[1],) + return self.__class__(self.layout_tensor, block_size, (x.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + + raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]") + + to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 1f3edf0587..ce682bc254 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -9,7 +9,7 @@ from torchao.dtypes import AffineQuantizedTensor from torchao.dtypes import to_affine_quantized_intx from torchao.quantization.quant_primitives import MappingType -from torchao.quantization import quantize_, int4_weight_only +from torchao.quantization import quantize_, int8_weight_only """ How to use: @@ -231,7 +231,8 @@ class TorchAOTensorParallelTestCase(DTensorTestBase): COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] TENSOR_SUBCLASS = AffineQuantizedTensor - QUANT_METHOD_FN = int4_weight_only + # QUANT_METHOD_FN = staticmethod(int4_weight_only) + QUANT_METHOD_FN = staticmethod(int8_weight_only) QUANT_METHOD_KWARGS = {} # def setUp(self) -> None: @@ -286,11 +287,12 @@ def quantize(self, m: torch.nn.Module) -> torch.nn.Module: """ Quantize the model """ - quantize_(m, self.QUANT_METHOD_FN(**(self.QUANT_METHOD_KWARGS))) + quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS)) return m # @common_utils.parametrize("device", COMMON_DEVICES) # @common_utils.parametrize("dtype", COMMON_DTYPES) + @with_comms def test_tp(self): device = "cuda" dtype = torch.bfloat16 @@ -306,25 +308,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) # Get rank and device - rank = int(os.environ["RANK"]) - device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}") # Original model - proj_up = M(1024, 2048).to(device) - proj_dn = M(2048, 1024).to(device) - example_input = 100 * torch.randn(128, 1024, device=device) + proj_up = M(1024, 2048).to(device).to(dtype) + proj_dn = M(2048, 1024).to(device).to(dtype) + example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype) y = proj_dn(proj_up(example_input)) # Quantize the model up_quant = self.quantize(proj_up) dn_quant = self.quantize(proj_dn) y_q = dn_quant(up_quant(example_input)) - print("Quantization works!") mesh = self.build_device_mesh() # Shard the models - up_dist = colwise_shard(up_quant, mesh) - dn_dist = rowwise_shard(dn_quant, mesh) + up_dist = self.colwise_shard(up_quant, mesh) + dn_dist = self.rowwise_shard(dn_quant, mesh) # We need to turn inputs into DTensor form as well -- just a format change input_dtensor = DTensor.from_local( @@ -332,15 +332,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) y_d = dn_dist(up_dist(input_dtensor)) - print("Distributed result:", y_d) - print("Distributed works!") up_compiled = torch.compile(up_dist) y_up = up_compiled(input_dtensor) dn_compiled = torch.compile(dn_dist) y_dn = dn_compiled(y_up) - print("compiled result:", y_dn) - print("torch.compile works!") common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase) diff --git a/torchao/utils.py b/torchao/utils.py index 1f4f66e1f4..f1248f67a6 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -493,6 +493,30 @@ def _get_to_kwargs(self, *args, **kwargs): } return kwargs +def fill_defaults(args, n, defaults_tail): + """ + __torch_dispatch__ doesn't guarantee the number of arguments you are + passed (e.g., defaulted arguments are not passed); but usually it is + convenient to pad out the arguments list with defaults. This function + helps you do that. + Args: + args: the list of positional arguments passed to __torch_dispatch__ + n: the number of arguments you are expecting to get + defaults_tail: default values for the arguments, starting from the + end of the list + Example: + >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) + [1, 2, 3, 4, 5] + >>> fill_defaults([1, 2, 3], 5, [None, None, None]) + [1, 2, 3, None, None]] + """ + if n - len(defaults_tail) > len(args): + raise RuntimeError("not enough defaults to fill arguments") + r = list(args) + for i in range(len(args), n): + r.append(defaults_tail[i - n + len(defaults_tail)]) + return r + ## Deprecated, will be deleted in the future def _torch_version_at_least(min_version): diff --git a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py index 03b0d31590..bc85d26f5d 100644 --- a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py @@ -25,36 +25,13 @@ LayoutType, PlainLayoutType, ) -from torchao.utils import TorchAOBaseTensor +from torchao.utils import ( + TorchAOBaseTensor, + fill_defaults, +) aten = torch.ops.aten -# TODO: move to torchao/utils.py -def fill_defaults(args, n, defaults_tail): - """ - __torch_dispatch__ doesn't guarantee the number of arguments you are - passed (e.g., defaulted arguments are not passed); but usually it is - convenient to pad out the arguments list with defaults. This function - helps you do that. - Args: - args: the list of positional arguments passed to __torch_dispatch__ - n: the number of arguments you are expecting to get - defaults_tail: default values for the arguments, starting from the - end of the list - Example: - >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) - [1, 2, 3, 4, 5] - >>> fill_defaults([1, 2, 3], 5, [None, None, None]) - [1, 2, 3, None, None]] - """ - if n - len(defaults_tail) > len(args): - raise RuntimeError("not enough defaults to fill arguments") - r = list(args) - for i in range(len(args), n): - r.append(defaults_tail[i - n + len(defaults_tail)]) - return r - - ############################### # Base Layout Tensor Subclass # ############################### @@ -327,7 +304,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) ) elif dim == 1: - return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1, 1), self.transposed, self.layout_type) + return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.transposed, self.layout_type) else: raise NotImplementedError(f"PlainMyDTypeLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") elif func is aten.t.default: diff --git a/tutorials/developer_api_guide/tensor_parallel.py b/tutorials/developer_api_guide/tensor_parallel.py index db610a71fa..0ed3bc9a29 100644 --- a/tutorials/developer_api_guide/tensor_parallel.py +++ b/tutorials/developer_api_guide/tensor_parallel.py @@ -5,7 +5,8 @@ from torch.distributed import DeviceMesh from torch.distributed.tensor import DTensor, Replicate, Shard, Placement from torch.utils._python_dispatch import return_and_correct_aliasing -from my_dtype_tensor_subclass import MyDTypeTensor, fill_defaults +from my_dtype_tensor_subclass import MyDTypeTensor +from torchao.utils import fill_defaults # a tensor subclass that supports tensor parallelism with DTensor class MyDTypeTensorTP(MyDTypeTensor): From 27e4238f6ad2957155506b61bddd07f0b4cdd2a3 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 24 Sep 2024 21:33:00 -0700 Subject: [PATCH 03/14] fixes --- .../test_affine_quantized_tensor_parallel.py | 2 +- torchao/dtypes/affine_quantized_tensor.py | 2 +- torchao/testing/utils.py | 19 +++---------------- 3 files changed, 5 insertions(+), 18 deletions(-) diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 0ca8793521..8c1301226f 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -3,7 +3,7 @@ from torchao.quantization import int8_weight_only class TestAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase): - QUANT_METHOD_FN = staticmethod(int8_weight_only) + pass copy_tests(TorchAOTensorParallelTestCase, TestAffineQuantizedTensorParallel, "aqt_tp") diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 952777c966..67ce606184 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1638,7 +1638,7 @@ def _(func, types, args, kwargs): block_size = (block_size[1],) return self.__class__(self.layout_tensor, block_size, (x.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) - raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]") + raise ValueError(f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]") to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index ce682bc254..611a6eaf36 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -231,20 +231,9 @@ class TorchAOTensorParallelTestCase(DTensorTestBase): COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] TENSOR_SUBCLASS = AffineQuantizedTensor - # QUANT_METHOD_FN = staticmethod(int4_weight_only) QUANT_METHOD_FN = staticmethod(int8_weight_only) QUANT_METHOD_KWARGS = {} - # def setUp(self) -> None: - # # Create a device mesh - # world_size = int(os.environ["WORLD_SIZE"]) - # if not dist.is_initialized(): - # dist.init_process_group(backend="nccl") - # self.mesh = dist.init_device_mesh("cuda", (world_size,)) - - # def tearDown(self) -> None: - # dist.destroy_process_group() - @staticmethod def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: """ @@ -290,12 +279,10 @@ def quantize(self, m: torch.nn.Module) -> torch.nn.Module: quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS)) return m - # @common_utils.parametrize("device", COMMON_DEVICES) - # @common_utils.parametrize("dtype", COMMON_DTYPES) + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) @with_comms - def test_tp(self): - device = "cuda" - dtype = torch.bfloat16 + def test_tp(self, device, dtype): # To make sure different ranks create the same module torch.manual_seed(5) From 1b6e42cba5428fdd79e1b44155c19b168869535e Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 25 Sep 2024 10:51:46 -0700 Subject: [PATCH 04/14] import fix --- torchao/testing/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 611a6eaf36..b467ac7cc7 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -216,8 +216,7 @@ def test_linear_compile(self, device, dtype): self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) import torch.distributed as dist -from torch.distributed import DeviceMesh -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, From 4553cd990200a914da806df915aa4a2dcff0ca38 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 25 Sep 2024 11:05:14 -0700 Subject: [PATCH 05/14] remove cpu test --- torchao/testing/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index b467ac7cc7..34600c7b61 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -226,7 +226,7 @@ def test_linear_compile(self, device, dtype): class TorchAOTensorParallelTestCase(DTensorTestBase): """Basic test case for tensor subclasses """ - COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + COMMON_DEVICES = (["cuda"] if torch.cuda.is_available() else []) COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] TENSOR_SUBCLASS = AffineQuantizedTensor From 771868fba402d8dde8ddbdfb0e77035ef510fdf2 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 25 Sep 2024 11:08:01 -0700 Subject: [PATCH 06/14] fix --- torchao/dtypes/affine_quantized_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 67ce606184..508f22f14c 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1636,7 +1636,7 @@ def _(func, types, args, kwargs): if len(shape) == 1 and shape[0] == -1: assert len(self.block_size) == 2 and self.block_size[0] == 1 block_size = (block_size[1],) - return self.__class__(self.layout_tensor, block_size, (x.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + return self.__class__(self.layout_tensor, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) raise ValueError(f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]") From 25e19a9709eb36de27c30b8e79e2f6e5aecc5b73 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 25 Sep 2024 11:10:59 -0700 Subject: [PATCH 07/14] fix --- torchao/dtypes/affine_quantized_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 508f22f14c..85b5ab3cf3 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1635,7 +1635,7 @@ def _(func, types, args, kwargs): if len(shape) == 1 and shape[0] == -1: assert len(self.block_size) == 2 and self.block_size[0] == 1 - block_size = (block_size[1],) + block_size = (self.block_size[1],) return self.__class__(self.layout_tensor, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) raise ValueError(f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]") From 3db5d9a472269cfdc9c8c317dc8d55df969a65e1 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 25 Sep 2024 11:42:31 -0700 Subject: [PATCH 08/14] fix test --- torchao/testing/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 34600c7b61..96637d3dd5 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -223,6 +223,7 @@ def test_linear_compile(self, device, dtype): NUM_DEVICES, ) +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") class TorchAOTensorParallelTestCase(DTensorTestBase): """Basic test case for tensor subclasses """ From cba6848ca0176ab95429d5af9977e92a9764f3e2 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 25 Sep 2024 12:00:33 -0700 Subject: [PATCH 09/14] device --- torchao/testing/utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 96637d3dd5..e61a9fe8de 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -223,11 +223,9 @@ def test_linear_compile(self, device, dtype): NUM_DEVICES, ) -@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") class TorchAOTensorParallelTestCase(DTensorTestBase): """Basic test case for tensor subclasses """ - COMMON_DEVICES = (["cuda"] if torch.cuda.is_available() else []) COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] TENSOR_SUBCLASS = AffineQuantizedTensor @@ -279,10 +277,11 @@ def quantize(self, m: torch.nn.Module) -> torch.nn.Module: quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS)) return m - @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) @with_comms - def test_tp(self, device, dtype): + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tp(self, dtype): + device = "cuda" # To make sure different ranks create the same module torch.manual_seed(5) From 713946690797967f17a170bf60c4808e72967e19 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 25 Sep 2024 18:48:06 -0700 Subject: [PATCH 10/14] change transpose impl --- torchao/dtypes/affine_quantized_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 85b5ab3cf3..43899b4802 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -603,7 +603,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): elif func is aten.t.default: tensor = args[0] new = tensor.__class__( - tensor.int_data.view(tensor.shape[::-1]), tensor.scale, tensor.zero_point, tensor.layout_type + tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor.layout_type ) return return_and_correct_aliasing(func, args, kwargs, new) From 115a5f28cd392e7b8a59656018e84737d74a1b27 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 26 Sep 2024 17:19:13 -0700 Subject: [PATCH 11/14] Skip compiled TP test for torch version < 2.5 --- torchao/testing/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index e61a9fe8de..550378c1d0 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -5,6 +5,8 @@ import torchao import os +from packaging import version + from torch.testing._internal import common_utils from torchao.dtypes import AffineQuantizedTensor from torchao.dtypes import to_affine_quantized_intx @@ -35,6 +37,8 @@ class MyTestCase(TorchAOBasicTestCase): unittest.main() """ +torch_version = version.Version(torch.__version__) + # copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389 def copy_tests( my_cls, other_cls, suffix, test_failures=None, xfail_prop=None @@ -223,6 +227,8 @@ def test_linear_compile(self, device, dtype): NUM_DEVICES, ) +COMPILED_TENSOR_PARALLEL_REQUIRED_VERSION = version.Version("2.5.0dev") + class TorchAOTensorParallelTestCase(DTensorTestBase): """Basic test case for tensor subclasses """ @@ -319,6 +325,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: y_d = dn_dist(up_dist(input_dtensor)) + if torch_version < COMPILED_TENSOR_PARALLEL_REQUIRED_VERSION: + # Need torch 2.5 to support compiled tensor parallelism + return + up_compiled = torch.compile(up_dist) y_up = up_compiled(input_dtensor) dn_compiled = torch.compile(dn_dist) From 5370e42b590b18ce00b05d7f19581f3ebe5b6789 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 26 Sep 2024 17:24:53 -0700 Subject: [PATCH 12/14] version util --- torchao/testing/utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 550378c1d0..be34e4db01 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -5,13 +5,12 @@ import torchao import os -from packaging import version - from torch.testing._internal import common_utils from torchao.dtypes import AffineQuantizedTensor from torchao.dtypes import to_affine_quantized_intx from torchao.quantization.quant_primitives import MappingType from torchao.quantization import quantize_, int8_weight_only +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 """ How to use: @@ -37,8 +36,6 @@ class MyTestCase(TorchAOBasicTestCase): unittest.main() """ -torch_version = version.Version(torch.__version__) - # copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389 def copy_tests( my_cls, other_cls, suffix, test_failures=None, xfail_prop=None @@ -325,7 +322,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: y_d = dn_dist(up_dist(input_dtensor)) - if torch_version < COMPILED_TENSOR_PARALLEL_REQUIRED_VERSION: + if TORCH_VERSION_AT_LEAST_2_5: # Need torch 2.5 to support compiled tensor parallelism return From ed9b82e93d047a62488effabc6232b27c760dcb1 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 26 Sep 2024 17:28:52 -0700 Subject: [PATCH 13/14] fix --- torchao/testing/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index be34e4db01..c73260eefa 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -224,8 +224,6 @@ def test_linear_compile(self, device, dtype): NUM_DEVICES, ) -COMPILED_TENSOR_PARALLEL_REQUIRED_VERSION = version.Version("2.5.0dev") - class TorchAOTensorParallelTestCase(DTensorTestBase): """Basic test case for tensor subclasses """ From b113edae23ae9e7a53228ede02e699d048555ab5 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 26 Sep 2024 18:17:54 -0700 Subject: [PATCH 14/14] fix version --- torchao/testing/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index c73260eefa..7fa4ba4a63 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -320,7 +320,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: y_d = dn_dist(up_dist(input_dtensor)) - if TORCH_VERSION_AT_LEAST_2_5: + if not TORCH_VERSION_AT_LEAST_2_5: # Need torch 2.5 to support compiled tensor parallelism return