@@ -205,7 +205,9 @@ def __init__(
205205 )
206206
207207 if not _TrialComponent ._trial_component_is_associated_to_trial (
208- self ._trial_component .trial_component_name , self ._trial .trial_name , sagemaker_session
208+ self ._trial_component .trial_component_name ,
209+ self ._trial .trial_name ,
210+ sagemaker_session ,
209211 ):
210212 self ._trial .add_trial_component (self ._trial_component )
211213
@@ -340,7 +342,9 @@ def log_precision_recall(
340342 if positive_label is not None :
341343 kwargs ["pos_label" ] = positive_label
342344
343- precision , recall , _ = precision_recall_curve (y_true , predicted_probabilities , ** kwargs )
345+ precision , recall , _ = precision_recall_curve (
346+ y_true , predicted_probabilities , ** kwargs
347+ )
344348
345349 kwargs ["average" ] = "micro"
346350 ap = average_precision_score (y_true , predicted_probabilities , ** kwargs )
@@ -560,7 +564,9 @@ def _is_input_valid(input_type, field_name, field_value) -> bool:
560564 field_name (str): The name of the field to be checked.
561565 field_value (str or int or float): The value of the field to be checked.
562566 """
563- if isinstance (field_value , Number ) and (isnan (field_value ) or isinf (field_value )):
567+ if isinstance (field_value , Number ) and (
568+ isnan (field_value ) or isinf (field_value )
569+ ):
564570 logger .warning (
565571 "Failed to log %s %s. Received invalid value: %s." ,
566572 input_type ,
@@ -622,10 +628,14 @@ def _verify_trial_component_artifacts_length(self, is_output):
622628 err_msg_template = "Cannot add more than {} {}_artifacts under run"
623629 if is_output :
624630 if len (self ._trial_component .output_artifacts ) >= MAX_RUN_TC_ARTIFACTS_LEN :
625- raise ValueError (err_msg_template .format (MAX_RUN_TC_ARTIFACTS_LEN , "output" ))
631+ raise ValueError (
632+ err_msg_template .format (MAX_RUN_TC_ARTIFACTS_LEN , "output" )
633+ )
626634 else :
627635 if len (self ._trial_component .input_artifacts ) >= MAX_RUN_TC_ARTIFACTS_LEN :
628- raise ValueError (err_msg_template .format (MAX_RUN_TC_ARTIFACTS_LEN , "input" ))
636+ raise ValueError (
637+ err_msg_template .format (MAX_RUN_TC_ARTIFACTS_LEN , "input" )
638+ )
629639
630640 @staticmethod
631641 def _generate_trial_component_name (run_name : str , experiment_name : str ) -> str :
@@ -646,20 +656,28 @@ def _generate_trial_component_name(run_name: str, experiment_name: str) -> str:
646656 """
647657 buffer = 1 # leave length buffers for delimiters
648658 max_len = int (MAX_NAME_LEN_IN_BACKEND / 2 ) - buffer
649- err_msg_template = "The {} (length: {}) must have length less than or equal to {}"
659+ err_msg_template = (
660+ "The {} (length: {}) must have length less than or equal to {}"
661+ )
650662 if len (run_name ) > max_len :
651- raise ValueError (err_msg_template .format ("run_name" , len (run_name ), max_len ))
663+ raise ValueError (
664+ err_msg_template .format ("run_name" , len (run_name ), max_len )
665+ )
652666 if len (experiment_name ) > max_len :
653667 raise ValueError (
654- err_msg_template .format ("experiment_name" , len (experiment_name ), max_len )
668+ err_msg_template .format (
669+ "experiment_name" , len (experiment_name ), max_len
670+ )
655671 )
656672 trial_component_name = "{}{}{}" .format (experiment_name , DELIMITER , run_name )
657673 # due to mixed-case concerns on the backend
658674 trial_component_name = trial_component_name .lower ()
659675 return trial_component_name
660676
661677 @staticmethod
662- def _extract_run_name_from_tc_name (trial_component_name : str , experiment_name : str ) -> str :
678+ def _extract_run_name_from_tc_name (
679+ trial_component_name : str , experiment_name : str
680+ ) -> str :
663681 """Extract the user supplied run name from a trial component name.
664682
665683 Args:
@@ -676,7 +694,9 @@ def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: s
676694 )
677695
678696 @staticmethod
679- def _append_run_tc_label_to_tags (tags : Optional [List [Dict [str , str ]]] = None ) -> list :
697+ def _append_run_tc_label_to_tags (
698+ tags : Optional [List [Dict [str , str ]]] = None
699+ ) -> list :
680700 """Append the run trial component label to tags used to create a trial component.
681701
682702 Args:
0 commit comments