Skip to content

Commit bf49f41

Browse files
committed
[float8] fuse abs/max with torch.linalg.vector_norm
1 parent 53b6b78 commit bf49f41

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchao/float8/float8_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def amax_history_to_scale_stack(
9999

100100
@torch.no_grad()
101101
def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
102-
amax = torch.max(torch.abs(x))
102+
amax = torch.linalg.vector_norm(x, ord=float("inf"))
103103

104104
# If the user asked for distributed reduction, do it.
105105
# If the user did not ask for it, assume that it will

0 commit comments

Comments
 (0)