Skip to content

Commit ae15abf

Browse files
authored
Merge pull request #295 from DoubleML/s-update-sensitivity
Update Sensitivity Operations & APO Model
2 parents d72bcfe + 1382b89 commit ae15abf

37 files changed

+927
-383
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ MANIFEST
3030
*.idea
3131
*.vscode
3232
.flake8
33+
.coverage

doubleml/did/did.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
from sklearn.utils import check_X_y
55
from sklearn.utils.multiclass import type_of_target
66

7-
from ..double_ml import DoubleML
8-
from ..double_ml_data import DoubleMLData
9-
from ..double_ml_score_mixins import LinearScoreMixin
10-
from ..utils._checks import _check_finite_predictions, _check_is_propensity, _check_score, _check_trimming
11-
from ..utils._estimation import _dml_cv_predict, _dml_tune, _get_cond_smpls, _trimm
7+
from doubleml.double_ml import DoubleML
8+
from doubleml.double_ml_data import DoubleMLData
9+
from doubleml.double_ml_score_mixins import LinearScoreMixin
10+
from doubleml.utils._checks import _check_finite_predictions, _check_is_propensity, _check_score, _check_trimming
11+
from doubleml.utils._estimation import _dml_cv_predict, _dml_tune, _get_cond_smpls
12+
from doubleml.utils._propensity_score import _trimm
1213

1314

1415
class DoubleMLDID(LinearScoreMixin, DoubleML):

doubleml/did/did_cs.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
from sklearn.utils import check_X_y
55
from sklearn.utils.multiclass import type_of_target
66

7-
from ..double_ml import DoubleML
8-
from ..double_ml_data import DoubleMLData
9-
from ..double_ml_score_mixins import LinearScoreMixin
10-
from ..utils._checks import _check_finite_predictions, _check_is_propensity, _check_score, _check_trimming
11-
from ..utils._estimation import _dml_cv_predict, _dml_tune, _get_cond_smpls_2d, _trimm
7+
from doubleml.double_ml import DoubleML
8+
from doubleml.double_ml_data import DoubleMLData
9+
from doubleml.double_ml_score_mixins import LinearScoreMixin
10+
from doubleml.utils._checks import _check_finite_predictions, _check_is_propensity, _check_score, _check_trimming
11+
from doubleml.utils._estimation import _dml_cv_predict, _dml_tune, _get_cond_smpls_2d
12+
from doubleml.utils._propensity_score import _trimm
1213

1314

1415
class DoubleMLDIDCS(LinearScoreMixin, DoubleML):

doubleml/double_ml.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .double_ml_framework import DoubleMLFramework
1212
from .utils._checks import _check_external_predictions, _check_sample_splitting
1313
from .utils._estimation import _aggregate_coefs_and_ses, _rmse, _set_external_predictions, _var_est
14+
from .utils._sensitivity import _compute_sensitivity_bias
1415
from .utils.gain_statistics import gain_statistics
1516
from .utils.resampling import DoubleMLClusterResampling, DoubleMLResampling
1617

@@ -525,6 +526,9 @@ def fit(self, n_jobs_cv=None, store_predictions=True, external_predictions=None,
525526
# aggregated parameter estimates and standard errors from repeated cross-fitting
526527
self.coef, self.se = _aggregate_coefs_and_ses(self._all_coef, self._all_se, self._var_scaling_factors)
527528

529+
# validate sensitivity elements (e.g., re-estimate nu2 if negative)
530+
self._validate_sensitivity_elements()
531+
528532
# construct framework for inference
529533
self._framework = self.construct_framework()
530534

@@ -553,18 +557,27 @@ def construct_framework(self):
553557
"var_scaling_factors": self._var_scaling_factors,
554558
"scaled_psi": scaled_psi_reshape,
555559
"is_cluster_data": self._is_cluster_data,
560+
"treatment_names": self._dml_data.d_cols,
556561
}
557562

558563
if self._sensitivity_implemented:
559-
# reshape sensitivity elements to (n_obs, n_coefs, n_rep)
564+
# reshape sensitivity elements to (1 or n_obs, n_coefs, n_rep)
565+
sensitivity_dict = {
566+
"sigma2": np.transpose(self.sensitivity_elements["sigma2"], (0, 2, 1)),
567+
"nu2": np.transpose(self.sensitivity_elements["nu2"], (0, 2, 1)),
568+
"psi_sigma2": np.transpose(self.sensitivity_elements["psi_sigma2"], (0, 2, 1)),
569+
"psi_nu2": np.transpose(self.sensitivity_elements["psi_nu2"], (0, 2, 1)),
570+
}
571+
572+
max_bias, psi_max_bias = _compute_sensitivity_bias(**sensitivity_dict)
573+
560574
doubleml_dict.update(
561575
{
562576
"sensitivity_elements": {
563-
"sigma2": np.transpose(self.sensitivity_elements["sigma2"], (0, 2, 1)),
564-
"nu2": np.transpose(self.sensitivity_elements["nu2"], (0, 2, 1)),
565-
"psi_sigma2": np.transpose(self.sensitivity_elements["psi_sigma2"], (0, 2, 1)),
566-
"psi_nu2": np.transpose(self.sensitivity_elements["psi_nu2"], (0, 2, 1)),
567-
"riesz_rep": np.transpose(self.sensitivity_elements["riesz_rep"], (0, 2, 1)),
577+
"max_bias": max_bias,
578+
"psi_max_bias": psi_max_bias,
579+
"sigma2": sensitivity_dict["sigma2"],
580+
"nu2": sensitivity_dict["nu2"],
568581
}
569582
}
570583
)
@@ -1423,6 +1436,27 @@ def _initialize_sensitivity_elements(self, score_dim):
14231436
}
14241437
return sensitivity_elements
14251438

1439+
def _validate_sensitivity_elements(self):
1440+
if self._sensitivity_implemented:
1441+
for i_treat in range(self._dml_data.n_treat):
1442+
nu2 = self.sensitivity_elements["nu2"][:, :, i_treat]
1443+
riesz_rep = self.sensitivity_elements["riesz_rep"][:, :, i_treat]
1444+
1445+
if np.any(nu2 <= 0):
1446+
treatment_name = self._dml_data.d_cols[i_treat]
1447+
msg = (
1448+
f"The estimated nu2 for {treatment_name} is not positive. "
1449+
"Re-estimation based on riesz representer (non-orthogonal)."
1450+
)
1451+
warnings.warn(msg, UserWarning)
1452+
psi_nu2 = np.power(riesz_rep, 2)
1453+
nu2 = np.mean(psi_nu2, axis=0, keepdims=True)
1454+
1455+
self.sensitivity_elements["nu2"][:, :, i_treat] = nu2
1456+
self.sensitivity_elements["psi_nu2"][:, :, i_treat] = psi_nu2
1457+
1458+
return
1459+
14261460
def _get_sensitivity_elements(self, i_rep, i_treat):
14271461
sensitivity_elements = {key: value[:, i_rep, i_treat] for key, value in self.sensitivity_elements.items()}
14281462
return sensitivity_elements

0 commit comments

Comments
 (0)