Skip to content

Commit 62a6838

Browse files
authored
Merge pull request #336 from DoubleML/s-update-summary
Update DoubleML __str__ method
2 parents 77b1a6b + bf7e16a commit 62a6838

File tree

4 files changed

+97
-149
lines changed

4 files changed

+97
-149
lines changed

doubleml/did/did_binary.py

Lines changed: 11 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -239,58 +239,17 @@ def __init__(
239239
self._sensitivity_implemented = True
240240
self._external_predictions_implemented = True
241241

242-
def __str__(self):
243-
class_name = self.__class__.__name__
244-
header = f"================== {class_name} Object ==================\n"
245-
data_summary = self._dml_data._data_summary_str()
246-
score_info = (
247-
f"Score function: {str(self.score)}\n"
248-
f"Treatment group: {str(self.g_value)}\n"
249-
f"Pre-treatment period: {str(self.t_value_pre)}\n"
250-
f"Evaluation period: {str(self.t_value_eval)}\n"
251-
f"Control group: {str(self.control_group)}\n"
252-
f"Anticipation periods: {str(self.anticipation_periods)}\n"
253-
f"Effective sample size: {str(self.n_obs_subset)}\n"
254-
)
255-
learner_info = ""
256-
for key, value in self.learner.items():
257-
learner_info += f"Learner {key}: {str(value)}\n"
258-
if self.nuisance_loss is not None:
259-
learner_info += "Out-of-sample Performance:\n"
260-
is_classifier = [value for value in self._is_classifier.values()]
261-
is_regressor = [not value for value in is_classifier]
262-
if any(is_regressor):
263-
learner_info += "Regression:\n"
264-
for learner in [key for key, value in self._is_classifier.items() if value is False]:
265-
learner_info += f"Learner {learner} RMSE: {self.nuisance_loss[learner]}\n"
266-
if any(is_classifier):
267-
learner_info += "Classification:\n"
268-
for learner in [key for key, value in self._is_classifier.items() if value is True]:
269-
learner_info += f"Learner {learner} Log Loss: {self.nuisance_loss[learner]}\n"
270-
271-
if self._is_cluster_data:
272-
resampling_info = (
273-
f"No. folds per cluster: {self._n_folds_per_cluster}\n"
274-
f"No. folds: {self.n_folds}\n"
275-
f"No. repeated sample splits: {self.n_rep}\n"
276-
)
277-
else:
278-
resampling_info = f"No. folds: {self.n_folds}\nNo. repeated sample splits: {self.n_rep}\n"
279-
fit_summary = str(self.summary)
280-
res = (
281-
header
282-
+ "\n------------------ Data summary ------------------\n"
283-
+ data_summary
284-
+ "\n------------------ Score & algorithm ------------------\n"
285-
+ score_info
286-
+ "\n------------------ Machine learner ------------------\n"
287-
+ learner_info
288-
+ "\n------------------ Resampling ------------------\n"
289-
+ resampling_info
290-
+ "\n------------------ Fit summary ------------------\n"
291-
+ fit_summary
292-
)
293-
return res
242+
def _format_score_info_str(self):
243+
lines = [
244+
f"Score function: {str(self.score)}",
245+
f"Treatment group: {str(self.g_value)}",
246+
f"Pre-treatment period: {str(self.t_value_pre)}",
247+
f"Evaluation period: {str(self.t_value_eval)}",
248+
f"Control group: {str(self.control_group)}",
249+
f"Anticipation periods: {str(self.anticipation_periods)}",
250+
f"Effective sample size: {str(self.n_obs_subset)}",
251+
]
252+
return "\\n".join(lines)
294253

295254
@property
296255
def g_value(self):

doubleml/did/did_cs_binary.py

Lines changed: 13 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -156,58 +156,19 @@ def __init__(
156156
self._sensitivity_implemented = True
157157
self._external_predictions_implemented = True
158158

159-
def __str__(self):
160-
class_name = self.__class__.__name__
161-
header = f"================== {class_name} Object ==================\n"
162-
data_summary = self._dml_data._data_summary_str()
163-
score_info = (
164-
f"Score function: {str(self.score)}\n"
165-
f"Treatment group: {str(self.g_value)}\n"
166-
f"Pre-treatment period: {str(self.t_value_pre)}\n"
167-
f"Evaluation period: {str(self.t_value_eval)}\n"
168-
f"Control group: {str(self.control_group)}\n"
169-
f"Anticipation periods: {str(self.anticipation_periods)}\n"
170-
f"Effective sample size: {str(self.n_obs_subset)}\n"
171-
)
172-
learner_info = ""
173-
for key, value in self.learner.items():
174-
learner_info += f"Learner {key}: {str(value)}\n"
175-
if self.nuisance_loss is not None:
176-
learner_info += "Out-of-sample Performance:\n"
177-
is_classifier = [value for value in self._is_classifier.values()]
178-
is_regressor = [not value for value in is_classifier]
179-
if any(is_regressor):
180-
learner_info += "Regression:\n"
181-
for learner in [key for key, value in self._is_classifier.items() if value is False]:
182-
learner_info += f"Learner {learner} RMSE: {self.nuisance_loss[learner]}\n"
183-
if any(is_classifier):
184-
learner_info += "Classification:\n"
185-
for learner in [key for key, value in self._is_classifier.items() if value is True]:
186-
learner_info += f"Learner {learner} Log Loss: {self.nuisance_loss[learner]}\n"
187-
188-
if self._is_cluster_data:
189-
resampling_info = (
190-
f"No. folds per cluster: {self._n_folds_per_cluster}\n"
191-
f"No. folds: {self.n_folds}\n"
192-
f"No. repeated sample splits: {self.n_rep}\n"
193-
)
194-
else:
195-
resampling_info = f"No. folds: {self.n_folds}\nNo. repeated sample splits: {self.n_rep}\n"
196-
fit_summary = str(self.summary)
197-
res = (
198-
header
199-
+ "\n------------------ Data summary ------------------\n"
200-
+ data_summary
201-
+ "\n------------------ Score & algorithm ------------------\n"
202-
+ score_info
203-
+ "\n------------------ Machine learner ------------------\n"
204-
+ learner_info
205-
+ "\n------------------ Resampling ------------------\n"
206-
+ resampling_info
207-
+ "\n------------------ Fit summary ------------------\n"
208-
+ fit_summary
209-
)
210-
return res
159+
def _format_score_info_str(self):
160+
lines = [
161+
f"Score function: {str(self.score)}",
162+
f"Treatment group: {str(self.g_value)}",
163+
f"Pre-treatment period: {str(self.t_value_pre)}",
164+
f"Evaluation period: {str(self.t_value_eval)}",
165+
f"Control group: {str(self.control_group)}",
166+
f"Anticipation periods: {str(self.anticipation_periods)}",
167+
f"Effective sample size: {str(self.n_obs_subset)}",
168+
]
169+
return "\n".join(lines)
170+
171+
# _format_learner_info_str method is inherited from DoubleML base class.
211172

212173
@property
213174
def g_value(self):

doubleml/double_ml.py

Lines changed: 70 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -110,50 +110,87 @@ def __init__(self, obj_dml_data, n_folds, n_rep, score, draw_sample_splitting):
110110
self._i_rep = None
111111
self._i_treat = None
112112

113-
def __str__(self):
113+
def _format_header_str(self):
114114
class_name = self.__class__.__name__
115-
header = f"================== {class_name} Object ==================\n"
116-
data_summary = self._dml_data._data_summary_str()
117-
score_info = f"Score function: {str(self.score)}\n"
115+
return f"================== {class_name} Object =================="
116+
117+
def _format_score_info_str(self):
118+
return f"Score function: {str(self.score)}"
119+
120+
def _format_learner_info_str(self):
118121
learner_info = ""
119-
for key, value in self.learner.items():
120-
learner_info += f"Learner {key}: {str(value)}\n"
122+
if self.learner is not None:
123+
for key, value in self.learner.items():
124+
learner_info += f"Learner {key}: {str(value)}\n"
121125
if self.nuisance_loss is not None:
122126
learner_info += "Out-of-sample Performance:\n"
123-
is_classifier = [value for value in self._is_classifier.values()]
124-
is_regressor = [not value for value in is_classifier]
125-
if any(is_regressor):
126-
learner_info += "Regression:\n"
127-
for learner in [key for key, value in self._is_classifier.items() if value is False]:
128-
learner_info += f"Learner {learner} RMSE: {self.nuisance_loss[learner]}\n"
129-
if any(is_classifier):
130-
learner_info += "Classification:\n"
131-
for learner in [key for key, value in self._is_classifier.items() if value is True]:
132-
learner_info += f"Learner {learner} Log Loss: {self.nuisance_loss[learner]}\n"
127+
# Check if _is_classifier is populated, otherwise, it might be called before fit
128+
if self._is_classifier:
129+
is_classifier_any = any(self._is_classifier.values())
130+
is_regressor_any = any(not v for v in self._is_classifier.values())
131+
132+
if is_regressor_any:
133+
learner_info += "Regression:\n"
134+
for learner_name in self.params_names: # Iterate through known learners
135+
if not self._is_classifier.get(learner_name, True): # Default to not regressor if not found
136+
loss_val = self.nuisance_loss.get(learner_name, "N/A")
137+
learner_info += f"Learner {learner_name} RMSE: {loss_val}\n"
138+
if is_classifier_any:
139+
learner_info += "Classification:\n"
140+
for learner_name in self.params_names: # Iterate through known learners
141+
if self._is_classifier.get(learner_name, False): # Default to not classifier if not found
142+
loss_val = self.nuisance_loss.get(learner_name, "N/A")
143+
learner_info += f"Learner {learner_name} Log Loss: {loss_val}\n"
144+
else:
145+
learner_info += " (Run .fit() to see out-of-sample performance)\n"
146+
return learner_info.strip()
133147

148+
def _format_resampling_info_str(self):
134149
if self._is_cluster_data:
135-
resampling_info = (
150+
return (
136151
f"No. folds per cluster: {self._n_folds_per_cluster}\n"
137152
f"No. folds: {self.n_folds}\n"
138-
f"No. repeated sample splits: {self.n_rep}\n"
153+
f"No. repeated sample splits: {self.n_rep}"
139154
)
140155
else:
141-
resampling_info = f"No. folds: {self.n_folds}\nNo. repeated sample splits: {self.n_rep}\n"
142-
fit_summary = str(self.summary)
143-
res = (
144-
header
145-
+ "\n------------------ Data summary ------------------\n"
146-
+ data_summary
147-
+ "\n------------------ Score & algorithm ------------------\n"
148-
+ score_info
149-
+ "\n------------------ Machine learner ------------------\n"
150-
+ learner_info
151-
+ "\n------------------ Resampling ------------------\n"
152-
+ resampling_info
153-
+ "\n------------------ Fit summary ------------------\n"
154-
+ fit_summary
156+
return f"No. folds: {self.n_folds}\nNo. repeated sample splits: {self.n_rep}"
157+
158+
def _format_additional_info_str(self):
159+
"""
160+
Hook for subclasses to add additional information to the string representation.
161+
Returns an empty string by default.
162+
Subclasses should override this method to provide content.
163+
The content should not include the 'Additional Information' header itself.
164+
"""
165+
return ""
166+
167+
def __str__(self):
168+
header = self._format_header_str()
169+
# Assumes self._dml_data._data_summary_str() exists and is well-formed
170+
data_summary = self._dml_data._data_summary_str()
171+
score_info = self._format_score_info_str()
172+
learner_info = self._format_learner_info_str()
173+
resampling_info = self._format_resampling_info_str()
174+
fit_summary = str(self.summary) # Assumes self.summary is well-formed
175+
176+
representation = (
177+
f"{header}\n"
178+
f"\n------------------ Data Summary ------------------\n"
179+
f"{data_summary}\n"
180+
f"\n------------------ Score & Algorithm ------------------\n"
181+
f"{score_info}\n"
182+
f"\n------------------ Machine Learner ------------------\n"
183+
f"{learner_info}\n"
184+
f"\n------------------ Resampling ------------------\n"
185+
f"{resampling_info}\n"
186+
f"\n------------------ Fit Summary ------------------\n"
187+
f"{fit_summary}"
155188
)
156-
return res
189+
190+
additional_info = self._format_additional_info_str()
191+
if additional_info:
192+
representation += f"\n\n------------------ Additional Information ------------------\n" f"{additional_info}"
193+
return representation
157194

158195
@property
159196
def n_folds(self):

doubleml/irm/iivm.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -197,22 +197,13 @@ def __init__(
197197
self.subgroups = subgroups
198198
self._external_predictions_implemented = True
199199

200-
def __str__(self):
201-
parent_str = super().__str__()
202-
203-
# add robust confset
200+
def _format_additional_info_str(self):
204201
if self.framework is None:
205-
confset_str = ""
202+
return ""
206203
else:
207204
confset = self.robust_confset()
208205
formatted_confset = ", ".join([f"[{lower:.4f}, {upper:.4f}]" for lower, upper in confset])
209-
confset_str = (
210-
"\n\n--------------- Additional Information ----------------\n"
211-
+ f"Robust Confidence Set: {formatted_confset}\n"
212-
)
213-
214-
res = parent_str + confset_str
215-
return res
206+
return f"Robust Confidence Set: {formatted_confset}"
216207

217208
@property
218209
def normalize_ipw(self):

0 commit comments

Comments
 (0)