1616from sagemaker import image_uris
1717from tests .unit .sagemaker .image_uris import expected_uris
1818
19- COMMON_INSTANCE_TYPES = {"cpu" : " ml.c4.xlarge" , "gpu" : "ml.p4d .24xlarge" }
19+ CONTAINER_VERSIONS = {"ml.p4d.24xlarge" : "cu118" , "ml.p5d .24xlarge" : "cu12 " }
2020
2121
2222@pytest .mark .parametrize ("load_config" , ["pytorch-smp.json" ], indirect = True )
@@ -40,20 +40,21 @@ def test_smp_v2(load_config):
4040 PY_VERSIONS = load_config ["training" ]["versions" ][version ]["py_versions" ]
4141 for py_version in PY_VERSIONS :
4242 for region in ACCOUNTS .keys ():
43- uri = image_uris .get_training_image_uri (
44- region ,
45- framework = "pytorch" ,
46- framework_version = version ,
47- py_version = py_version ,
48- distribution = distribution ,
49- instance_type = COMMON_INSTANCE_TYPES [processor ]
50- )
51- expected = expected_uris .framework_uri (
52- repo = "smdistributed-modelparallel" ,
53- fw_version = version ,
54- py_version = py_version ,
55- processor = processor ,
56- region = region ,
57- account = ACCOUNTS [region ],
58- )
59- assert expected == uri
43+ for instance_type in CONTAINER_VERSIONS .keys ():
44+ uri = image_uris .get_training_image_uri (
45+ region ,
46+ framework = "pytorch" ,
47+ framework_version = version ,
48+ py_version = py_version ,
49+ distribution = distribution ,
50+ instance_type = instance_type
51+ )
52+ expected = expected_uris .framework_uri (
53+ repo = "smdistributed-modelparallel" ,
54+ fw_version = version ,
55+ py_version = f"{ py_version } -{ CONTAINER_VERSIONS [instance_type ]} " ,
56+ processor = processor ,
57+ region = region ,
58+ account = ACCOUNTS [region ],
59+ )
60+ assert expected == uri
0 commit comments