@@ -89,7 +89,7 @@ def test_graviton_pytorch(graviton_pytorch_version):
8989 _test_graviton_framework_uris ("pytorch" , graviton_pytorch_version )
9090
9191
92- def test_graviton_xgboost (graviton_xgboost_versions ):
92+ def test_graviton_xgboost_instance_type_specified (graviton_xgboost_versions ):
9393 for xgboost_version in graviton_xgboost_versions :
9494 for instance_type in GRAVITON_INSTANCE_TYPES :
9595 uri = image_uris .retrieve (
@@ -102,6 +102,19 @@ def test_graviton_xgboost(graviton_xgboost_versions):
102102 assert expected == uri
103103
104104
105+ def test_graviton_xgboost_image_scope_specified (graviton_xgboost_versions ):
106+ for xgboost_version in graviton_xgboost_versions :
107+ for instance_type in GRAVITON_INSTANCE_TYPES :
108+ uri = image_uris .retrieve (
109+ "xgboost" , "us-west-2" , version = xgboost_version , image_scope = "inference_graviton"
110+ )
111+ expected = (
112+ "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:"
113+ f"{ xgboost_version } -arm64"
114+ )
115+ assert expected == uri
116+
117+
105118def test_graviton_xgboost_unsupported_version (graviton_xgboost_unsupported_versions ):
106119 for xgboost_version in graviton_xgboost_unsupported_versions :
107120 for instance_type in GRAVITON_INSTANCE_TYPES :
@@ -112,7 +125,7 @@ def test_graviton_xgboost_unsupported_version(graviton_xgboost_unsupported_versi
112125 assert f"Unsupported xgboost version: { xgboost_version } ." in str (error )
113126
114127
115- def test_graviton_sklearn (graviton_sklearn_versions ):
128+ def test_graviton_sklearn_instance_type_specified (graviton_sklearn_versions ):
116129 for sklearn_version in graviton_sklearn_versions :
117130 for instance_type in GRAVITON_INSTANCE_TYPES :
118131 uri = image_uris .retrieve (
@@ -125,6 +138,19 @@ def test_graviton_sklearn(graviton_sklearn_versions):
125138 assert expected == uri
126139
127140
141+ def test_graviton_sklearn_image_scope_specified (graviton_sklearn_versions ):
142+ for sklearn_version in graviton_sklearn_versions :
143+ for instance_type in GRAVITON_INSTANCE_TYPES :
144+ uri = image_uris .retrieve (
145+ "sklearn" , "us-west-2" , version = sklearn_version , image_scope = "inference_graviton"
146+ )
147+ expected = (
148+ "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:"
149+ f"{ sklearn_version } -arm64-cpu-py3"
150+ )
151+ assert expected == uri
152+
153+
128154def test_graviton_sklearn_unsupported_version (graviton_sklearn_unsupported_versions ):
129155 for sklearn_version in graviton_sklearn_unsupported_versions :
130156 for instance_type in GRAVITON_INSTANCE_TYPES :
0 commit comments