We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d51715a commit 2ee9f76Copy full SHA for 2ee9f76
test/float8/test_base.py
@@ -13,6 +13,7 @@
13
import pytest
14
import torch
15
import torch.nn as nn
16
+
17
from torchao.utils import (
18
TORCH_VERSION_AT_LEAST_2_5,
19
is_sm_at_least_89,
@@ -536,7 +537,7 @@ def test_inference_mode(self):
536
537
with torch.inference_mode(mode=True):
538
m(x)
539
- @unittest.skipIf(not is_sm_89(), "CUDA arch 8.9 not available")
540
+ @unittest.skipIf(not is_sm_at_least_89(), "CUDA arch 8.9 not available")
541
def test_quantize(self):
542
x = torch.randn(32, 32, device="cuda")
543
m = nn.Sequential(nn.Linear(32, 32)).cuda()
0 commit comments