@@ -89,7 +89,7 @@ def test_graviton_pytorch(graviton_pytorch_version):
8989 _test_graviton_framework_uris ("pytorch" , graviton_pytorch_version )
9090
9191
92- def test_graviton_xgboost (graviton_xgboost_versions ):
92+ def test_graviton_xgboost_instance_type_specified (graviton_xgboost_versions ):
9393 for xgboost_version in graviton_xgboost_versions :
9494 for instance_type in GRAVITON_INSTANCE_TYPES :
9595 uri = image_uris .retrieve (
@@ -102,6 +102,33 @@ def test_graviton_xgboost(graviton_xgboost_versions):
102102 assert expected == uri
103103
104104
105+ def test_graviton_xgboost_image_scope_specified (graviton_xgboost_versions ):
106+ for xgboost_version in graviton_xgboost_versions :
107+ for instance_type in GRAVITON_INSTANCE_TYPES :
108+ uri = image_uris .retrieve (
109+ "xgboost" , "us-west-2" , version = xgboost_version , image_scope = "inference_graviton"
110+ )
111+ expected = (
112+ "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:"
113+ f"{ xgboost_version } -arm64"
114+ )
115+ assert expected == uri
116+
117+
118+ def test_graviton_xgboost_image_scope_specified_x86_instance (graviton_xgboost_versions ):
119+ for xgboost_version in graviton_xgboost_versions :
120+ for instance_type in GRAVITON_INSTANCE_TYPES :
121+ with pytest .raises (ValueError ) as error :
122+ image_uris .retrieve (
123+ "xgboost" ,
124+ "us-west-2" ,
125+ version = xgboost_version ,
126+ image_scope = "inference_graviton" ,
127+ instance_type = "ml.m5.xlarge" ,
128+ )
129+ assert "Unsupported instance type: m5." in str (error )
130+
131+
105132def test_graviton_xgboost_unsupported_version (graviton_xgboost_unsupported_versions ):
106133 for xgboost_version in graviton_xgboost_unsupported_versions :
107134 for instance_type in GRAVITON_INSTANCE_TYPES :
@@ -112,7 +139,7 @@ def test_graviton_xgboost_unsupported_version(graviton_xgboost_unsupported_versi
112139 assert f"Unsupported xgboost version: { xgboost_version } ." in str (error )
113140
114141
115- def test_graviton_sklearn (graviton_sklearn_versions ):
142+ def test_graviton_sklearn_instance_type_specified (graviton_sklearn_versions ):
116143 for sklearn_version in graviton_sklearn_versions :
117144 for instance_type in GRAVITON_INSTANCE_TYPES :
118145 uri = image_uris .retrieve (
@@ -125,6 +152,19 @@ def test_graviton_sklearn(graviton_sklearn_versions):
125152 assert expected == uri
126153
127154
155+ def test_graviton_sklearn_image_scope_specified (graviton_sklearn_versions ):
156+ for sklearn_version in graviton_sklearn_versions :
157+ for instance_type in GRAVITON_INSTANCE_TYPES :
158+ uri = image_uris .retrieve (
159+ "sklearn" , "us-west-2" , version = sklearn_version , image_scope = "inference_graviton"
160+ )
161+ expected = (
162+ "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:"
163+ f"{ sklearn_version } -arm64-cpu-py3"
164+ )
165+ assert expected == uri
166+
167+
128168def test_graviton_sklearn_unsupported_version (graviton_sklearn_unsupported_versions ):
129169 for sklearn_version in graviton_sklearn_unsupported_versions :
130170 for instance_type in GRAVITON_INSTANCE_TYPES :
@@ -138,6 +178,20 @@ def test_graviton_sklearn_unsupported_version(graviton_sklearn_unsupported_versi
138178 assert expected == uri
139179
140180
181+ def test_graviton_sklearn_image_scope_specified_x86_instance (graviton_sklearn_unsupported_versions ):
182+ for sklearn_version in graviton_sklearn_unsupported_versions :
183+ for instance_type in GRAVITON_INSTANCE_TYPES :
184+ with pytest .raises (ValueError ) as error :
185+ image_uris .retrieve (
186+ "sklearn" ,
187+ "us-west-2" ,
188+ version = sklearn_version ,
189+ image_scope = "inference_graviton" ,
190+ instance_type = "ml.m5.xlarge" ,
191+ )
192+ assert "Unsupported instance type: m5." in str (error )
193+
194+
141195def _expected_graviton_framework_uri (framework , version , region ):
142196 return expected_uris .graviton_framework_uri (
143197 "{}-inference-graviton" .format (framework ),
0 commit comments