Skip to content

Commit e93b3c1

Browse files
gkatkovhaoxinwaJennaZhaoJenna ZhaoGregory Katkov
authored andcommitted
Queue and QueuedJob implementation for AWS Batch Service (aws#1511)
* feat: Queue and QueuedJob implementation for AWS Batch Service (aws#991) * fix: Update Batch endpoint; Function renaming (aws#1230) Co-authored-by: Jenna Zhao <[email protected]> * chore: minor formatting --------- Co-authored-by: haoxinwa <[email protected]> Co-authored-by: JennaZhao <[email protected]> Co-authored-by: Jenna Zhao <[email protected]> Co-authored-by: Gregory Katkov <[email protected]>
1 parent a8c157d commit e93b3c1

File tree

18 files changed

+1208
-8
lines changed

18 files changed

+1208
-8
lines changed

src/sagemaker/batch_queueing/__init__.py

Whitespace-only changes.
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""The module provides helper function for Batch Submit/Describe/Terminal job APIs."""
2+
3+
from __future__ import absolute_import
4+
import json
5+
from typing import List, Dict, Optional
6+
from .constants import SAGEMAKER_TRAINING, DEFAULT_TIMEOUT
7+
from .boto_client import get_batch_boto_client
8+
9+
10+
def submit_service_job(
11+
training_payload: Dict,
12+
job_name: str,
13+
job_queue: str,
14+
retry_attempts: Optional[int] = None,
15+
scheduling_priority: Optional[int] = None,
16+
timeout: Optional[Dict] = None,
17+
share_identifier: Optional[str] = None,
18+
tags: Optional[List] = None,
19+
) -> Dict:
20+
"""Batch submit_service_job API helper function.
21+
22+
Args:
23+
training_payload: a dict containing a dict of arguments for Training job.
24+
job_name: Batch job name.
25+
job_queue: Batch job queue ARN.
26+
retry_attempts: The number of retry expected.
27+
scheduling_priority: An integer representing scheduling priority.
28+
timeout: Set with value of timeout if specified, else default to 1 day.
29+
share_identifier: value of shareIdentifier if specified.
30+
tags: A list of string representing Batch tags.
31+
32+
Returns:
33+
A dict containing jobArn, jobName and jobId.
34+
"""
35+
if timeout is None:
36+
timeout = DEFAULT_TIMEOUT
37+
client = get_batch_boto_client()
38+
payload = {
39+
"jobName": job_name,
40+
"jobQueue": job_queue,
41+
"serviceJobType": SAGEMAKER_TRAINING,
42+
"serviceRequestPayload": json.dumps(training_payload),
43+
"timeoutConfig": timeout,
44+
}
45+
if retry_attempts:
46+
payload["retryStrategy"] = {"attempts": retry_attempts}
47+
if scheduling_priority:
48+
payload["schedulingPriority"] = scheduling_priority
49+
if share_identifier:
50+
payload["shareIdentifier"] = share_identifier
51+
if tags:
52+
payload["tags"] = tags
53+
return client.submit_service_job(**payload)
54+
55+
56+
def describe_service_job(job_id: str) -> Dict:
57+
"""Batch describe_service_job API helper function.
58+
59+
Args:
60+
job_id: Job ID used.
61+
62+
Returns: a dict. See the sample below
63+
{
64+
'attempts': [
65+
{
66+
'serviceResourceId': {
67+
'name': 'string',
68+
'value': 'string'
69+
},
70+
'startedAt': 123,
71+
'stoppedAt': 123,
72+
'statusReason': 'string'
73+
},
74+
],
75+
'createdAt': 123,
76+
'isTerminated': True|False,
77+
'jobArn': 'string',
78+
'jobId': 'string',
79+
'jobName': 'string',
80+
'jobQueue': 'string',
81+
'retryStrategy': {
82+
'attempts': 123
83+
},
84+
'schedulingPriority': 123,
85+
'serviceRequestPayload': 'string',
86+
'serviceJobType': 'EKS'|'ECS'|'ECS_FARGATE'|'SAGEMAKER_TRAINING',
87+
'shareIdentifier': 'string',
88+
'startedAt': 123,
89+
'status': 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED',
90+
'statusReason': 'string',
91+
'stoppedAt': 123,
92+
'tags': {
93+
'string': 'string'
94+
},
95+
'timeout': {
96+
'attemptDurationSeconds': 123
97+
}
98+
}
99+
"""
100+
client = get_batch_boto_client()
101+
return client.describe_service_job(jobId=job_id)
102+
103+
104+
def terminate_service_job(job_id: str, reason: Optional[str] = "default terminate reason") -> Dict:
105+
"""Batch terminate_service_job API helper function.
106+
107+
Args:
108+
job_id: Job ID
109+
reason: A string representing terminate reason.
110+
111+
Returns: an empty dict
112+
"""
113+
client = get_batch_boto_client()
114+
return client.terminate_service_job(jobId=job_id, reason=reason)
115+
116+
117+
def list_service_job(
118+
job_queue: str,
119+
job_status: Optional[str] = None,
120+
filters: Optional[List] = None,
121+
next_token: Optional[str] = None,
122+
) -> Dict:
123+
"""Batch list_service_job API helper function.
124+
125+
Args:
126+
job_queue: Batch job queue ARN.
127+
job_status: Batch job status.
128+
filters: A list of Dict. Each contains a filter.
129+
next_token: Used to retrieve data in next page.
130+
131+
Returns: A generator containing list results.
132+
133+
"""
134+
client = get_batch_boto_client()
135+
payload = {"jobQueue": job_queue}
136+
if filters:
137+
payload["filters"] = filters
138+
if next_token:
139+
payload["nextToken"] = next_token
140+
if job_status:
141+
payload["jobStatus"] = job_status
142+
part_of_jobs = client.list_service_jobs(**payload)
143+
next_token = part_of_jobs.get("nextToken")
144+
yield part_of_jobs
145+
if next_token:
146+
yield from list_service_job(job_queue, job_status, filters, next_token)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""The file provides helper function for getting Batch boto client."""
2+
3+
from __future__ import absolute_import
4+
from typing import Optional
5+
import boto3
6+
7+
8+
def get_batch_boto_client(
9+
region: Optional[str] = None,
10+
endpoint: Optional[str] = None,
11+
) -> boto3.session.Session.client:
12+
"""Helper function for getting Batch boto3 client.
13+
14+
Args:
15+
region: Region specified
16+
endpoint: Batch API endpoint.
17+
18+
Returns: Batch boto3 client.
19+
20+
"""
21+
return boto3.client("sm_batch", region_name=region, endpoint_url=endpoint)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""The file defines constants used for Batch API helper functions."""
2+
3+
from __future__ import absolute_import
4+
5+
SAGEMAKER_TRAINING = "SAGEMAKER_TRAINING"
6+
DEFAULT_ATTEMPT_DURATION_IN_SECONDS = 86400 # 1 day in seconds.
7+
DEFAULT_TIMEOUT = {"attemptDurationSeconds": DEFAULT_ATTEMPT_DURATION_IN_SECONDS}
8+
JOB_STATUS_RUNNING = "RUNNING"
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""The file Defines customized exception for Batch queueing"""
2+
3+
from __future__ import absolute_import
4+
5+
6+
class NoTrainingJob(Exception):
7+
"""Define NoTrainingJob Exception.
8+
9+
It means no Training job has been created by AWS Batch service.
10+
"""
11+
12+
def __init__(self, value):
13+
super().__init__(value)
14+
self.value = value
15+
16+
def __str__(self):
17+
"""Convert Exception to string.
18+
19+
Returns: a String containing exception error messages.
20+
21+
"""
22+
return repr(self.value)
23+
24+
25+
class MissingRequiredArgument(Exception):
26+
"""Define MissingRequiredArgument exception.
27+
28+
It means some required arguments are missing.
29+
"""
30+
31+
def __init__(self, value):
32+
super().__init__(value)
33+
self.value = value
34+
35+
def __str__(self):
36+
"""Convert Exception to string.
37+
38+
Returns: a String containing exception error messages.
39+
40+
"""
41+
return repr(self.value)
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""Define Queue class for AWS Batch service"""
2+
3+
from __future__ import absolute_import
4+
from typing import Dict, Optional, List
5+
import logging
6+
from sagemaker.estimator import Estimator, _TrainingJob
7+
from .queue_job import QueuedJob
8+
from .batch_api_helper import submit_service_job, list_service_job
9+
from .exception import MissingRequiredArgument
10+
from .constants import DEFAULT_TIMEOUT, JOB_STATUS_RUNNING
11+
12+
13+
class Queue:
14+
"""Queue class for AWS Batch service
15+
16+
With this class, customers are able to create a new queue and submit jobs to AWS Batch Service.
17+
"""
18+
19+
def __init__(self, queue_name: str):
20+
self.queue_name = queue_name
21+
22+
def submit(
23+
self,
24+
estimator: Estimator,
25+
inputs,
26+
job_name: Optional[str] = None,
27+
max_retry_attempts: Optional[int] = None,
28+
priority: Optional[int] = None,
29+
share_identifier: Optional[str] = None,
30+
timeout: Optional[Dict] = None,
31+
tags: Optional[Dict] = None,
32+
experiment_config: Optional[Dict] = None,
33+
) -> QueuedJob:
34+
"""Submit a queued job and return a QueuedJob object.
35+
36+
Args:
37+
estimator: Training job estimator object.
38+
inputs: Training job inputs.
39+
job_name: Batch job name.
40+
max_retry_attempts: Max retry attempts for Batch job.
41+
priority: Scheduling priority for Batch job.
42+
share_identifier: Share identifier for Batch job.
43+
timeout: Timeout configuration for Batch job.
44+
tags: Tags apply to Batch job. These tags are for Batch job only.
45+
experiment_config: Experiment management configuration.
46+
Optionally, the dict can contain four keys:
47+
'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.
48+
49+
Returns: a QueuedJob object with Batch job ARN and job name.
50+
51+
"""
52+
if experiment_config is None:
53+
experiment_config = {}
54+
estimator.prepare_workflow_for_training(job_name)
55+
training_args = _TrainingJob.get_train_args(estimator, inputs, experiment_config)
56+
training_payload = estimator.sagemaker_session.get_train_request(**training_args)
57+
training_payload = _convert_first_letter_to_lowercase_for_all_keys(training_payload)
58+
59+
if timeout is None:
60+
timeout = DEFAULT_TIMEOUT
61+
if job_name is None:
62+
job_name = training_payload["trainingJobName"]
63+
64+
resp = submit_service_job(
65+
training_payload,
66+
job_name,
67+
self.queue_name,
68+
max_retry_attempts,
69+
priority,
70+
timeout,
71+
share_identifier,
72+
tags,
73+
)
74+
if "jobArn" not in resp or "jobName" not in resp:
75+
raise MissingRequiredArgument(
76+
"jobArn or jobName is missing in response from Batch submit_service_job API"
77+
)
78+
return QueuedJob(resp["jobArn"], resp["jobName"])
79+
80+
def list_jobs(
81+
self, job_name: Optional[str] = None, status: Optional[str] = JOB_STATUS_RUNNING
82+
) -> List[QueuedJob]:
83+
"""List Batch jobs according to job_name or status.
84+
85+
Args:
86+
job_name: Batch job name.
87+
status: Batch job status.
88+
89+
Returns: A list of QueuedJob.
90+
91+
"""
92+
filters = None
93+
if job_name:
94+
filters = [{"name": "JOB_NAME", "values": [job_name]}]
95+
status = None # job_status is ignored when job_name is specified.
96+
jobs_to_return = []
97+
next_token = None
98+
for job_result_dict in list_service_job(self.queue_name, status, filters, next_token):
99+
for job_result in job_result_dict.get("jobSummaryList", []):
100+
if "jobArn" in job_result and "jobName" in job_result:
101+
jobs_to_return.append(QueuedJob(job_result["jobArn"], job_result["jobName"]))
102+
else:
103+
logging.warning("Missing JobArn or JobName in Batch ListJobs API")
104+
continue
105+
return jobs_to_return
106+
107+
108+
def _first_letter_to_lowercase(key):
109+
"""Convert first letter in string to lowercase.
110+
111+
Args:
112+
key: a string.
113+
114+
Returns: a string with first letter being lowercase.
115+
116+
"""
117+
return str.lower(key[0]) + key[1:]
118+
119+
120+
def _convert_first_letter_to_lowercase_for_all_keys(payload):
121+
"""Convert first letter of all keys to lowercase in a Dict.
122+
123+
Args:
124+
payload: a Dict or a List.
125+
126+
Returns: converted Dict or List
127+
128+
"""
129+
if isinstance(payload, dict):
130+
return {
131+
_first_letter_to_lowercase(key): _convert_first_letter_to_lowercase_for_all_keys(value)
132+
for key, value in payload.items()
133+
}
134+
if isinstance(payload, list):
135+
return [_convert_first_letter_to_lowercase_for_all_keys(item) for item in payload]
136+
return payload

0 commit comments

Comments
 (0)