Skip to content

Commit 6197ea9

Browse files
committed
update config save
1 parent 3c3a78c commit 6197ea9

File tree

1 file changed

+12
-32
lines changed

1 file changed

+12
-32
lines changed

monte-cover/src/montecover/base.py

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ def run_simulation(self, n_jobs=None):
107107

108108
rep_end_time = time.time()
109109
rep_duration = rep_end_time - rep_start_time
110-
self.logger.info(
111-
f"Repetition {i_rep+1} completed in {rep_duration:.2f}s"
112-
)
110+
self.logger.info(f"Repetition {i_rep+1} completed in {rep_duration:.2f}s")
113111

114112
else:
115113
self.logger.info(f"Starting parallel execution with n_jobs={n_jobs}")
@@ -140,9 +138,7 @@ def save_results(self, output_path: str = "results", file_prefix: str = ""):
140138
"Script": [self.__class__.__name__],
141139
"Date": [datetime.now().strftime("%Y-%m-%d %H:%M")],
142140
"Total Runtime (minutes)": [self.total_runtime / 60],
143-
"Python Version": [
144-
f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
145-
],
141+
"Python Version": [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"],
146142
"Config File": [self.config_file],
147143
}
148144
)
@@ -165,7 +161,7 @@ def save_config(self, output_path: str):
165161
self.logger.warning(f"Adding .yaml extension to output path: {output_path}")
166162

167163
with open(output_path, "w") as file:
168-
yaml.dump(self.config, file)
164+
yaml.dump(self.config, file, sort_keys=False, default_flow_style=False, indent=2, allow_unicode=True)
169165

170166
self.logger.info(f"Configuration saved to {output_path}")
171167

@@ -178,9 +174,7 @@ def _load_config(self, config_path: str) -> Dict[str, Any]:
178174
with open(config_path, "r") as file:
179175
config = yaml.safe_load(file)
180176
else:
181-
raise ValueError(
182-
f"Unsupported config file format: {config_path}. Use .yaml or .yml"
183-
)
177+
raise ValueError(f"Unsupported config file format: {config_path}. Use .yaml or .yml")
184178

185179
return config
186180

@@ -204,9 +198,7 @@ def _setup_logging(self, log_level: str, log_file: Optional[str]):
204198
# Console handler
205199
ch = logging.StreamHandler()
206200
ch.setLevel(level)
207-
formatter = logging.Formatter(
208-
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
209-
)
201+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
210202
ch.setFormatter(formatter)
211203
self.logger.addHandler(ch)
212204

@@ -264,9 +256,7 @@ def _process_repetition(self, i_rep):
264256
dml_params = dict(zip(self.dml_parameters.keys(), dml_param_values))
265257
i_param_comb += 1
266258

267-
comb_results = self._process_parameter_combination(
268-
i_rep, i_param_comb, dgp_params, dml_params, dml_data
269-
)
259+
comb_results = self._process_parameter_combination(i_rep, i_param_comb, dgp_params, dml_params, dml_data)
270260

271261
# Merge results
272262
for result_name, result_list in comb_results.items():
@@ -276,14 +266,11 @@ def _process_repetition(self, i_rep):
276266

277267
return rep_results
278268

279-
def _process_parameter_combination(
280-
self, i_rep, i_param_comb, dgp_params, dml_params, dml_data
281-
):
269+
def _process_parameter_combination(self, i_rep, i_param_comb, dgp_params, dml_params, dml_data):
282270
"""Process a single parameter combination."""
283271
# Log parameter combination
284272
self.logger.debug(
285-
f"Rep {i_rep+1}, Combo {i_param_comb}/{self.total_combinations}: "
286-
f"DGPs {dgp_params}, DML {dml_params}"
273+
f"Rep {i_rep+1}, Combo {i_param_comb}/{self.total_combinations}: " f"DGPs {dgp_params}, DML {dml_params}"
287274
)
288275
param_start_time = time.time()
289276

@@ -292,9 +279,7 @@ def _process_parameter_combination(
292279

293280
# Log timing
294281
param_duration = time.time() - param_start_time
295-
self.logger.debug(
296-
f"Parameter combination completed in {param_duration:.2f}s"
297-
)
282+
self.logger.debug(f"Parameter combination completed in {param_duration:.2f}s")
298283

299284
# Process results
300285
if repetition_results is None:
@@ -313,8 +298,7 @@ def _process_parameter_combination(
313298

314299
except Exception as e:
315300
self.logger.error(
316-
f"Error: repetition {i_rep+1}, DGP parameters {dgp_params}, "
317-
f"DML parameters {dml_params}: {str(e)}"
301+
f"Error: repetition {i_rep+1}, DGP parameters {dgp_params}, " f"DML parameters {dml_params}: {str(e)}"
318302
)
319303
self.logger.exception("Exception details:")
320304
return {}
@@ -349,13 +333,9 @@ def _compute_coverage(thetas, oracle_thetas, confint, joint_confint=None):
349333
if joint_confint is not None:
350334
joint_lower_bound = joint_confint.iloc[:, 0]
351335
joint_upper_bound = joint_confint.iloc[:, 1]
352-
joint_coverage_mask = (joint_lower_bound < oracle_thetas) & (
353-
oracle_thetas < joint_upper_bound
354-
)
336+
joint_coverage_mask = (joint_lower_bound < oracle_thetas) & (oracle_thetas < joint_upper_bound)
355337

356338
result_dict["Uniform Coverage"] = np.all(joint_coverage_mask)
357-
result_dict["Uniform CI Length"] = np.mean(
358-
joint_upper_bound - joint_lower_bound
359-
)
339+
result_dict["Uniform CI Length"] = np.mean(joint_upper_bound - joint_lower_bound)
360340

361341
return result_dict

0 commit comments

Comments
 (0)