Skip to content

Commit d892708

Browse files
committed
device
1 parent c8a1b38 commit d892708

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

torchao/testing/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,9 @@ def test_linear_compile(self, device, dtype):
165165
NUM_DEVICES,
166166
)
167167

168-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
169168
class TorchAOTensorParallelTestCase(DTensorTestBase):
170169
"""Basic test case for tensor subclasses
171170
"""
172-
COMMON_DEVICES = (["cuda"] if torch.cuda.is_available() else [])
173171
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
174172

175173
TENSOR_SUBCLASS = AffineQuantizedTensor
@@ -221,10 +219,11 @@ def quantize(self, m: torch.nn.Module) -> torch.nn.Module:
221219
quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS))
222220
return m
223221

224-
@common_utils.parametrize("device", COMMON_DEVICES)
225222
@common_utils.parametrize("dtype", COMMON_DTYPES)
226223
@with_comms
227-
def test_tp(self, device, dtype):
224+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
225+
def test_tp(self, dtype):
226+
device = "cuda"
228227
# To make sure different ranks create the same module
229228
torch.manual_seed(5)
230229

0 commit comments

Comments
 (0)