2323from sagemaker .sklearn import SKLearn
2424from sagemaker .utils import retry_with_backoff
2525from tests .integ .sagemaker .experiments .helpers import name , cleanup_exp_resources
26- from sagemaker .experiments .run import Run , RUN_NAME_BASE
26+ from sagemaker .experiments .run import (
27+ Run ,
28+ RUN_NAME_BASE ,
29+ DELIMITER ,
30+ RUN_TC_TAG_KEY ,
31+ RUN_TC_TAG_VALUE ,
32+ )
2733from sagemaker .experiments ._helper import _DEFAULT_ARTIFACT_PREFIX
2834
2935
@@ -60,24 +66,20 @@ def lineage_artifact_path(tempdir):
6066 return file_path
6167
6268
63- def test_local_run (
69+ file_artifact_name = f"file-artifact-{ name ()} "
70+ lineage_artifact_name = f"lineage-file-artifact-{ name ()} "
71+ metric_name = "test-x-step"
72+
73+
74+ def test_local_run_with_load_specifying_names (
6475 sagemaker_session , artifact_file_path , artifact_directory , lineage_artifact_path
6576):
6677 exp_name = f"my-local-exp-{ name ()} "
67- exp_name2 = f"{ exp_name } -2"
68- file_artifact_name = "file-artifact"
69- lineage_artifact_name = "lineage-file-artifact"
70- table_artifact_name = "TestTableTitle"
71- metric_name = "test-x-step"
72-
73- with cleanup_exp_resources (
74- exp_names = [exp_name , exp_name2 ], sagemaker_session = sagemaker_session
75- ):
78+ with cleanup_exp_resources (exp_names = [exp_name ], sagemaker_session = sagemaker_session ):
7679 # Run name is not provided, will create a new TC
7780 with Run .init (experiment_name = exp_name , sagemaker_session = sagemaker_session ) as run1 :
7881 run1_name = run1 .run_name
79- run1_exp_name = run1 .experiment_name
80- run1_trial_name = run1 ._trial .trial_name
82+ assert RUN_NAME_BASE in run1_name
8183
8284 run1 .log_parameter ("p1" , 1.0 )
8385 run1 .log_parameter ("p2" , "p2-value" )
@@ -86,73 +88,134 @@ def test_local_run(
8688 run1 .log_artifact_file (file_path = artifact_file_path , name = file_artifact_name )
8789 run1 .log_artifact_directory (directory = artifact_directory , is_output = False )
8890 run1 .log_lineage_artifact (file_path = lineage_artifact_path , name = lineage_artifact_name )
89- run1 .log_table (
90- title = table_artifact_name , values = {"x" : [1 , 2 , 3 ], "y" : [4 , 5 , 6 ]}, is_output = False
91- )
9291
9392 for i in range (_MetricsManager ._BATCH_SIZE ):
9493 run1 .log_metric (name = metric_name , value = i , step = i )
9594
96- assert RUN_NAME_BASE in run1_name
97-
98- def validate_tc_artifact_association (is_output , expected_artifact_name ):
99- if is_output :
100- # It's an output association from the tc
101- response = sagemaker_session .sagemaker_client .list_associations (
102- SourceArn = tc .trial_component_arn
103- )
104- else :
105- # It's an input association to the tc
106- response = sagemaker_session .sagemaker_client .list_associations (
107- DestinationArn = tc .trial_component_arn
108- )
109- associations = response ["AssociationSummaries" ]
110-
111- assert len (associations ) == 1
112- summary = associations [0 ]
113- if is_output :
114- assert summary ["SourceArn" ] == tc .trial_component_arn
115- assert summary ["DestinationName" ] == expected_artifact_name
116- else :
117- assert summary ["DestinationArn" ] == tc .trial_component_arn
118- assert summary ["SourceName" ] == expected_artifact_name
119-
120- # Run name is passed from the name of an existing TC.
121- # Meanwhile, the experiment_name is changed.
122- # Should load TC from backend.
123- with Run .init (
124- experiment_name = exp_name2 ,
95+ with Run .load (
96+ experiment_name = exp_name ,
12597 run_name = run1_name ,
12698 sagemaker_session = sagemaker_session ,
12799 ) as run2 :
128- assert run1_exp_name != run2 .experiment_name
129- assert run1_trial_name != run2 ._trial .trial_name
130- assert run1_name == run2 .run_name
131-
132- tc = run2 ._trial_component
133- assert tc .parameters == {"p1" : 1.0 , "p2" : "p2-value" , "p3" : 2.0 , "p4" : "p4-value" }
134-
135- s3_prefix = f"s3://{ sagemaker_session .default_bucket ()} /{ _DEFAULT_ARTIFACT_PREFIX } "
136- assert s3_prefix in tc .output_artifacts [file_artifact_name ].value
137- assert "text/plain" == tc .output_artifacts [file_artifact_name ].media_type
138- assert s3_prefix in tc .input_artifacts ["artifact_file1" ].value
139- assert "text/plain" == tc .input_artifacts ["artifact_file1" ].media_type
140- assert s3_prefix in tc .input_artifacts ["artifact_file2" ].value
141- assert "text/plain" == tc .input_artifacts ["artifact_file2" ].media_type
142-
143- assert len (tc .metrics ) == 1
144- metric_summary = tc .metrics [0 ]
145- assert metric_summary .metric_name == metric_name
146- assert metric_summary .max == 9.0
147- assert metric_summary .min == 0.0
100+ assert run2 .run_name == run1_name
101+ assert run2 ._trial_component .trial_component_name == f"{ exp_name } { DELIMITER } { run1_name } "
102+ _check_run_from_local_end_result (
103+ sagemaker_session = sagemaker_session , tc = run2 ._trial_component
104+ )
105+
148106
149- validate_tc_artifact_association (
150- is_output = True ,
151- expected_artifact_name = lineage_artifact_name ,
107+ def _check_run_from_local_end_result (sagemaker_session , tc ):
108+ def validate_tc_artifact_association (is_output , expected_artifact_name ):
109+ if is_output :
110+ # It's an output association from the tc
111+ response = sagemaker_session .sagemaker_client .list_associations (
112+ SourceArn = tc .trial_component_arn
113+ )
114+ else :
115+ # It's an input association to the tc
116+ response = sagemaker_session .sagemaker_client .list_associations (
117+ DestinationArn = tc .trial_component_arn
152118 )
153- validate_tc_artifact_association (
154- is_output = False ,
155- expected_artifact_name = table_artifact_name ,
119+ associations = response ["AssociationSummaries" ]
120+
121+ assert len (associations ) == 1
122+ summary = associations [0 ]
123+ if is_output :
124+ assert summary ["SourceArn" ] == tc .trial_component_arn
125+ assert summary ["DestinationName" ] == expected_artifact_name
126+ else :
127+ assert summary ["DestinationArn" ] == tc .trial_component_arn
128+ assert summary ["SourceName" ] == expected_artifact_name
129+
130+ assert tc .parameters == {"p1" : 1.0 , "p2" : "p2-value" , "p3" : 2.0 , "p4" : "p4-value" }
131+
132+ s3_prefix = f"s3://{ sagemaker_session .default_bucket ()} /{ _DEFAULT_ARTIFACT_PREFIX } "
133+ assert s3_prefix in tc .output_artifacts [file_artifact_name ].value
134+ assert "text/plain" == tc .output_artifacts [file_artifact_name ].media_type
135+ assert s3_prefix in tc .input_artifacts ["artifact_file1" ].value
136+ assert "text/plain" == tc .input_artifacts ["artifact_file1" ].media_type
137+ assert s3_prefix in tc .input_artifacts ["artifact_file2" ].value
138+ assert "text/plain" == tc .input_artifacts ["artifact_file2" ].media_type
139+
140+ assert len (tc .metrics ) == 1
141+ metric_summary = tc .metrics [0 ]
142+ assert metric_summary .metric_name == metric_name
143+ assert metric_summary .max == 9.0
144+ assert metric_summary .min == 0.0
145+
146+ validate_tc_artifact_association (
147+ is_output = True ,
148+ expected_artifact_name = lineage_artifact_name ,
149+ )
150+
151+
152+ def test_two_local_run_init_with_same_run_name_and_different_exp_names (sagemaker_session ):
153+ exp_name1 = f"my-two-local-exp1-{ name ()} "
154+ exp_name2 = f"my-two-local-exp2-{ name ()} "
155+ run_name = "test-run"
156+ with cleanup_exp_resources (
157+ exp_names = [exp_name1 , exp_name2 ], sagemaker_session = sagemaker_session
158+ ):
159+ # Run name is not provided, will create a new TC
160+ with Run .init (
161+ experiment_name = exp_name1 , run_name = run_name , sagemaker_session = sagemaker_session
162+ ) as run1 :
163+ pass
164+ with Run .init (
165+ experiment_name = exp_name2 , run_name = run_name , sagemaker_session = sagemaker_session
166+ ) as run2 :
167+ pass
168+
169+ assert run1 .experiment_name != run2 .experiment_name
170+ assert run1 .run_name == run2 .run_name
171+ assert (
172+ run1 ._trial_component .trial_component_name != run2 ._trial_component .trial_component_name
173+ )
174+ assert run1 ._trial_component .trial_component_name == f"{ exp_name1 } { DELIMITER } { run_name } "
175+ assert run2 ._trial_component .trial_component_name == f"{ exp_name2 } { DELIMITER } { run_name } "
176+
177+
178+ @pytest .mark .parametrize (
179+ "input_names" ,
180+ [
181+ (f"my-local-exp-{ name ()} " , "test-run" , None ), # both have delimiter -
182+ ("my-test-1" , "my-test-1" , None ), # exp_name equals run_name
183+ ("my-test-3" , "my-test-3-run" , None ), # <exp_name><delimiter> is subset of run_name
184+ ("x" * 59 , "test-run" , None ), # long exp_name
185+ ("test-exp" , "y" * 59 , None ), # long run_name
186+ ("x" * 59 , "y" * 59 , None ), # long exp_name and run_name
187+ ("my-test4" , "test-run" , "run-display-name-test" ), # with supplied display name
188+ ],
189+ )
190+ def test_run_name_vs_trial_component_name_edge_cases (
191+ sagemaker_session , artifact_file_path , artifact_directory , lineage_artifact_path , input_names
192+ ):
193+ exp_name , run_name , run_display_name = input_names
194+ with cleanup_exp_resources (exp_names = [exp_name ], sagemaker_session = sagemaker_session ):
195+ with Run .init (
196+ experiment_name = exp_name ,
197+ sagemaker_session = sagemaker_session ,
198+ run_name = run_name ,
199+ run_display_name = run_display_name ,
200+ ) as run1 :
201+ assert not run1 ._experiment .tags
202+ assert not run1 ._trial .tags
203+ tags = run1 ._trial_component .tags
204+ assert len (tags ) == 1
205+ assert tags [0 ]["Key" ] == RUN_TC_TAG_KEY
206+ assert tags [0 ]["Value" ] == RUN_TC_TAG_VALUE
207+
208+ with Run .load (
209+ experiment_name = exp_name ,
210+ run_name = run_name ,
211+ sagemaker_session = sagemaker_session ,
212+ ) as run2 :
213+ assert run2 .experiment_name == exp_name
214+ assert run2 .run_name == run_name
215+ assert run2 ._trial_component .trial_component_name == f"{ exp_name } { DELIMITER } { run_name } "
216+ assert run2 ._trial_component .display_name in (
217+ run_display_name ,
218+ run2 ._trial_component .trial_component_name ,
156219 )
157220
158221
@@ -201,7 +264,7 @@ def test_run_from_local_and_train_job_and_all_exp_cfg_match(sagemaker_session, j
201264 )
202265 estimator .fit (
203266 job_name = f"train-job-{ name ()} " ,
204- experiment_config = run .experiment_config ,
267+ experiment_config = run ._experiment_config ,
205268 wait = True , # wait the training job to finish
206269 logs = "None" , # set to "All" to display logs fetched from the training job
207270 )
@@ -248,7 +311,7 @@ def test_run_from_local_and_train_job_and_exp_cfg_not_match(sagemaker_session, j
248311 )
249312 estimator .fit (
250313 job_name = f"train-job-{ name ()} " ,
251- experiment_config = run .experiment_config ,
314+ experiment_config = run ._experiment_config ,
252315 wait = True , # wait the training job to finish
253316 logs = "None" , # set to "All" to display logs fetched from the training job
254317 )
0 commit comments