Skip to content

Commit c652e64

Browse files
committed
change: allow ModelMonitor and Processor to take IAM role names (in addition to ARNs)
This change also removes some hardcoded AWS account IDs and regions from some integration tests.
1 parent ee2c345 commit c652e64

File tree

6 files changed

+38
-47
lines changed

6 files changed

+38
-47
lines changed

src/sagemaker/automl/automl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""A class for SageMaker AutoML Job."""
13+
"""A class for SageMaker AutoML Jobs."""
1414
from __future__ import absolute_import
1515

1616
from six import string_types

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,17 @@
2222
import logging
2323
import uuid
2424

25-
from six.moves.urllib.parse import urlparse
2625
from six import string_types
27-
26+
from six.moves.urllib.parse import urlparse
2827
from botocore.exceptions import ClientError
2928

29+
from sagemaker.exceptions import UnexpectedStatusException
30+
from sagemaker.model_monitor.monitoring_files import Constraints, ConstraintViolations, Statistics
3031
from sagemaker.network import NetworkConfig
32+
from sagemaker.processing import Processor, ProcessingInput, ProcessingJob, ProcessingOutput
3133
from sagemaker.s3 import S3Uploader
32-
33-
from sagemaker.utils import name_from_base
3434
from sagemaker.session import Session
35-
from sagemaker.processing import Processor
36-
from sagemaker.processing import ProcessingJob
37-
from sagemaker.processing import ProcessingInput
38-
from sagemaker.processing import ProcessingOutput
39-
from sagemaker.model_monitor.monitoring_files import Constraints, ConstraintViolations
40-
from sagemaker.model_monitor.monitoring_files import Statistics
41-
from sagemaker.exceptions import UnexpectedStatusException
42-
from sagemaker.utils import retries
35+
from sagemaker.utils import name_from_base, retries
4336

4437
_DEFAULT_MONITOR_IMAGE_URI_WITH_PLACEHOLDERS = (
4538
"{}.dkr.ecr.{}.amazonaws.com/sagemaker-model-monitor-analyzer"
@@ -390,7 +383,7 @@ def update_monitoring_schedule(
390383
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
391384
object that configures network isolation, encryption of
392385
inter-container traffic, security group IDs, and subnets.
393-
role (str): An AWS IAM role. The Amazon SageMaker jobs use this role.
386+
role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role.
394387
image_uri (str): The uri of the image to use for the jobs started by
395388
the Monitor.
396389
@@ -452,7 +445,7 @@ def update_monitoring_schedule(
452445
self.network_config = network_config
453446

454447
if role is not None:
455-
self.role = role
448+
self.role = self.sagemaker_session.expand_role(role)
456449

457450
if image_uri is not None:
458451
self.image_uri = image_uri
@@ -988,7 +981,7 @@ def __init__(
988981
creating Amazon SageMaker Monitoring Schedules to monitor SageMaker endpoints.
989982
990983
Args:
991-
role (str): An AWS IAM role. The Amazon SageMaker jobs use this role.
984+
role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role.
992985
instance_count (int): The number of instances to run the jobs with.
993986
instance_type (str): Type of EC2 instance to use for the job, for example,
994987
'ml.m5.xlarge'.
@@ -1355,7 +1348,7 @@ def update_monitoring_schedule(
13551348
inter-container traffic, security group IDs, and subnets.
13561349
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
13571350
the baselining or monitoring jobs.
1358-
role (str): An AWS IAM role. The Amazon SageMaker jobs use this role.
1351+
role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role.
13591352
13601353
"""
13611354
monitoring_inputs = None
@@ -1431,7 +1424,7 @@ def update_monitoring_schedule(
14311424
network_config_dict = self.network_config._to_request_dict()
14321425

14331426
if role is not None:
1434-
self.role = role
1427+
self.role = self.sagemaker_session.expand_role(role)
14351428

14361429
self.sagemaker_session.update_monitoring_schedule(
14371430
monitoring_schedule_name=self.monitoring_schedule_name,

src/sagemaker/processing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
SageMaker processing tasks.
5353
5454
Args:
55-
role (str): An AWS IAM role. The Amazon SageMaker training jobs
55+
role (str): An AWS IAM role name or ARN. The Amazon SageMaker training jobs
5656
and APIs that create Amazon SageMaker endpoints use this role
5757
to access training data and model artifacts. After the endpoint
5858
is created, the inference code might use the IAM role, if it
@@ -281,7 +281,7 @@ def __init__(
281281
handles Amazon SageMaker processing tasks for jobs using script mode.
282282
283283
Args:
284-
role (str): An AWS IAM role. The Amazon SageMaker training jobs
284+
role (str): An AWS IAM role name or ARN. The Amazon SageMaker training jobs
285285
and APIs that create Amazon SageMaker endpoints use this role
286286
to access training data and model artifacts. After the endpoint
287287
is created, the inference code might use the IAM role, if it
@@ -538,7 +538,7 @@ def start_new(cls, processor, inputs, outputs, experiment_config):
538538
else:
539539
process_request_args["network_config"] = None
540540

541-
process_request_args["role_arn"] = processor.role
541+
process_request_args["role_arn"] = processor.sagemaker_session.expand_role(processor.role)
542542

543543
process_request_args["tags"] = processor.tags
544544

tests/integ/test_auto_ml.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from tests.integ import DATA_DIR, AUTO_ML_DEFAULT_TIMEMOUT_MINUTES
2424
from tests.integ.timeout import timeout
2525

26-
DEV_ACCOUNT = 142577830533
2726
ROLE = "SageMakerRole"
2827
PREFIX = "sagemaker/beta-automl-xgboost"
2928
HOSTING_INSTANCE_TYPE = "ml.c4.xlarge"
@@ -40,26 +39,10 @@
4039
# use a succeeded AutoML job to test describe and list candidates method, otherwise tests will run too long
4140
AUTO_ML_JOB_NAME = "sagemaker-auto-gamma-ml-test"
4241

43-
EXPECTED_DEFAULT_INPUT_CONFIG = [
44-
{
45-
"DataSource": {
46-
"S3DataSource": {
47-
"S3DataType": "S3Prefix",
48-
"S3Uri": "s3://sagemaker-us-east-2-{}/{}/input/iris_training.csv".format(
49-
DEV_ACCOUNT, PREFIX
50-
),
51-
}
52-
},
53-
"TargetAttributeName": TARGET_ATTRIBUTE_NAME,
54-
}
55-
]
5642
EXPECTED_DEFAULT_JOB_CONFIG = {
5743
"CompletionCriteria": {"MaxCandidates": 3},
5844
"SecurityConfig": {"EnableInterContainerTrafficEncryption": False},
5945
}
60-
EXPECTED_DEFAULT_OUTPUT_CONFIG = {
61-
"S3OutputPath": "s3://sagemaker-us-east-2-{}/".format(DEV_ACCOUNT)
62-
}
6346

6447

6548
def test_auto_ml_fit(sagemaker_session):
@@ -102,7 +85,7 @@ def test_auto_ml_input_object_fit(sagemaker_session):
10285

10386

10487
def test_auto_ml_fit_optional_args(sagemaker_session):
105-
output_path = "s3://sagemaker-us-east-2-{}/{}".format(DEV_ACCOUNT, "specified_ouput_path")
88+
output_path = "s3://{}/{}".format(sagemaker_session.default_bucket(), "specified_ouput_path")
10689
problem_type = "MulticlassClassification"
10790
job_objective = {"MetricName": "Accuracy"}
10891
auto_ml = AutoML(
@@ -138,6 +121,23 @@ def test_auto_ml_invalid_target_attribute(sagemaker_session):
138121

139122

140123
def test_auto_ml_describe_auto_ml_job(sagemaker_session):
124+
expected_default_input_config = [
125+
{
126+
"DataSource": {
127+
"S3DataSource": {
128+
"S3DataType": "S3Prefix",
129+
"S3Uri": "s3://{}/{}/input/iris_training.csv".format(
130+
sagemaker_session.default_bucket(), PREFIX
131+
),
132+
}
133+
},
134+
"TargetAttributeName": TARGET_ATTRIBUTE_NAME,
135+
}
136+
]
137+
expected_default_output_config = {
138+
"S3OutputPath": "s3://{}/".format(sagemaker_session.default_bucket())
139+
}
140+
141141
auto_ml = AutoML(
142142
role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session
143143
)
@@ -146,9 +146,9 @@ def test_auto_ml_describe_auto_ml_job(sagemaker_session):
146146
assert desc["AutoMLJobName"] == AUTO_ML_JOB_NAME
147147
assert desc["AutoMLJobStatus"] == "Completed"
148148
assert isinstance(desc["BestCandidate"], dict)
149-
assert desc["InputDataConfig"] == EXPECTED_DEFAULT_INPUT_CONFIG
149+
assert desc["InputDataConfig"] == expected_default_input_config
150150
assert desc["AutoMLJobConfig"] == EXPECTED_DEFAULT_JOB_CONFIG
151-
assert desc["OutputDataConfig"] == EXPECTED_DEFAULT_OUTPUT_CONFIG
151+
assert desc["OutputDataConfig"] == expected_default_output_config
152152

153153

154154
def test_list_candidates(sagemaker_session):

tests/integ/test_model_monitor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from tests.integ.kms_utils import get_or_create_kms_key
4343
from tests.integ.retry import retries
4444

45-
ROLE = "arn:aws:iam::142577830533:role/SageMakerRole"
45+
ROLE = "SageMakerRole"
4646
INSTANCE_COUNT = 1
4747
INSTANCE_TYPE = "ml.m5.xlarge"
4848
VOLUME_SIZE_IN_GB = 40
@@ -63,7 +63,7 @@
6363
DEFAULT_EXECUTION_MAX_RUNTIME_IN_SECONDS = 3600
6464
DEFAULT_IMAGE_SUFFIX = ".com/sagemaker-model-monitor-analyzer"
6565

66-
UPDATED_ROLE = "arn:aws:iam::142577830533:role/SageMakerRole"
66+
UPDATED_ROLE = "SageMakerRole"
6767
UPDATED_INSTANCE_COUNT = 2
6868
UPDATED_INSTANCE_TYPE = "ml.m5.2xlarge"
6969
UPDATED_VOLUME_SIZE_IN_GB = 50
@@ -99,7 +99,7 @@ def predictor(sagemaker_session, tf_full_version):
9999
):
100100
model = Model(
101101
model_data=model_data,
102-
role="SageMakerRole",
102+
role=ROLE,
103103
framework_version=tf_full_version,
104104
sagemaker_session=sagemaker_session,
105105
)

tests/integ/test_processing.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tests.integ import DATA_DIR
2424
from tests.integ.kms_utils import get_or_create_kms_key
2525

26-
ROLE = "arn:aws:iam::142577830533:role/SageMakerRole"
26+
ROLE = "SageMakerRole"
2727

2828

2929
@pytest.fixture(scope="module")
@@ -57,8 +57,6 @@ def output_kms_key(sagemaker_session):
5757

5858

5959
def test_sklearn(sagemaker_session, sklearn_full_version, cpu_instance_type):
60-
logging.getLogger().setLevel(logging.DEBUG) # TODO-reinvent-2019: REMOVE
61-
6260
script_path = os.path.join(DATA_DIR, "dummy_script.py")
6361
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
6462

0 commit comments

Comments
 (0)