|
4 | 4 | import numpy as np |
5 | 5 | import pandas as pd |
6 | 6 | from doubleml.did.datasets import make_did_CS2021 |
7 | | -from lightgbm import LGBMClassifier, LGBMRegressor |
8 | | -from sklearn.linear_model import LinearRegression, LogisticRegression |
9 | 7 |
|
10 | 8 | from montecover.base import BaseSimulation |
| 9 | +from montecover.utils import create_learner_from_config |
11 | 10 |
|
12 | 11 |
|
13 | 12 | class DIDMultiCoverageSimulation(BaseSimulation): |
@@ -36,39 +35,13 @@ def __init__( |
36 | 35 | def _process_config_parameters(self): |
37 | 36 | """Process simulation-specific parameters from config""" |
38 | 37 | # Process ML models in parameter grid |
| 38 | + # Process ML models in parameter grid |
| 39 | + assert "learners" in self.dml_parameters, "No learners specified in the config file" |
39 | 40 |
|
40 | | - assert ( |
41 | | - "learners" in self.dml_parameters |
42 | | - ), "No learners specified in the config file" |
| 41 | + required_learners = ["ml_g", "ml_m"] |
43 | 42 | for learner in self.dml_parameters["learners"]: |
44 | | - assert "ml_g" in learner, "No ml_g specified in the config file" |
45 | | - assert "ml_m" in learner, "No ml_m specified in the config file" |
46 | | - |
47 | | - # Convert ml_g strings to actual objects |
48 | | - if learner["ml_g"][0] == "Linear": |
49 | | - learner["ml_g"] = ("Linear", LinearRegression()) |
50 | | - elif learner["ml_g"][0] == "LGBM": |
51 | | - learner["ml_g"] = ( |
52 | | - "LGBM", |
53 | | - LGBMRegressor( |
54 | | - n_estimators=500, learning_rate=0.02, verbose=-1, n_jobs=1 |
55 | | - ), |
56 | | - ) |
57 | | - else: |
58 | | - raise ValueError(f"Unknown learner type: {learner['ml_g']}") |
59 | | - |
60 | | - # Convert ml_m strings to actual objects |
61 | | - if learner["ml_m"][0] == "Linear": |
62 | | - learner["ml_m"] = ("Linear", LogisticRegression()) |
63 | | - elif learner["ml_m"][0] == "LGBM": |
64 | | - learner["ml_m"] = ( |
65 | | - "LGBM", |
66 | | - LGBMClassifier( |
67 | | - n_estimators=500, learning_rate=0.02, verbose=-1, n_jobs=1 |
68 | | - ), |
69 | | - ) |
70 | | - else: |
71 | | - raise ValueError(f"Unknown learner type: {learner['ml_m']}") |
| 43 | + for ml in required_learners: |
| 44 | + assert ml in learner, f"No {ml} specified in the config file" |
72 | 45 |
|
73 | 46 | def _calculate_oracle_values(self): |
74 | 47 | """Calculate oracle values for the simulation.""" |
@@ -102,8 +75,9 @@ def _calculate_oracle_values(self): |
102 | 75 | def run_single_rep(self, dml_data, dml_params) -> Dict[str, Any]: |
103 | 76 | """Run a single repetition with the given parameters.""" |
104 | 77 | # Extract parameters |
105 | | - learner_g_name, ml_g = dml_params["learners"]["ml_g"] |
106 | | - learner_m_name, ml_m = dml_params["learners"]["ml_m"] |
| 78 | + learner_config = dml_params["learners"] |
| 79 | + learner_g_name, ml_g = create_learner_from_config(learner_config["ml_g"]) |
| 80 | + learner_m_name, ml_m = create_learner_from_config(learner_config["ml_m"]) |
107 | 81 | score = dml_params["score"] |
108 | 82 | in_sample_normalization = dml_params["in_sample_normalization"] |
109 | 83 |
|
|
0 commit comments