5555 TEST_RUN_DISPLAY_NAME ,
5656 TEST_ARTIFACT_BUCKET ,
5757 TEST_ARTIFACT_PREFIX ,
58+ TEST_TAGS
5859)
5960
6061
@@ -155,24 +156,22 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
155156
156157
157158@pytest .mark .parametrize (
158- ("kwargs" , "expected_artifact_bucket" , "expected_artifact_prefix" ),
159+ ("kwargs" , "expected_artifact_bucket" , "expected_artifact_prefix" , "expected_tags" ),
159160 [
160- ({}, None , _DEFAULT_ARTIFACT_PREFIX ),
161+ ({}, None , _DEFAULT_ARTIFACT_PREFIX , None ),
161162 (
162163 {
163164 "artifact_bucket" : TEST_ARTIFACT_BUCKET ,
164165 "artifact_prefix" : TEST_ARTIFACT_PREFIX ,
166+ "tags" : TEST_TAGS
165167 },
166168 TEST_ARTIFACT_BUCKET ,
167169 TEST_ARTIFACT_PREFIX ,
170+ TEST_TAGS
168171 ),
169172 ],
170173)
171174@patch .object (_TrialComponent , "save" , MagicMock (return_value = None ))
172- @patch (
173- "sagemaker.experiments.run.Experiment._load_or_create" ,
174- MagicMock (return_value = Experiment (experiment_name = TEST_EXP_NAME )),
175- )
176175@patch (
177176 "sagemaker.experiments.run._Trial._load_or_create" ,
178177 MagicMock (side_effect = mock_trial_load_or_create_func ),
@@ -189,6 +188,7 @@ def test_run_load_no_run_name_and_in_train_job(
189188 kwargs ,
190189 expected_artifact_bucket ,
191190 expected_artifact_prefix ,
191+ expected_tags
192192):
193193 client = sagemaker_session .sagemaker_client
194194 job_name = "my-train-job"
@@ -220,19 +220,22 @@ def test_run_load_no_run_name_and_in_train_job(
220220 }
221221 ]
222222 }
223- with load_run (sagemaker_session = sagemaker_session , ** kwargs ) as run_obj :
224- assert run_obj ._in_load
225- assert not run_obj ._inside_init_context
226- assert run_obj ._inside_load_context
227- assert run_obj .run_name == TEST_RUN_NAME
228- assert run_obj ._trial_component .trial_component_name == expected_tc_name
229- assert run_obj .run_group_name == Run ._generate_trial_name (TEST_EXP_NAME )
230- assert run_obj ._trial
231- assert run_obj .experiment_name == TEST_EXP_NAME
232- assert run_obj ._experiment
233- assert run_obj .experiment_config == exp_config
234- assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
235- assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
223+ expmock = MagicMock (return_value = Experiment (experiment_name = TEST_EXP_NAME ,tags = expected_tags ))
224+ with patch ("sagemaker.experiments.run.Experiment._load_or_create" , expmock ):
225+ with load_run (sagemaker_session = sagemaker_session , ** kwargs ) as run_obj :
226+ assert run_obj ._in_load
227+ assert not run_obj ._inside_init_context
228+ assert run_obj ._inside_load_context
229+ assert run_obj .run_name == TEST_RUN_NAME
230+ assert run_obj ._trial_component .trial_component_name == expected_tc_name
231+ assert run_obj .run_group_name == Run ._generate_trial_name (TEST_EXP_NAME )
232+ assert run_obj ._trial
233+ assert run_obj .experiment_name == TEST_EXP_NAME
234+ assert run_obj ._experiment
235+ assert run_obj .experiment_config == exp_config
236+ assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
237+ assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
238+ assert run_obj ._experiment .tags == expected_tags
236239
237240 client .describe_training_job .assert_called_once_with (TrainingJobName = job_name )
238241 run_obj ._trial .add_trial_component .assert_not_called ()
0 commit comments