Skip to content

Commit 26d84b5

Browse files
committed
Float8 tensor parallel for aqt_dynamic_act_weight
1 parent 6314d88 commit 26d84b5

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,6 +1653,10 @@ def _linear_fp8_act_fp8_weight_impl(
16531653

16541654
# Preprocess data
16551655
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
1656+
1657+
print(f"out_shape: {out_shape}")
1658+
print(f"input_tensor: {input_tensor.shape}, weight_tensor: {weight_tensor.shape}")
1659+
print(f"inpt_data: {inpt_data.shape}, w_data: {w_data.shape}")
16561660

16571661

16581662
print(f"out_shape: {out_shape}")

torchao/testing/utils.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,6 @@ def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
250250
m.linear.weight = torch.nn.Parameter(
251251
dtensor, requires_grad=False
252252
)
253-
print('colwise shard Shapeof m.linear.weight : ', m.linear.weight.shape)
254253
return m
255254

256255
@staticmethod
@@ -265,15 +264,11 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
265264
rank = mesh.get_local_rank()
266265
local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols]
267266
# Construct DTensor from local shard
268-
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)], run_check=True)
269-
print(f'dtensor shape: {dtensor.shape}')
270-
print(f'Other dtensor values: {local_shard.original_weight_tensor.tensor_impl.float8_data.shape}, {mesh}, {[Shard(1)]}')
267+
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)])
271268
# Replace parameter in module
272269
m.linear.weight = torch.nn.Parameter(
273270
dtensor, requires_grad=False
274271
)
275-
print('rowwise shard Shapeof m.linear.weight : ', m.linear.weight.shape)
276-
277272
return m
278273

279274
def quantize(self, m: torch.nn.Module) -> torch.nn.Module:
@@ -306,15 +301,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
306301
proj_up = M(1024, 2048).to(device).to(dtype)
307302
proj_dn = M(2048, 1024).to(device).to(dtype)
308303
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
304+
print('Run y')
309305
y = proj_dn(proj_up(example_input))
310-
print('Run before y')
306+
311307
# Quantize the model
312308
up_quant = self.quantize(proj_up)
313309
dn_quant = self.quantize(proj_dn)
314-
print('Run before y_q')
315310
y_q = dn_quant(up_quant(example_input))
316-
print('Executed y_q')
317-
311+
318312
mesh = self.build_device_mesh()
319313
mesh.device_type = "cuda"
320314

0 commit comments

Comments
 (0)