@@ -150,9 +150,8 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
150150 sagemaker_session = sagemaker_session ,
151151 )
152152
153- assert (
154- f"The run_name (length: { MAX_NAME_LEN_IN_BACKEND } ) must have length less than"
155- in str (err )
153+ assert f"The run_name (length: { MAX_NAME_LEN_IN_BACKEND } ) must have length less than" in str (
154+ err
156155 )
157156
158157
@@ -224,9 +223,7 @@ def test_run_load_no_run_name_and_in_train_job(
224223 }
225224 ]
226225 }
227- expmock = MagicMock (
228- return_value = Experiment (experiment_name = TEST_EXP_NAME , tags = expected_tags )
229- )
226+ expmock = MagicMock (return_value = Experiment (experiment_name = TEST_EXP_NAME , tags = expected_tags ))
230227 with patch ("sagemaker.experiments.run.Experiment._load_or_create" , expmock ):
231228 with load_run (sagemaker_session = sagemaker_session , ** kwargs ) as run_obj :
232229 assert run_obj ._in_load
@@ -239,12 +236,8 @@ def test_run_load_no_run_name_and_in_train_job(
239236 assert run_obj .experiment_name == TEST_EXP_NAME
240237 assert run_obj ._experiment
241238 assert run_obj .experiment_config == exp_config
242- assert (
243- run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
244- )
245- assert (
246- run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
247- )
239+ assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
240+ assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
248241 assert run_obj ._experiment .tags == expected_tags
249242
250243 client .describe_training_job .assert_called_once_with (TrainingJobName = job_name )
@@ -269,9 +262,7 @@ def test_run_load_no_run_name_and_in_train_job_but_fail_to_get_exp_cfg(
269262 with load_run (sagemaker_session = sagemaker_session ):
270263 pass
271264
272- assert "Not able to fetch RunName in ExperimentConfig of the sagemaker job" in str (
273- err
274- )
265+ assert "Not able to fetch RunName in ExperimentConfig of the sagemaker job" in str (err )
275266
276267
277268def test_run_load_no_run_name_and_not_in_train_job (run_obj , sagemaker_session ):
@@ -291,9 +282,7 @@ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(
291282
292283 # experiment_name is given but is not supplied along with the run_name so it's ignored.
293284 with pytest .raises (RuntimeError ) as err :
294- with load_run (
295- experiment_name = TEST_EXP_NAME , sagemaker_session = sagemaker_session
296- ):
285+ with load_run (experiment_name = TEST_EXP_NAME , sagemaker_session = sagemaker_session ):
297286 pass
298287
299288 assert "Failed to load a Run object" in str (err )
@@ -621,17 +610,12 @@ def test_log_output_artifact(run_obj):
621610 with run_obj :
622611 run_obj .log_file ("foo.txt" , "name" , "whizz/bang" )
623612 run_obj ._artifact_uploader .upload_artifact .assert_called_with ("foo.txt" )
624- assert (
625- "whizz/bang" == run_obj ._trial_component .output_artifacts ["name" ].media_type
626- )
613+ assert "whizz/bang" == run_obj ._trial_component .output_artifacts ["name" ].media_type
627614
628615 run_obj .log_file ("foo.txt" )
629616 run_obj ._artifact_uploader .upload_artifact .assert_called_with ("foo.txt" )
630617 assert "foo.txt" in run_obj ._trial_component .output_artifacts
631- assert (
632- "text/plain"
633- == run_obj ._trial_component .output_artifacts ["foo.txt" ].media_type
634- )
618+ assert "text/plain" == run_obj ._trial_component .output_artifacts ["foo.txt" ].media_type
635619
636620
637621def test_log_input_artifact_outside_run_context (run_obj ):
@@ -648,51 +632,36 @@ def test_log_input_artifact(run_obj):
648632 with run_obj :
649633 run_obj .log_file ("foo.txt" , "name" , "whizz/bang" , is_output = False )
650634 run_obj ._artifact_uploader .upload_artifact .assert_called_with ("foo.txt" )
651- assert (
652- "whizz/bang" == run_obj ._trial_component .input_artifacts ["name" ].media_type
653- )
635+ assert "whizz/bang" == run_obj ._trial_component .input_artifacts ["name" ].media_type
654636
655637 run_obj .log_file ("foo.txt" , is_output = False )
656638 run_obj ._artifact_uploader .upload_artifact .assert_called_with ("foo.txt" )
657639 assert "foo.txt" in run_obj ._trial_component .input_artifacts
658- assert (
659- "text/plain"
660- == run_obj ._trial_component .input_artifacts ["foo.txt" ].media_type
661- )
640+ assert "text/plain" == run_obj ._trial_component .input_artifacts ["foo.txt" ].media_type
662641
663642
664643def test_log_multiple_inputs (run_obj ):
665644 with run_obj :
666645 for index in range (0 , MAX_RUN_TC_ARTIFACTS_LEN ):
667646 file_path = "foo" + str (index ) + ".txt"
668647 run_obj ._trial_component .input_artifacts [file_path ] = {
669- "foo" : TrialComponentArtifact (
670- value = "baz" + str (index ), media_type = "text/text"
671- )
648+ "foo" : TrialComponentArtifact (value = "baz" + str (index ), media_type = "text/text" )
672649 }
673650 with pytest .raises (ValueError ) as error :
674651 run_obj .log_artifact ("foo.txt" , "name" , "whizz/bang" , False )
675- assert (
676- f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } input_artifacts"
677- in str (error )
678- )
652+ assert f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } input_artifacts" in str (error )
679653
680654
681655def test_log_multiple_outputs (run_obj ):
682656 with run_obj :
683657 for index in range (0 , MAX_RUN_TC_ARTIFACTS_LEN ):
684658 file_path = "foo" + str (index ) + ".txt"
685659 run_obj ._trial_component .output_artifacts [file_path ] = {
686- "foo" : TrialComponentArtifact (
687- value = "baz" + str (index ), media_type = "text/text"
688- )
660+ "foo" : TrialComponentArtifact (value = "baz" + str (index ), media_type = "text/text" )
689661 }
690662 with pytest .raises (ValueError ) as error :
691663 run_obj .log_artifact ("foo.txt" , "name" , "whizz/bang" )
692- assert (
693- f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } output_artifacts"
694- in str (error )
695- )
664+ assert f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } output_artifacts" in str (error )
696665
697666
698667def test_log_multiple_input_artifacts (run_obj ):
@@ -722,10 +691,7 @@ def test_log_multiple_input_artifacts(run_obj):
722691 # log an extra input artifact, should raise exception
723692 with pytest .raises (ValueError ) as error :
724693 run_obj .log_file ("foo.txt" , "name" , "whizz/bang" , is_output = False )
725- assert (
726- f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } input_artifacts"
727- in str (error )
728- )
694+ assert f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } input_artifacts" in str (error )
729695
730696
731697def test_log_multiple_output_artifacts (run_obj ):
@@ -750,10 +716,7 @@ def test_log_multiple_output_artifacts(run_obj):
750716 # log an extra output artifact, should raise exception
751717 with pytest .raises (ValueError ) as error :
752718 run_obj .log_file ("foo.txt" , "name" , "whizz/bang" )
753- assert (
754- f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } output_artifacts"
755- in str (error )
756- )
719+ assert f"Cannot add more than { MAX_RUN_TC_ARTIFACTS_LEN } output_artifacts" in str (error )
757720
758721
759722def test_log_precision_recall_outside_run_context (run_obj ):
@@ -820,10 +783,7 @@ def test_log_precision_recall_invalid_input(run_obj):
820783 no_skill = no_skill ,
821784 is_output = False ,
822785 )
823- assert (
824- "Lengths mismatch between true labels and predicted probabilities"
825- in str (error )
826- )
786+ assert "Lengths mismatch between true labels and predicted probabilities" in str (error )
827787
828788
829789def test_log_confusion_matrix_outside_run_context (run_obj ):
@@ -921,9 +881,7 @@ def test_log_roc_curve_invalid_input(run_obj):
921881
922882 with run_obj :
923883 with pytest .raises (ValueError ) as error :
924- run_obj .log_roc_curve (
925- y_true , y_scores , title = "TestROCCurve" , is_output = False
926- )
884+ run_obj .log_roc_curve (y_true , y_scores , title = "TestROCCurve" , is_output = False )
927885 assert "Lengths mismatch between true labels and predicted scores" in str (error )
928886
929887
@@ -940,18 +898,10 @@ def test_log_roc_curve_invalid_input(run_obj):
940898@patch ("sagemaker.experiments.run._TrialComponent.list" )
941899@patch ("sagemaker.experiments.run._TrialComponent.search" )
942900def test_list (mock_tc_search , mock_tc_list , mock_tc_load , run_obj , sagemaker_session ):
943- start_time = datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (
944- hours = 1
945- )
946- end_time = datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (
947- hours = 2
948- )
949- creation_time = datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (
950- hours = 3
951- )
952- last_modified_time = datetime .datetime .now (
953- datetime .timezone .utc
954- ) + datetime .timedelta (hours = 4 )
901+ start_time = datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (hours = 1 )
902+ end_time = datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (hours = 2 )
903+ creation_time = datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (hours = 3 )
904+ last_modified_time = datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (hours = 4 )
955905 tc_list_len = 20
956906 tc_list_len_half = int (tc_list_len / 2 )
957907 mock_tc_search .side_effect = [
@@ -1039,9 +989,8 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
1039989 assert run ._experiment
1040990 assert run ._trial
1041991 assert isinstance (run ._trial_component , _TrialComponent )
1042- assert (
1043- run ._trial_component .trial_component_name
1044- == Run ._generate_trial_component_name ("a" + str (i ), TEST_EXP_NAME )
992+ assert run ._trial_component .trial_component_name == Run ._generate_trial_component_name (
993+ "a" + str (i ), TEST_EXP_NAME
1045994 )
1046995 assert run ._in_load is False
1047996 assert run ._inside_load_context is False
@@ -1054,9 +1003,7 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
10541003@patch ("sagemaker.experiments.run._TrialComponent.list" )
10551004def test_list_empty (mock_tc_list , sagemaker_session ):
10561005 mock_tc_list .return_value = []
1057- assert [] == list_runs (
1058- experiment_name = TEST_EXP_NAME , sagemaker_session = sagemaker_session
1059- )
1006+ assert [] == list_runs (experiment_name = TEST_EXP_NAME , sagemaker_session = sagemaker_session )
10601007
10611008
10621009@patch (
@@ -1122,10 +1069,7 @@ def test_exit_fail(sagemaker_session, run_obj):
11221069 except ValueError :
11231070 pass
11241071
1125- assert (
1126- run_obj ._trial_component .status .primary_status
1127- == _TrialComponentStatusType .Failed .value
1128- )
1072+ assert run_obj ._trial_component .status .primary_status == _TrialComponentStatusType .Failed .value
11291073 assert run_obj ._trial_component .status .message
11301074 assert isinstance (run_obj ._trial_component .end_time , datetime .datetime )
11311075
@@ -1182,9 +1126,7 @@ def _verify_tc_status_before_enter_init(trial_component):
11821126 assert not trial_component .status
11831127
11841128
1185- def _verify_tc_status_when_entering (
1186- trial_component , init_start_time = None , has_completed = False
1187- ):
1129+ def _verify_tc_status_when_entering (trial_component , init_start_time = None , has_completed = False ):
11881130 if not init_start_time :
11891131 assert isinstance (trial_component .start_time , datetime .datetime )
11901132 now = datetime .datetime .now (dateutil .tz .tzlocal ())
@@ -1194,17 +1136,11 @@ def _verify_tc_status_when_entering(
11941136
11951137 if not has_completed :
11961138 assert not trial_component .end_time
1197- assert (
1198- trial_component .status .primary_status
1199- == _TrialComponentStatusType .InProgress .value
1200- )
1139+ assert trial_component .status .primary_status == _TrialComponentStatusType .InProgress .value
12011140
12021141
12031142def _verify_tc_status_when_successfully_exit (trial_component , old_end_time = None ):
1204- assert (
1205- trial_component .status .primary_status
1206- == _TrialComponentStatusType .Completed .value
1207- )
1143+ assert trial_component .status .primary_status == _TrialComponentStatusType .Completed .value
12081144 assert isinstance (trial_component .start_time , datetime .datetime )
12091145 assert isinstance (trial_component .end_time , datetime .datetime )
12101146 if old_end_time :
0 commit comments