11import copy
22
33import pytest
4+ import unittest
45import torch
56from torch .testing ._internal .common_utils import (
67 TestCase ,
@@ -70,7 +71,7 @@ def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device):
7071 actual = torch .compile (from_scaled_tc_floatx , fullgraph = True )(x , ebits , mbits , scale )
7172 torch .testing .assert_close (actual , expected )
7273
73- @pytest . mark . skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
74+ @unittest . skipIf (not torch .cuda .is_available (), reason = "CUDA not available" )
7475 @parametrize ("ebits,mbits" , _Floatx_DTYPES )
7576 def test_to_copy_device (self , ebits , mbits ):
7677 from torchao .quantization .quant_primitives import (
@@ -87,12 +88,12 @@ def test_to_copy_device(self, ebits, mbits):
8788 floatx_tensor_impl = floatx_tensor_impl .cpu ()
8889 assert floatx_tensor_impl .device .type == "cpu"
8990
90- @pytest . mark . skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
91- @pytest . mark . skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "quantization only works with torch.compile for 2.5+" )
91+ @unittest . skipIf (not torch .cuda .is_available (), reason = "CUDA not available" )
92+ @unittest . skipIf (not TORCH_VERSION_AT_LEAST_2_5 , reason = "quantization only works with torch.compile for 2.5+" )
9293 @parametrize ("ebits,mbits" , _Floatx_DTYPES )
9394 @parametrize ("bias" , [False , True ])
9495 @parametrize ("dtype" , [torch .half , torch .bfloat16 ])
95- @pytest . mark . skipif (is_fbcode (), reason = "broken in fbcode" )
96+ @unittest . skipIf (is_fbcode (), reason = "broken in fbcode" )
9697 def test_fpx_weight_only (self , ebits , mbits , bias , dtype ):
9798 N , OC , IC = 4 , 256 , 64
9899 device = "cuda"
0 commit comments