|
11 | 11 | from .double_ml_framework import DoubleMLFramework |
12 | 12 | from .utils._checks import _check_external_predictions, _check_sample_splitting |
13 | 13 | from .utils._estimation import _aggregate_coefs_and_ses, _rmse, _set_external_predictions, _var_est |
| 14 | +from .utils._sensitivity import _compute_sensitivity_bias |
14 | 15 | from .utils.gain_statistics import gain_statistics |
15 | 16 | from .utils.resampling import DoubleMLClusterResampling, DoubleMLResampling |
16 | 17 |
|
@@ -525,6 +526,9 @@ def fit(self, n_jobs_cv=None, store_predictions=True, external_predictions=None, |
525 | 526 | # aggregated parameter estimates and standard errors from repeated cross-fitting |
526 | 527 | self.coef, self.se = _aggregate_coefs_and_ses(self._all_coef, self._all_se, self._var_scaling_factors) |
527 | 528 |
|
| 529 | + # validate sensitivity elements (e.g., re-estimate nu2 if negative) |
| 530 | + self._validate_sensitivity_elements() |
| 531 | + |
528 | 532 | # construct framework for inference |
529 | 533 | self._framework = self.construct_framework() |
530 | 534 |
|
@@ -553,18 +557,27 @@ def construct_framework(self): |
553 | 557 | "var_scaling_factors": self._var_scaling_factors, |
554 | 558 | "scaled_psi": scaled_psi_reshape, |
555 | 559 | "is_cluster_data": self._is_cluster_data, |
| 560 | + "treatment_names": self._dml_data.d_cols, |
556 | 561 | } |
557 | 562 |
|
558 | 563 | if self._sensitivity_implemented: |
559 | | - # reshape sensitivity elements to (n_obs, n_coefs, n_rep) |
| 564 | + # reshape sensitivity elements to (1 or n_obs, n_coefs, n_rep) |
| 565 | + sensitivity_dict = { |
| 566 | + "sigma2": np.transpose(self.sensitivity_elements["sigma2"], (0, 2, 1)), |
| 567 | + "nu2": np.transpose(self.sensitivity_elements["nu2"], (0, 2, 1)), |
| 568 | + "psi_sigma2": np.transpose(self.sensitivity_elements["psi_sigma2"], (0, 2, 1)), |
| 569 | + "psi_nu2": np.transpose(self.sensitivity_elements["psi_nu2"], (0, 2, 1)), |
| 570 | + } |
| 571 | + |
| 572 | + max_bias, psi_max_bias = _compute_sensitivity_bias(**sensitivity_dict) |
| 573 | + |
560 | 574 | doubleml_dict.update( |
561 | 575 | { |
562 | 576 | "sensitivity_elements": { |
563 | | - "sigma2": np.transpose(self.sensitivity_elements["sigma2"], (0, 2, 1)), |
564 | | - "nu2": np.transpose(self.sensitivity_elements["nu2"], (0, 2, 1)), |
565 | | - "psi_sigma2": np.transpose(self.sensitivity_elements["psi_sigma2"], (0, 2, 1)), |
566 | | - "psi_nu2": np.transpose(self.sensitivity_elements["psi_nu2"], (0, 2, 1)), |
567 | | - "riesz_rep": np.transpose(self.sensitivity_elements["riesz_rep"], (0, 2, 1)), |
| 577 | + "max_bias": max_bias, |
| 578 | + "psi_max_bias": psi_max_bias, |
| 579 | + "sigma2": sensitivity_dict["sigma2"], |
| 580 | + "nu2": sensitivity_dict["nu2"], |
568 | 581 | } |
569 | 582 | } |
570 | 583 | ) |
@@ -1423,6 +1436,27 @@ def _initialize_sensitivity_elements(self, score_dim): |
1423 | 1436 | } |
1424 | 1437 | return sensitivity_elements |
1425 | 1438 |
|
| 1439 | + def _validate_sensitivity_elements(self): |
| 1440 | + if self._sensitivity_implemented: |
| 1441 | + for i_treat in range(self._dml_data.n_treat): |
| 1442 | + nu2 = self.sensitivity_elements["nu2"][:, :, i_treat] |
| 1443 | + riesz_rep = self.sensitivity_elements["riesz_rep"][:, :, i_treat] |
| 1444 | + |
| 1445 | + if np.any(nu2 <= 0): |
| 1446 | + treatment_name = self._dml_data.d_cols[i_treat] |
| 1447 | + msg = ( |
| 1448 | + f"The estimated nu2 for {treatment_name} is not positive. " |
| 1449 | + "Re-estimation based on riesz representer (non-orthogonal)." |
| 1450 | + ) |
| 1451 | + warnings.warn(msg, UserWarning) |
| 1452 | + psi_nu2 = np.power(riesz_rep, 2) |
| 1453 | + nu2 = np.mean(psi_nu2, axis=0, keepdims=True) |
| 1454 | + |
| 1455 | + self.sensitivity_elements["nu2"][:, :, i_treat] = nu2 |
| 1456 | + self.sensitivity_elements["psi_nu2"][:, :, i_treat] = psi_nu2 |
| 1457 | + |
| 1458 | + return |
| 1459 | + |
1426 | 1460 | def _get_sensitivity_elements(self, i_rep, i_treat): |
1427 | 1461 | sensitivity_elements = {key: value[:, i_rep, i_treat] for key, value in self.sensitivity_elements.items()} |
1428 | 1462 | return sensitivity_elements |
|
0 commit comments