|
12 | 12 | # language governing permissions and limitations under the License. |
13 | 13 | from __future__ import absolute_import |
14 | 14 | import unittest |
| 15 | +from unittest.mock import Mock |
15 | 16 |
|
16 | 17 |
|
17 | 18 | from mock.mock import patch |
| 19 | +import pytest |
18 | 20 |
|
19 | 21 | import copy |
20 | 22 | from sagemaker.jumpstart import artifacts |
|
28 | 30 | BASE_SPEC, |
29 | 31 | ) |
30 | 32 |
|
| 33 | +from sagemaker.jumpstart.artifacts.model_packages import _retrieve_model_package_arn |
| 34 | +from sagemaker.jumpstart.enums import JumpStartScriptScope |
31 | 35 |
|
32 | | -from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec |
| 36 | +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec |
| 37 | +from tests.unit.sagemaker.workflow.conftest import mock_client |
33 | 38 |
|
34 | 39 |
|
35 | 40 | class ModelArtifactVariantsTest(unittest.TestCase): |
@@ -319,3 +324,109 @@ def test_estimator_fit_kwargs(self, patched_get_model_specs): |
319 | 324 | ) |
320 | 325 |
|
321 | 326 | assert kwargs == {"some-estimator-fit-key": "some-estimator-fit-value"} |
| 327 | + |
| 328 | + |
| 329 | +class RetrieveModelPackageArnTest(unittest.TestCase): |
| 330 | + |
| 331 | + mock_session = Mock(s3_client=mock_client) |
| 332 | + |
| 333 | + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") |
| 334 | + def test_retrieve_model_package_arn(self, patched_get_model_specs): |
| 335 | + patched_get_model_specs.side_effect = get_special_model_spec |
| 336 | + |
| 337 | + model_id = "variant-model" |
| 338 | + region = "us-west-2" |
| 339 | + |
| 340 | + assert ( |
| 341 | + _retrieve_model_package_arn( |
| 342 | + region=region, |
| 343 | + model_id=model_id, |
| 344 | + scope=JumpStartScriptScope.INFERENCE, |
| 345 | + model_version="*", |
| 346 | + sagemaker_session=self.mock_session, |
| 347 | + instance_type="ml.p2.48xlarge", |
| 348 | + ) |
| 349 | + == "us-west-2/blah/blah/blah/gpu" |
| 350 | + ) |
| 351 | + |
| 352 | + assert ( |
| 353 | + _retrieve_model_package_arn( |
| 354 | + region=region, |
| 355 | + model_id=model_id, |
| 356 | + scope=JumpStartScriptScope.INFERENCE, |
| 357 | + model_version="*", |
| 358 | + sagemaker_session=self.mock_session, |
| 359 | + instance_type="ml.p4.2xlarge", |
| 360 | + ) |
| 361 | + == "us-west-2/blah/blah/blah/gpu" |
| 362 | + ) |
| 363 | + |
| 364 | + assert ( |
| 365 | + _retrieve_model_package_arn( |
| 366 | + region=region, |
| 367 | + model_id=model_id, |
| 368 | + scope=JumpStartScriptScope.INFERENCE, |
| 369 | + model_version="*", |
| 370 | + sagemaker_session=self.mock_session, |
| 371 | + instance_type="ml.inf1.2xlarge", |
| 372 | + ) |
| 373 | + == "us-west-2/blah/blah/blah/inf" |
| 374 | + ) |
| 375 | + |
| 376 | + assert ( |
| 377 | + _retrieve_model_package_arn( |
| 378 | + region=region, |
| 379 | + model_id=model_id, |
| 380 | + scope=JumpStartScriptScope.INFERENCE, |
| 381 | + model_version="*", |
| 382 | + sagemaker_session=self.mock_session, |
| 383 | + instance_type="ml.inf2.12xlarge", |
| 384 | + ) |
| 385 | + == "us-west-2/blah/blah/blah/inf" |
| 386 | + ) |
| 387 | + |
| 388 | + assert ( |
| 389 | + _retrieve_model_package_arn( |
| 390 | + region=region, |
| 391 | + model_id=model_id, |
| 392 | + scope=JumpStartScriptScope.INFERENCE, |
| 393 | + model_version="*", |
| 394 | + sagemaker_session=self.mock_session, |
| 395 | + instance_type="ml.afasfasf.12xlarge", |
| 396 | + ) |
| 397 | + == "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" |
| 398 | + ) |
| 399 | + |
| 400 | + assert ( |
| 401 | + _retrieve_model_package_arn( |
| 402 | + region=region, |
| 403 | + model_id=model_id, |
| 404 | + scope=JumpStartScriptScope.INFERENCE, |
| 405 | + model_version="*", |
| 406 | + sagemaker_session=self.mock_session, |
| 407 | + instance_type="ml.m2.12xlarge", |
| 408 | + ) |
| 409 | + == "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" |
| 410 | + ) |
| 411 | + |
| 412 | + assert ( |
| 413 | + _retrieve_model_package_arn( |
| 414 | + region=region, |
| 415 | + model_id=model_id, |
| 416 | + scope=JumpStartScriptScope.INFERENCE, |
| 417 | + model_version="*", |
| 418 | + sagemaker_session=self.mock_session, |
| 419 | + instance_type="nobodycares", |
| 420 | + ) |
| 421 | + == "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" |
| 422 | + ) |
| 423 | + |
| 424 | + with pytest.raises(ValueError): |
| 425 | + _retrieve_model_package_arn( |
| 426 | + region="cn-north-1", |
| 427 | + model_id=model_id, |
| 428 | + scope=JumpStartScriptScope.INFERENCE, |
| 429 | + model_version="*", |
| 430 | + sagemaker_session=self.mock_session, |
| 431 | + instance_type="ml.p2.12xlarge", |
| 432 | + ) |
0 commit comments