-
Notifications
You must be signed in to change notification settings - Fork 369
Fix Float8Tensor quantize op kernrel preference dispatch #2883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,8 @@ | |
| from typing import Tuple | ||
|
|
||
| import torch | ||
| from torch._inductor.utils import run_and_get_code | ||
| from torch.testing import FileCheck | ||
| from torch.testing._internal import common_utils | ||
| from torch.testing._internal.common_utils import ( | ||
| run_tests, | ||
|
|
@@ -85,6 +87,14 @@ def test_fp8_linear_variants( | |
| kernel_preference: KernelPreference, | ||
| sizes: Tuple, | ||
| ): | ||
| if ( | ||
| isinstance(granularity, PerTensor) | ||
| and kernel_preference == KernelPreference.FBGEMM | ||
| ): | ||
| return unittest.skip( | ||
| "per tensor with fbgemm kernel preferece does not work yet" | ||
| ) | ||
|
|
||
| error_message = None | ||
| if isinstance(granularity, PerRow): | ||
| if mode == "dynamic" and dtype != torch.bfloat16: | ||
|
|
@@ -237,7 +247,11 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes): | |
| other_kernel_preferences = [ | ||
| KernelPreference.AUTO, | ||
| ] | ||
| if _is_fbgemm_genai_gpu_available() and is_sm_at_least_90(): | ||
| if ( | ||
| _is_fbgemm_genai_gpu_available() | ||
| and is_sm_at_least_90() | ||
| and not isinstance(granularity, PerTensor) | ||
| ): | ||
| other_kernel_preferences.append(KernelPreference.FBGEMM) | ||
|
|
||
| quantized_outputs = {} | ||
|
|
@@ -399,6 +413,32 @@ def test_moe_weight_reshape_ops(self): | |
| config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) | ||
| self._test_moe_weight_reshape_ops(config) | ||
|
|
||
| # TODO: we have some other tests living in https://github.com/pytorch/ao/blob/4ecc89edd7b5cfc12e6f80854c85d04c472a0eb0/test/dtypes/test_affine_quantized_float.py#L743 | ||
| # that should be moved here after v1 config is deprecated: | ||
| # https://github.com/pytorch/ao/issues/2649 | ||
| @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") | ||
| def test_expected_gpu_kernel_fbgemm(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this test should be together with the other tests we have which check the same thing for other settings of this config, currently in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah I think we can put everything here after we deprecate the AQT path in 9 months |
||
| """Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels""" | ||
| torch.compiler.reset() | ||
|
|
||
| M, K, N = 128, 256, 512 | ||
| m = torch.nn.Sequential( | ||
| torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16) | ||
| ) | ||
| config = Float8DynamicActivationFloat8WeightConfig( | ||
| granularity=PerRow(), | ||
| kernel_preference=KernelPreference.FBGEMM, | ||
| ) | ||
| quantize_(m, config) | ||
| m = torch.compile(m) | ||
| x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) | ||
| out, code = run_and_get_code(m, x) | ||
|
|
||
| # check at least one occurrence of the quantize op and rowwise gemm op | ||
| FileCheck().check_count( | ||
| "torch.ops.triton.quantize_fp8_row.default", 1 | ||
| ).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default", 1).run(code[0]) | ||
|
|
||
|
|
||
| common_utils.instantiate_parametrized_tests(TestFloat8Tensor) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: lets Xfail this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are using unittest, seems like we can't do
return unittest.expectedFailure("...")?but let me know if there is an example to do expectedFailure conditionally instead of skipping entire test