Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added file.txt
Empty file.
115 changes: 111 additions & 4 deletions test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["oncall: distributed"]
import copy
from torch.distributed.tensor import distribute_tensor, Shard

import torch
import torch.nn as nn
Expand All @@ -15,6 +16,7 @@
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, MLP
from torch.testing._internal.common_utils import run_tests

device_type = torch.accelerator.current_accelerator().type

device_type = torch.device(get_devtype())

Expand All @@ -26,7 +28,110 @@ def test_gradient_scaler(self):
{"has_inf": [True, False], "test_2d": [True, False]},
self._test_gradient_scaler,
)

def _test_gradient_scaler(self, has_inf: bool, test_2d: bool):
torch.manual_seed(0)

def refactor_twoWay_parallel(n):
return max(
((i, n // i) for i in range(1, int(n ** 0.5) + 1) if n % i == 0),
key=lambda x: min(x)
)

if test_2d:
# Dynamically compute a balanced 2D mesh
dp_size, tp_size = refactor_twoWay_parallel(self.world_size)
mesh_2d = init_device_mesh(
device_type.type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")
)
dp_mesh, tp_mesh = mesh_2d["dp"], mesh_2d["tp"]

# Ensure model dimensions are divisible by tp_size
model_dim = tp_size # in_features and out_features will both match tp size
model = nn.Sequential(MLP(model_dim), MLP(model_dim), MLP(model_dim))

tp_parallelize_plan = {
"0.in_proj": ColwiseParallel(),
"0.out_proj": RowwiseParallel(),
"1.in_proj": ColwiseParallel(),
"1.out_proj": RowwiseParallel(),
"2.in_proj": ColwiseParallel(),
"2.out_proj": RowwiseParallel(),
}

model = parallelize_module(
model,
device_mesh=tp_mesh,
parallelize_plan=tp_parallelize_plan,
)

for module in model:
fully_shard(module, mesh=dp_mesh)
fully_shard(model, mesh=dp_mesh)

# Ensure input shape is (batch_size, feature_dim)
input = torch.randn((2, model_dim), device=device_type)

else:
# Default path: single-dimension FSDP
model = nn.Sequential(
*[nn.Linear(4, 4, device=device_type, bias=False) for _ in range(2)]
)
for layer in model:
fully_shard(layer)
fully_shard(model)
input = torch.randn([4, 4], device=device_type)

loss = model(input).sum()
scaler = GradScaler(init_scale=2.0, enabled=True, device=device_type.type)
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
scaler.scale(loss).backward()
inv_scale = scaler._scale.double().reciprocal().float()

if (
has_inf is True
and opt.param_groups[0]["params"][0].grad._local_tensor.device.index == 1
):
opt.param_groups[0]["params"][0].grad._local_tensor[0, 0].fill_(float("inf"))

inital_grad = opt.param_groups[0]["params"][0].grad.to_local().clone()

scaler.unscale_(opt)
for found_inf in scaler._per_optimizer_states[id(opt)][
"found_inf_per_device"
].values():
self.assertEqual(found_inf, has_inf)

self.assertEqual(
scaler._per_optimizer_states[id(opt)]["stage"].value,
OptState.UNSCALED.value,
)

unscaled_grad = opt.param_groups[0]["params"][0].grad.to_local().clone()
self.assertEqual(unscaled_grad, inital_grad * inv_scale)
initial_scale = scaler.get_scale()
initial_state = copy.copy(opt.state)

scaler.step(opt)
steped_state = copy.copy(opt.state)

if has_inf:
self.assertEqual(steped_state, initial_state)
else:
self.assertNotEqual(steped_state.items(), initial_state.items())

scaler.update()
updated_scale = scaler.get_scale()

if has_inf:
backoff_factor = scaler.get_backoff_factor()
self.assertEqual(updated_scale, initial_scale * backoff_factor)
else:
self.assertEqual(updated_scale, initial_scale)



'''
def _test_gradient_scaler(self, has_inf: bool, test_2d: bool):
torch.manual_seed(0)
model = nn.Sequential(
Expand All @@ -35,12 +140,14 @@ def _test_gradient_scaler(self, has_inf: bool, test_2d: bool):
for layer in model:
fully_shard(layer)
fully_shard(model)
input = torch.randn([4, 4], device=device_type)
input = torch.randn([4,4], device=device_type)

if test_2d:
mesh_2d = init_device_mesh(
device_type.type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
)
#device_type.type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")

device_type.type, (2,6), mesh_dim_names=("dp", "tp")
)
dp_mesh, tp_mesh = mesh_2d["dp"], mesh_2d["tp"]
model = nn.Sequential(MLP(2), MLP(2), MLP(2))
tp_parallelize_plan = {
Expand Down Expand Up @@ -108,6 +215,6 @@ def _test_gradient_scaler(self, has_inf: bool, test_2d: bool):
# scale is not updated
self.assertEqual(updated_scale, initial_scale)


'''
if __name__ == "__main__":
run_tests()