Skip to content

Commit e261b89

Browse files
committed
incline common params
1 parent aac77b6 commit e261b89

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

test/test_ao_models.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,6 @@
1010
from torchao._models.llama.model import Transformer
1111
from torchao.testing import common_utils
1212

13-
_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
14-
_BATCH_SIZES = [1, 4]
15-
_TRAINING_MODES = [True, False]
16-
17-
# Define test parameters
18-
COMMON_DEVICES = common_utils.parametrize("device", _AVAILABLE_DEVICES)
19-
COMMON_DTYPES = common_utils.parametrize("dtype", [torch.float32, torch.bfloat16])
20-
2113

2214
def init_model(name="stories15M", device="cpu", precision=torch.bfloat16):
2315
"""Initialize and return a Transformer model with specified configuration."""
@@ -29,9 +21,11 @@ def init_model(name="stories15M", device="cpu", precision=torch.bfloat16):
2921
class TorchAOBasicTestCase(unittest.TestCase):
3022
"""Test suite for basic Transformer inference functionality."""
3123

32-
@COMMON_DEVICES
33-
@common_utils.parametrize("batch_size", _BATCH_SIZES)
34-
@common_utils.parametrize("is_training", _TRAINING_MODES)
24+
@common_utils.parametrize(
25+
"device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
26+
)
27+
@common_utils.parametrize("batch_size", [1, 4])
28+
@common_utils.parametrize("is_training", [True, False])
3529
def test_ao_inference_mode(self, device, batch_size, is_training):
3630
# Initialize model with specified device
3731
random_model = init_model(device=device)

0 commit comments

Comments
 (0)