|
41 | 41 | ) |
42 | 42 | from sagemaker.experiments._run_context import _RunContext |
43 | 43 | from sagemaker.experiments.run import Run |
| 44 | +from sagemaker.image_uris import get_base_python_image_uri |
44 | 45 | from sagemaker.session import get_execution_role, _logs_for_job, Session |
45 | 46 | from sagemaker.utils import name_from_base, _tmpdir, resolve_value_from_config |
46 | 47 | from sagemaker.s3 import s3_path_join, S3Uploader |
|
60 | 61 | # training channel names |
61 | 62 | RUNTIME_SCRIPTS_CHANNEL_NAME = "sagemaker_remote_function_bootstrap" |
62 | 63 | REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws" |
63 | | -SAGEMAKER_WHL_CHANNEL_NAME = "sagemaker_whl_file" |
64 | 64 |
|
65 | 65 | # run context dictionary keys |
66 | 66 | KEY_EXPERIMENT_NAME = "experiment_name" |
67 | 67 | KEY_RUN_NAME = "run_name" |
68 | 68 |
|
69 | | -SAGEMAKER_SDK_WHL_FILE = ( |
70 | | - "s3://sagemaker-pathways/beta/pysdk/sagemaker-2.132.1.dev0-py2.py3-none-any.whl" |
71 | | -) |
72 | | - |
73 | 69 | JOBS_CONTAINER_ENTRYPOINT = [ |
74 | 70 | "/bin/bash", |
75 | 71 | f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", |
@@ -280,22 +276,18 @@ def _get_default_image(session): |
280 | 276 | ): |
281 | 277 | return os.environ["SAGEMAKER_INTERNAL_IMAGE_URI"] |
282 | 278 |
|
283 | | - py_major_version = sys.version_info[0] |
284 | | - py_minor_version = sys.version_info[1] |
| 279 | + py_version = str(sys.version_info[0]) + str(sys.version_info[1]) |
285 | 280 |
|
286 | | - # TODO:Add Support for 3.8 |
287 | | - if py_major_version != 3 or py_minor_version != 10: |
288 | | - raise ValueError("Use supported Python version or provide compatible ImageUri.") |
| 281 | + if py_version not in ["310", "38"]: |
| 282 | + raise ValueError( |
| 283 | + "Default image is supported only for Python versions 3.8 and 3.10. If you " |
| 284 | + "are using any other python version, you must provide a compatible image_uri." |
| 285 | + ) |
289 | 286 |
|
290 | | - # TODO: Support only supported by Studio |
291 | 287 | region = session.boto_region_name |
| 288 | + image_uri = get_base_python_image_uri(region=region, py_version=py_version) |
292 | 289 |
|
293 | | - # TODO: Remove beta image and use public base python |
294 | | - beta_image = ( |
295 | | - f"581474259216.dkr.ecr.{region}.amazonaws.com/" |
296 | | - f"sagemaker-pathways-beta:basepy_3_10_latest" |
297 | | - ) |
298 | | - return beta_image |
| 290 | + return image_uri |
299 | 291 |
|
300 | 292 |
|
301 | 293 | class _Job: |
@@ -394,19 +386,6 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non |
394 | 386 | ) |
395 | 387 | ) |
396 | 388 |
|
397 | | - # temporary solution for public beta to make sagemaker installer available |
398 | | - # in the images, this should be removed before pathways GA. |
399 | | - input_data_config.append( |
400 | | - dict( |
401 | | - ChannelName=SAGEMAKER_WHL_CHANNEL_NAME, |
402 | | - DataSource={ |
403 | | - "S3DataSource": { |
404 | | - "S3Uri": SAGEMAKER_SDK_WHL_FILE, |
405 | | - "S3DataType": "S3Prefix", |
406 | | - } |
407 | | - }, |
408 | | - ) |
409 | | - ) |
410 | 389 | request_dict["InputDataConfig"] = input_data_config |
411 | 390 |
|
412 | 391 | output_config = {"S3OutputPath": s3_base_uri} |
|
0 commit comments