diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 158d003a36..2265be31ef 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -135,5 +135,6 @@ def test_print_quantized_module(self, apply_quant): common_utils.instantiate_parametrized_tests(TestAffineQuantized) + if __name__ == "__main__": run_tests() diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py new file mode 100644 index 0000000000..8c1301226f --- /dev/null +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -0,0 +1,12 @@ +from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase +from torch.testing._internal.common_utils import run_tests +from torchao.quantization import int8_weight_only + +class TestAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase): + pass + + +copy_tests(TorchAOTensorParallelTestCase, TestAffineQuantizedTensorParallel, "aqt_tp") + +if __name__ == "__main__": + run_tests() diff --git a/torchao/__init__.py b/torchao/__init__.py index dce378411c..0d1230ba93 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -33,11 +33,13 @@ quantize_, ) from . import dtypes +from . import testing __all__ = [ "dtypes", "autoquant", "quantize_", + "testing", ] # test-pytorchbot diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e00576263f..43899b4802 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -38,7 +38,8 @@ find_multiple, TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_5, - _is_float8_type + _is_float8_type, + fill_defaults, ) import logging @@ -599,13 +600,25 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) - if func is aten.t.default: + elif func is aten.t.default: tensor = args[0] new = tensor.__class__( - tensor.int_data.view(tensor.shape[::-1]), tensor.scale, tensor.zero_point, tensor.layout_type + tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor.layout_type ) return return_and_correct_aliasing(func, args, kwargs, new) + elif func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) + ) + elif dim == 1: + assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" + return PlainAQTLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.zero_point.view(-1), self.layout_type) + else: + raise NotImplementedError(f"PlainAQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError( f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported" ) @@ -1595,6 +1608,39 @@ def _(func, types, args, kwargs): ) return return_and_correct_aliasing(func, args, kwargs, new) +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" + if end >= self.shape[dim]: + end = self.shape[dim] + shape = list(self.shape) + shape[dim] = end - start + block_size = self.block_size + assert len(block_size) == 2, f"Slice only works for 2d block_size right now, got: {block_size}" + # with slice, some shape dimension might be smaller than block_size dimension, so + # we need to make sure there is no overflow + block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) + new = self.__class__(aten.slice.Tensor(self.layout_tensor, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + return return_and_correct_aliasing(func, args, kwargs, new) + +# this is needed for DTensor.from_local() and for flattening tensor +@implements(aten.view.default) +def _(func, types, args, kwargs): + self, shape = args + + if tuple(self.shape) == tuple(shape): + return self.__class__(self.layout_tensor, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + + if len(shape) == 1 and shape[0] == -1: + assert len(self.block_size) == 2 and self.block_size[0] == 1 + block_size = (self.block_size[1],) + return self.__class__(self.layout_tensor, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + + raise ValueError(f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]") + + to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 48a171a75f..7fa4ba4a63 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -3,11 +3,14 @@ import copy import torch import torchao +import os from torch.testing._internal import common_utils from torchao.dtypes import AffineQuantizedTensor from torchao.dtypes import to_affine_quantized_intx from torchao.quantization.quant_primitives import MappingType +from torchao.quantization import quantize_, int8_weight_only +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 """ How to use: @@ -213,10 +216,122 @@ def test_linear_compile(self, device, dtype): lp_res = torch.compile(l)(hp_act_tensor) self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) +import torch.distributed as dist +from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, + NUM_DEVICES, +) + +class TorchAOTensorParallelTestCase(DTensorTestBase): + """Basic test case for tensor subclasses + """ + COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] + + TENSOR_SUBCLASS = AffineQuantizedTensor + QUANT_METHOD_FN = staticmethod(int8_weight_only) + QUANT_METHOD_KWARGS = {} + @staticmethod + def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: + """ + Shard linear layer of the model in column-wise fashion + """ + # Column-wise is wrt to A^T, so for A it is row-wise. + # Number of rows per rank + orig_weight = m.linear.weight + n_local_rows = orig_weight.size(0) // mesh.size() + rank = mesh.get_local_rank() + local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :] + # Construct DTensor from local shard + dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) + # Replace parameter in module + m.linear.weight = torch.nn.Parameter( + dtensor, requires_grad=False + ) + return m + + @staticmethod + def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: + """ + Shard linear layer of the model in row-wise fashion + """ + # Row-wise is wrt to A^T, so for A it is column-wise. + # Number of rows per rank + orig_weight = m.linear.weight + n_local_cols = orig_weight.size(1) // mesh.size() + rank = mesh.get_local_rank() + local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols] + # Construct DTensor from local shard + dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)]) + # Replace parameter in module + m.linear.weight = torch.nn.Parameter( + dtensor, requires_grad=False + ) + return m + + def quantize(self, m: torch.nn.Module) -> torch.nn.Module: + """ + Quantize the model + """ + quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS)) + return m + + @common_utils.parametrize("dtype", COMMON_DTYPES) + @with_comms + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tp(self, dtype): + device = "cuda" + # To make sure different ranks create the same module + torch.manual_seed(5) + + class M(torch.nn.Module): + def __init__(self, in_features, out_features, **kwargs) -> None: + super().__init__(**kwargs) + self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + # Get rank and device + device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}") + + # Original model + proj_up = M(1024, 2048).to(device).to(dtype) + proj_dn = M(2048, 1024).to(device).to(dtype) + example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype) + y = proj_dn(proj_up(example_input)) + + # Quantize the model + up_quant = self.quantize(proj_up) + dn_quant = self.quantize(proj_dn) + y_q = dn_quant(up_quant(example_input)) + + mesh = self.build_device_mesh() + # Shard the models + up_dist = self.colwise_shard(up_quant, mesh) + dn_dist = self.rowwise_shard(dn_quant, mesh) + + # We need to turn inputs into DTensor form as well -- just a format change + input_dtensor = DTensor.from_local( + example_input, mesh, [Replicate()] + ) + + y_d = dn_dist(up_dist(input_dtensor)) + + if not TORCH_VERSION_AT_LEAST_2_5: + # Need torch 2.5 to support compiled tensor parallelism + return + + up_compiled = torch.compile(up_dist) + y_up = up_compiled(input_dtensor) + dn_compiled = torch.compile(dn_dist) + y_dn = dn_compiled(y_up) common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase) +common_utils.instantiate_parametrized_tests(TorchAOTensorParallelTestCase) if __name__ == "__main__": unittest.main() diff --git a/torchao/utils.py b/torchao/utils.py index 1f4f66e1f4..f1248f67a6 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -493,6 +493,30 @@ def _get_to_kwargs(self, *args, **kwargs): } return kwargs +def fill_defaults(args, n, defaults_tail): + """ + __torch_dispatch__ doesn't guarantee the number of arguments you are + passed (e.g., defaulted arguments are not passed); but usually it is + convenient to pad out the arguments list with defaults. This function + helps you do that. + Args: + args: the list of positional arguments passed to __torch_dispatch__ + n: the number of arguments you are expecting to get + defaults_tail: default values for the arguments, starting from the + end of the list + Example: + >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) + [1, 2, 3, 4, 5] + >>> fill_defaults([1, 2, 3], 5, [None, None, None]) + [1, 2, 3, None, None]] + """ + if n - len(defaults_tail) > len(args): + raise RuntimeError("not enough defaults to fill arguments") + r = list(args) + for i in range(len(args), n): + r.append(defaults_tail[i - n + len(defaults_tail)]) + return r + ## Deprecated, will be deleted in the future def _torch_version_at_least(min_version): diff --git a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py index 03b0d31590..bc85d26f5d 100644 --- a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py @@ -25,36 +25,13 @@ LayoutType, PlainLayoutType, ) -from torchao.utils import TorchAOBaseTensor +from torchao.utils import ( + TorchAOBaseTensor, + fill_defaults, +) aten = torch.ops.aten -# TODO: move to torchao/utils.py -def fill_defaults(args, n, defaults_tail): - """ - __torch_dispatch__ doesn't guarantee the number of arguments you are - passed (e.g., defaulted arguments are not passed); but usually it is - convenient to pad out the arguments list with defaults. This function - helps you do that. - Args: - args: the list of positional arguments passed to __torch_dispatch__ - n: the number of arguments you are expecting to get - defaults_tail: default values for the arguments, starting from the - end of the list - Example: - >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) - [1, 2, 3, 4, 5] - >>> fill_defaults([1, 2, 3], 5, [None, None, None]) - [1, 2, 3, None, None]] - """ - if n - len(defaults_tail) > len(args): - raise RuntimeError("not enough defaults to fill arguments") - r = list(args) - for i in range(len(args), n): - r.append(defaults_tail[i - n + len(defaults_tail)]) - return r - - ############################### # Base Layout Tensor Subclass # ############################### @@ -327,7 +304,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) ) elif dim == 1: - return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1, 1), self.transposed, self.layout_type) + return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.transposed, self.layout_type) else: raise NotImplementedError(f"PlainMyDTypeLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") elif func is aten.t.default: diff --git a/tutorials/developer_api_guide/tensor_parallel.py b/tutorials/developer_api_guide/tensor_parallel.py index db610a71fa..0ed3bc9a29 100644 --- a/tutorials/developer_api_guide/tensor_parallel.py +++ b/tutorials/developer_api_guide/tensor_parallel.py @@ -5,7 +5,8 @@ from torch.distributed import DeviceMesh from torch.distributed.tensor import DTensor, Replicate, Shard, Placement from torch.utils._python_dispatch import return_and_correct_aliasing -from my_dtype_tensor_subclass import MyDTypeTensor, fill_defaults +from my_dtype_tensor_subclass import MyDTypeTensor +from torchao.utils import fill_defaults # a tensor subclass that supports tensor parallelism with DTensor class MyDTypeTensorTP(MyDTypeTensor):