@@ -110,50 +110,87 @@ def __init__(self, obj_dml_data, n_folds, n_rep, score, draw_sample_splitting):
110110 self ._i_rep = None
111111 self ._i_treat = None
112112
113- def __str__ (self ):
113+ def _format_header_str (self ):
114114 class_name = self .__class__ .__name__
115- header = f"================== { class_name } Object ==================\n "
116- data_summary = self ._dml_data ._data_summary_str ()
117- score_info = f"Score function: { str (self .score )} \n "
115+ return f"================== { class_name } Object =================="
116+
117+ def _format_score_info_str (self ):
118+ return f"Score function: { str (self .score )} "
119+
120+ def _format_learner_info_str (self ):
118121 learner_info = ""
119- for key , value in self .learner .items ():
120- learner_info += f"Learner { key } : { str (value )} \n "
122+ if self .learner is not None :
123+ for key , value in self .learner .items ():
124+ learner_info += f"Learner { key } : { str (value )} \n "
121125 if self .nuisance_loss is not None :
122126 learner_info += "Out-of-sample Performance:\n "
123- is_classifier = [value for value in self ._is_classifier .values ()]
124- is_regressor = [not value for value in is_classifier ]
125- if any (is_regressor ):
126- learner_info += "Regression:\n "
127- for learner in [key for key , value in self ._is_classifier .items () if value is False ]:
128- learner_info += f"Learner { learner } RMSE: { self .nuisance_loss [learner ]} \n "
129- if any (is_classifier ):
130- learner_info += "Classification:\n "
131- for learner in [key for key , value in self ._is_classifier .items () if value is True ]:
132- learner_info += f"Learner { learner } Log Loss: { self .nuisance_loss [learner ]} \n "
127+ # Check if _is_classifier is populated, otherwise, it might be called before fit
128+ if self ._is_classifier :
129+ is_classifier_any = any (self ._is_classifier .values ())
130+ is_regressor_any = any (not v for v in self ._is_classifier .values ())
131+
132+ if is_regressor_any :
133+ learner_info += "Regression:\n "
134+ for learner_name in self .params_names : # Iterate through known learners
135+ if not self ._is_classifier .get (learner_name , True ): # Default to not regressor if not found
136+ loss_val = self .nuisance_loss .get (learner_name , "N/A" )
137+ learner_info += f"Learner { learner_name } RMSE: { loss_val } \n "
138+ if is_classifier_any :
139+ learner_info += "Classification:\n "
140+ for learner_name in self .params_names : # Iterate through known learners
141+ if self ._is_classifier .get (learner_name , False ): # Default to not classifier if not found
142+ loss_val = self .nuisance_loss .get (learner_name , "N/A" )
143+ learner_info += f"Learner { learner_name } Log Loss: { loss_val } \n "
144+ else :
145+ learner_info += " (Run .fit() to see out-of-sample performance)\n "
146+ return learner_info .strip ()
133147
148+ def _format_resampling_info_str (self ):
134149 if self ._is_cluster_data :
135- resampling_info = (
150+ return (
136151 f"No. folds per cluster: { self ._n_folds_per_cluster } \n "
137152 f"No. folds: { self .n_folds } \n "
138- f"No. repeated sample splits: { self .n_rep } \n "
153+ f"No. repeated sample splits: { self .n_rep } "
139154 )
140155 else :
141- resampling_info = f"No. folds: { self .n_folds } \n No. repeated sample splits: { self .n_rep } \n "
142- fit_summary = str (self .summary )
143- res = (
144- header
145- + "\n ------------------ Data summary ------------------\n "
146- + data_summary
147- + "\n ------------------ Score & algorithm ------------------\n "
148- + score_info
149- + "\n ------------------ Machine learner ------------------\n "
150- + learner_info
151- + "\n ------------------ Resampling ------------------\n "
152- + resampling_info
153- + "\n ------------------ Fit summary ------------------\n "
154- + fit_summary
156+ return f"No. folds: { self .n_folds } \n No. repeated sample splits: { self .n_rep } "
157+
158+ def _format_additional_info_str (self ):
159+ """
160+ Hook for subclasses to add additional information to the string representation.
161+ Returns an empty string by default.
162+ Subclasses should override this method to provide content.
163+ The content should not include the 'Additional Information' header itself.
164+ """
165+ return ""
166+
167+ def __str__ (self ):
168+ header = self ._format_header_str ()
169+ # Assumes self._dml_data._data_summary_str() exists and is well-formed
170+ data_summary = self ._dml_data ._data_summary_str ()
171+ score_info = self ._format_score_info_str ()
172+ learner_info = self ._format_learner_info_str ()
173+ resampling_info = self ._format_resampling_info_str ()
174+ fit_summary = str (self .summary ) # Assumes self.summary is well-formed
175+
176+ representation = (
177+ f"{ header } \n "
178+ f"\n ------------------ Data Summary ------------------\n "
179+ f"{ data_summary } \n "
180+ f"\n ------------------ Score & Algorithm ------------------\n "
181+ f"{ score_info } \n "
182+ f"\n ------------------ Machine Learner ------------------\n "
183+ f"{ learner_info } \n "
184+ f"\n ------------------ Resampling ------------------\n "
185+ f"{ resampling_info } \n "
186+ f"\n ------------------ Fit Summary ------------------\n "
187+ f"{ fit_summary } "
155188 )
156- return res
189+
190+ additional_info = self ._format_additional_info_str ()
191+ if additional_info :
192+ representation += f"\n \n ------------------ Additional Information ------------------\n " f"{ additional_info } "
193+ return representation
157194
158195 @property
159196 def n_folds (self ):
0 commit comments