|
27 | 27 | from sagemaker.model import NEO_IMAGE_ACCOUNT |
28 | 28 | from sagemaker.session import s3_input |
29 | 29 | from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix |
| 30 | +from sagemaker.xgboost.defaults import XGBOOST_VERSION_1, XGBOOST_SUPPORTED_VERSIONS |
30 | 31 | from sagemaker.xgboost.estimator import get_xgboost_image_uri |
31 | | -from sagemaker.xgboost.defaults import XGBOOST_LATEST_VERSION |
32 | 32 |
|
33 | 33 | logger = logging.getLogger(__name__) |
34 | 34 |
|
@@ -559,13 +559,23 @@ def get_image_uri(region_name, repo_name, repo_version=1): |
559 | 559 | """ |
560 | 560 | if repo_name == "xgboost": |
561 | 561 | if repo_version in ["0.90", "0.90-1", "0.90-1-cpu-py3"]: |
562 | | - return get_xgboost_image_uri(region_name, XGBOOST_LATEST_VERSION) |
| 562 | + return get_xgboost_image_uri(region_name, XGBOOST_VERSION_1) |
| 563 | + |
| 564 | + supported_version = [ |
| 565 | + version |
| 566 | + for version in XGBOOST_SUPPORTED_VERSIONS |
| 567 | + if repo_version in (version, version + "-cpu-py3") |
| 568 | + ] |
| 569 | + if supported_version: |
| 570 | + return get_xgboost_image_uri(region_name, supported_version[0]) |
| 571 | + |
563 | 572 | logging.warning( |
564 | | - "There is a more up to date SageMaker XGBoost image." |
| 573 | + "There is a more up to date SageMaker XGBoost image. " |
565 | 574 | "To use the newer image, please set 'repo_version'=" |
566 | | - "'0.90-1. For example:\n" |
| 575 | + "'%s'. For example:\n" |
567 | 576 | "\tget_image_uri(region, 'xgboost', '%s').", |
568 | | - XGBOOST_LATEST_VERSION, |
| 577 | + XGBOOST_VERSION_1, |
| 578 | + XGBOOST_VERSION_1, |
569 | 579 | ) |
570 | 580 | repo = "{}:{}".format(repo_name, repo_version) |
571 | 581 | return "{}/{}".format(registry(region_name, repo_name), repo) |
0 commit comments