From ccf8bc704fbd75c33a9caa95f1a925c9ea9e7585 Mon Sep 17 00:00:00 2001 From: Shuqi Yang Date: Tue, 24 Sep 2024 23:48:40 -0700 Subject: [PATCH] Remove two if statements in fp8 padding (#935) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/935 Reviewed By: vkuzo Differential Revision: D63051205 --- torchao/float8/float8_utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 54613e5b05..362329b5de 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -196,9 +196,7 @@ def _get_min_alignment(size: int, alignment_value: int) -> int: 16 ``` """ - if size % alignment_value == 0: - return size - return (1 + (size // alignment_value)) * alignment_value + return (1 + ((size - 1) // alignment_value)) * alignment_value def pad_tensor_for_matmul( @@ -234,10 +232,6 @@ def pad_tensor_for_matmul( dim1_aligned = _get_min_alignment(dim1, 16) if 0 in dims else dim1 dim2_aligned = _get_min_alignment(dim2, 16) if 1 in dims else dim2 - # Check if padding is needed for either dimension - if dim1 == dim1_aligned and dim2 == dim2_aligned: - return tensor - # Calculate padding values for both dimensions pad_dim1 = dim1_aligned - dim1 pad_dim2 = dim2_aligned - dim2