From 4367e4960daaa33b96c9d2a671e9352a16b38289 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Tue, 24 Sep 2024 12:03:23 -0700 Subject: [PATCH 1/9] [float8] all-reduce amax on dp mesh instead of global pg Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/float8/float8_scaling_utils.py | 3 ++- torchao/float8/float8_utils.py | 12 ++++++++---- torchao/float8/fsdp_utils.py | 1 + 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index d2ae896320..fd3502562a 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic( linear_mm_config: LinearMMConfig, reduce_amax: bool = False, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, + device_mesh: "DeviceMesh" = None, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -52,7 +53,7 @@ def hp_tensor_to_float8_dynamic( """ if tensor_already_casted_to_fp8(hp_tensor): return hp_tensor - scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax) + scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax, device_mesh) return hp_tensor_and_scale_to_float8( hp_tensor, scale, diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 54613e5b05..a72c7764b6 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -98,23 +98,27 @@ def amax_history_to_scale_stack( @torch.no_grad() -def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: +def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False, device_mesh: "DeviceMesh" = None) -> torch.Tensor: amax = torch.max(torch.abs(x)) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will # happen elsewhere. if reduce_amax and dist.is_initialized(): - dist.all_reduce(amax, op=dist.ReduceOp.MAX) + pg = device_mesh.get_group() if device_mesh is not None else None + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=pg) return amax @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False + x: torch.Tensor, + float8_dtype: torch.dtype, + reduce_amax: bool = False, + device_mesh: "DeviceMesh" = None, ) -> torch.Tensor: - amax = tensor_to_amax(x, reduce_amax=reduce_amax) + amax = tensor_to_amax(x, reduce_amax=reduce_amax, device_mesh=device_mesh) return amax_to_scale(amax, float8_dtype, x.dtype) diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 7ec60c795b..5939f721f8 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -216,6 +216,7 @@ def fsdp_pre_all_gather(self, mesh): self._linear_mm_config, reduce_amax=True, gemm_input_role=GemmInputRole.WEIGHT, + device_mesh=mesh, ) return (float8_tensor._data,), (float8_tensor._scale,) From f1b9c1dde4fd228addac8c5b573aadafcb8cccff Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Tue, 24 Sep 2024 13:42:55 -0700 Subject: [PATCH 2/9] liner Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_fsdp2/test_fsdp2.py | 33 +++++++++++++++++++++++++++- torchao/float8/float8_utils.py | 4 ++-- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index e2e7097f6b..8738c9a187 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -17,10 +17,12 @@ import torch.nn as nn from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType from torchao.float8.float8_linear_utils import convert_to_float8_training +from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from fsdp2_common import check_parity_bf16_mp, check_parity_no_mp from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy -from torch.distributed._tensor import DTensor +from torch.distributed._tensor import DTensor, init_device_mesh +from torchao.float8.float8_tensor import GemmInputRole from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( @@ -293,6 +295,35 @@ def _get_curr_active_memory_mb(self) -> int: return round(mem_stats["active_bytes.all.current"] / 1e6) +class Test2DParallelMultiThread(FSDPTestMultiThread, TestFloat8Common): + @property + def world_size(self) -> int: + return 4 + + def test_amax_allreduce_device_mesh(self): + dp_size = 2 + pp_size = self.world_size // dp_size + global_mesh = init_device_mesh("cuda", (pp_size, dp_size), mesh_dim_names=("pp", "dp")) + dp_mesh = global_mesh["dp"] + pp_mesh = global_mesh["pp"] + + torch.manual_seed(42 + self.rank) + hp_tensor = torch.randn(768, 32, device="cuda") + + if self.rank in [0, 1]: + # rank 0 and 1 are the 1st stage in the pipeline + # rank 2 and 4 are doing thing but waiting for the 1st stage + float8_tensor = hp_tensor_to_float8_dynamic( + hp_tensor, + torch.float8_e4m3fn, + Float8LinearConfig( + cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC), + ), + gemm_input_role=GemmInputRole.WEIGHT, + reduce_amax=True, + device_mesh=dp_mesh + ) + class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common): @property def world_size(self) -> int: diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index a72c7764b6..930cf2af97 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -98,7 +98,7 @@ def amax_history_to_scale_stack( @torch.no_grad() -def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False, device_mesh: "DeviceMesh" = None) -> torch.Tensor: +def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False, device_mesh: dist.DeviceMesh = None) -> torch.Tensor: amax = torch.max(torch.abs(x)) # If the user asked for distributed reduction, do it. @@ -116,7 +116,7 @@ def tensor_to_scale( x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False, - device_mesh: "DeviceMesh" = None, + device_mesh: dist.DeviceMesh = None, ) -> torch.Tensor: amax = tensor_to_amax(x, reduce_amax=reduce_amax, device_mesh=device_mesh) return amax_to_scale(amax, float8_dtype, x.dtype) From 0e501ff1c9c430701128a05554a1b323912ed8e5 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Tue, 24 Sep 2024 13:44:06 -0700 Subject: [PATCH 3/9] improve comments Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_fsdp2/test_fsdp2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index 8738c9a187..f4cc0079fe 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -312,7 +312,7 @@ def test_amax_allreduce_device_mesh(self): if self.rank in [0, 1]: # rank 0 and 1 are the 1st stage in the pipeline - # rank 2 and 4 are doing thing but waiting for the 1st stage + # rank 2 and 4 are doing nothing but waiting for the 1st stage float8_tensor = hp_tensor_to_float8_dynamic( hp_tensor, torch.float8_e4m3fn, From e1979e489a819d9270bb01271158f5c480bcd2fc Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Tue, 24 Sep 2024 13:45:11 -0700 Subject: [PATCH 4/9] move hp tensor inside if Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_fsdp2/test_fsdp2.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index f4cc0079fe..1e955ed233 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -307,12 +307,11 @@ def test_amax_allreduce_device_mesh(self): dp_mesh = global_mesh["dp"] pp_mesh = global_mesh["pp"] - torch.manual_seed(42 + self.rank) - hp_tensor = torch.randn(768, 32, device="cuda") - if self.rank in [0, 1]: # rank 0 and 1 are the 1st stage in the pipeline # rank 2 and 4 are doing nothing but waiting for the 1st stage + torch.manual_seed(42 + self.rank) + hp_tensor = torch.randn(768, 32, device="cuda") float8_tensor = hp_tensor_to_float8_dynamic( hp_tensor, torch.float8_e4m3fn, From 00b36bcc5de21a08c2e1cc20da55702adf323cda Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Tue, 24 Sep 2024 13:51:52 -0700 Subject: [PATCH 5/9] linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/float8/float8_scaling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index fd3502562a..1297e5d95e 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -36,7 +36,7 @@ def hp_tensor_to_float8_dynamic( linear_mm_config: LinearMMConfig, reduce_amax: bool = False, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, - device_mesh: "DeviceMesh" = None, + device_mesh: "torch.distributed.DeviceMesh" = None, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, From 6d8e4a22f11b069c335019b3b52898cded9633b6 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Tue, 24 Sep 2024 13:53:37 -0700 Subject: [PATCH 6/9] linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/float8/float8_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 930cf2af97..fa3c436472 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -98,7 +98,9 @@ def amax_history_to_scale_stack( @torch.no_grad() -def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False, device_mesh: dist.DeviceMesh = None) -> torch.Tensor: +def tensor_to_amax( + x: torch.Tensor, reduce_amax: bool = False, device_mesh: dist.DeviceMesh = None +) -> torch.Tensor: amax = torch.max(torch.abs(x)) # If the user asked for distributed reduction, do it. From 839e8aa724687784c9a455929a1601a0d1d4449a Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Tue, 24 Sep 2024 14:34:43 -0700 Subject: [PATCH 7/9] linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/float8/float8_scaling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 1297e5d95e..e9e1951763 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -36,7 +36,7 @@ def hp_tensor_to_float8_dynamic( linear_mm_config: LinearMMConfig, reduce_amax: bool = False, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, - device_mesh: "torch.distributed.DeviceMesh" = None, + device_mesh = None, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, From 6cfbe7d70076d7ca06d339e2ce51b612768f9ed7 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Tue, 24 Sep 2024 15:33:34 -0700 Subject: [PATCH 8/9] linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/float8/float8_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index fa3c436472..06b9dd54ad 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -99,7 +99,7 @@ def amax_history_to_scale_stack( @torch.no_grad() def tensor_to_amax( - x: torch.Tensor, reduce_amax: bool = False, device_mesh: dist.DeviceMesh = None + x: torch.Tensor, reduce_amax: bool = False, device_mesh = None ) -> torch.Tensor: amax = torch.max(torch.abs(x)) @@ -118,7 +118,7 @@ def tensor_to_scale( x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False, - device_mesh: dist.DeviceMesh = None, + device_mesh = None, ) -> torch.Tensor: amax = tensor_to_amax(x, reduce_amax=reduce_amax, device_mesh=device_mesh) return amax_to_scale(amax, float8_dtype, x.dtype) From 736534142cc54592f67c15a3c69234ebb2d2ac16 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Tue, 24 Sep 2024 15:41:38 -0700 Subject: [PATCH 9/9] linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/float8/float8_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 06b9dd54ad..09f42a946a 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -99,7 +99,7 @@ def amax_history_to_scale_stack( @torch.no_grad() def tensor_to_amax( - x: torch.Tensor, reduce_amax: bool = False, device_mesh = None + x: torch.Tensor, reduce_amax: bool = False, device_mesh=None ) -> torch.Tensor: amax = torch.max(torch.abs(x)) @@ -118,7 +118,7 @@ def tensor_to_scale( x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False, - device_mesh = None, + device_mesh=None, ) -> torch.Tensor: amax = tensor_to_amax(x, reduce_amax=reduce_amax, device_mesh=device_mesh) return amax_to_scale(amax, float8_dtype, x.dtype)