1010from torchao ._models .llama .model import Transformer
1111from torchao .testing import common_utils
1212
13- _AVAILABLE_DEVICES = ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
14- _BATCH_SIZES = [1 , 4 ]
15- _TRAINING_MODES = [True , False ]
16-
17- # Define test parameters
18- COMMON_DEVICES = common_utils .parametrize ("device" , _AVAILABLE_DEVICES )
19- COMMON_DTYPES = common_utils .parametrize ("dtype" , [torch .float32 , torch .bfloat16 ])
20-
2113
2214def init_model (name = "stories15M" , device = "cpu" , precision = torch .bfloat16 ):
2315 """Initialize and return a Transformer model with specified configuration."""
@@ -29,9 +21,11 @@ def init_model(name="stories15M", device="cpu", precision=torch.bfloat16):
2921class TorchAOBasicTestCase (unittest .TestCase ):
3022 """Test suite for basic Transformer inference functionality."""
3123
32- @COMMON_DEVICES
33- @common_utils .parametrize ("batch_size" , _BATCH_SIZES )
34- @common_utils .parametrize ("is_training" , _TRAINING_MODES )
24+ @common_utils .parametrize (
25+ "device" , ["cpu" , "cuda" ] if torch .cuda .is_available () else ["cpu" ]
26+ )
27+ @common_utils .parametrize ("batch_size" , [1 , 4 ])
28+ @common_utils .parametrize ("is_training" , [True , False ])
3529 def test_ao_inference_mode (self , device , batch_size , is_training ):
3630 # Initialize model with specified device
3731 random_model = init_model (device = device )
0 commit comments