Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
ea1e2da
remove nu2, sigma2 and scores from input checks
SvenKlaassen Jan 30, 2025
5c184af
add bias and score bias to sensitivity elements
SvenKlaassen Feb 3, 2025
0decafc
Merge remote-tracking branch 'origin/main' into s-update-sensitivity
SvenKlaassen Feb 3, 2025
e21f17e
add bias and bias score to framework
SvenKlaassen Feb 3, 2025
4d870b9
add bias and bias score to framework
SvenKlaassen Feb 3, 2025
b7e7c4c
add error handling for negative nu2
SvenKlaassen Feb 3, 2025
10e9648
rename to max_bias and psi_max_bias
SvenKlaassen Feb 3, 2025
badf543
update sensitivity_calculation to bias formula
SvenKlaassen Feb 3, 2025
b33b4c0
fix warning
SvenKlaassen Feb 3, 2025
635e6b4
fix bias score scaling
SvenKlaassen Feb 3, 2025
9e4e6e5
update framework sensitivity calculations
SvenKlaassen Feb 3, 2025
287355f
fix tests
SvenKlaassen Feb 3, 2025
7179a56
remove max_bias and psi_max_bias from doubleml class
SvenKlaassen Feb 3, 2025
33cb705
fix test
SvenKlaassen Feb 3, 2025
b3ec350
remove nu2, sigma2 and scores from framework
SvenKlaassen Jan 30, 2025
c413e0f
fix sensitivity input test
SvenKlaassen Feb 4, 2025
7ed5e8c
remove sigma2, nu2 and scores from the input
SvenKlaassen Feb 4, 2025
a614728
remove nu2, sigma2 and scores from the framework operations
SvenKlaassen Feb 4, 2025
1bcd451
fix sensitivity shape tests
SvenKlaassen Feb 4, 2025
f72fc69
update gain statistics
SvenKlaassen Feb 4, 2025
382ad70
add benchmark variables to framework
SvenKlaassen Feb 5, 2025
3a828c5
add shape and input tests for sigma2 and nu2
SvenKlaassen Feb 5, 2025
dbb4045
update nu2 reestimation
SvenKlaassen Feb 5, 2025
b1d1fee
fix sensitivity validation
SvenKlaassen Feb 5, 2025
f62c345
add treatment names to construct_framework
SvenKlaassen Feb 5, 2025
a35ae96
fix warning message
SvenKlaassen Feb 5, 2025
6f05b11
remove unnecessary assignment
SvenKlaassen Feb 5, 2025
dc298f1
formatting
SvenKlaassen Feb 5, 2025
a118bf2
fix difference framework bounds
SvenKlaassen Feb 5, 2025
7353db8
add test comparing apo to irm
SvenKlaassen Feb 5, 2025
2a0aee0
add ci to irm apo comparison
SvenKlaassen Feb 5, 2025
38eb96d
update psi_a in irm and apo scores
SvenKlaassen Feb 5, 2025
bb05d74
update comparison test weighted irm and apo
SvenKlaassen Feb 5, 2025
8ea4a38
move _compute_sensitivity_bias to utils
SvenKlaassen Feb 6, 2025
54f59b1
update apo model
SvenKlaassen Feb 7, 2025
de76905
update apo tests
SvenKlaassen Feb 7, 2025
3d7d315
update apos model
SvenKlaassen Feb 7, 2025
117d816
update apos tests
SvenKlaassen Feb 7, 2025
82991de
formatting
SvenKlaassen Feb 7, 2025
ef85a9f
update skip_index in causal constrast
SvenKlaassen Feb 7, 2025
ef45e55
update causal contrast inner loop
SvenKlaassen Feb 7, 2025
bcbad76
add sensitivity update for apos
SvenKlaassen Feb 7, 2025
5da072e
add sensitivity comparison to unit tests
SvenKlaassen Feb 7, 2025
ee74514
update irm_vs_apo sensitivity test
SvenKlaassen Feb 7, 2025
55b7e45
update atte test irm vs apo
SvenKlaassen Feb 7, 2025
62ce1dd
Merge pull request #297 from DoubleML/s-update-apo
SvenKlaassen Feb 7, 2025
7467267
fix return type test
SvenKlaassen Feb 8, 2025
1184ef5
Merge branch 'main' into s-update-sensitivity
SvenKlaassen Feb 10, 2025
0152e75
move location from _trimm and _normalize_ipw
SvenKlaassen Feb 10, 2025
0ae0aed
use np.clip
SvenKlaassen Feb 10, 2025
d922d7b
update propensity score adjustment in irm
SvenKlaassen Feb 10, 2025
48ec289
update prop score adjustments apo
SvenKlaassen Feb 10, 2025
3e95ed4
fix _trimm
SvenKlaassen Feb 10, 2025
915b638
add tests for sensitivity operations
SvenKlaassen Feb 13, 2025
be5b39c
add pytest-cov to dev requirements
SvenKlaassen Feb 21, 2025
4b49ecd
add .coverage to gitignore
SvenKlaassen Feb 21, 2025
1382b89
update IRM ATT estimation
SvenKlaassen Mar 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ MANIFEST
*.idea
*.vscode
.flake8
.coverage
11 changes: 6 additions & 5 deletions doubleml/did/did.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from sklearn.utils import check_X_y
from sklearn.utils.multiclass import type_of_target

from ..double_ml import DoubleML
from ..double_ml_data import DoubleMLData
from ..double_ml_score_mixins import LinearScoreMixin
from ..utils._checks import _check_finite_predictions, _check_is_propensity, _check_score, _check_trimming
from ..utils._estimation import _dml_cv_predict, _dml_tune, _get_cond_smpls, _trimm
from doubleml.double_ml import DoubleML
from doubleml.double_ml_data import DoubleMLData
from doubleml.double_ml_score_mixins import LinearScoreMixin
from doubleml.utils._checks import _check_finite_predictions, _check_is_propensity, _check_score, _check_trimming
from doubleml.utils._estimation import _dml_cv_predict, _dml_tune, _get_cond_smpls
from doubleml.utils._propensity_score import _trimm


class DoubleMLDID(LinearScoreMixin, DoubleML):
Expand Down
11 changes: 6 additions & 5 deletions doubleml/did/did_cs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from sklearn.utils import check_X_y
from sklearn.utils.multiclass import type_of_target

from ..double_ml import DoubleML
from ..double_ml_data import DoubleMLData
from ..double_ml_score_mixins import LinearScoreMixin
from ..utils._checks import _check_finite_predictions, _check_is_propensity, _check_score, _check_trimming
from ..utils._estimation import _dml_cv_predict, _dml_tune, _get_cond_smpls_2d, _trimm
from doubleml.double_ml import DoubleML
from doubleml.double_ml_data import DoubleMLData
from doubleml.double_ml_score_mixins import LinearScoreMixin
from doubleml.utils._checks import _check_finite_predictions, _check_is_propensity, _check_score, _check_trimming
from doubleml.utils._estimation import _dml_cv_predict, _dml_tune, _get_cond_smpls_2d
from doubleml.utils._propensity_score import _trimm


class DoubleMLDIDCS(LinearScoreMixin, DoubleML):
Expand Down
46 changes: 40 additions & 6 deletions doubleml/double_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .double_ml_framework import DoubleMLFramework
from .utils._checks import _check_external_predictions, _check_sample_splitting
from .utils._estimation import _aggregate_coefs_and_ses, _rmse, _set_external_predictions, _var_est
from .utils._sensitivity import _compute_sensitivity_bias
from .utils.gain_statistics import gain_statistics
from .utils.resampling import DoubleMLClusterResampling, DoubleMLResampling

Expand Down Expand Up @@ -525,6 +526,9 @@ def fit(self, n_jobs_cv=None, store_predictions=True, external_predictions=None,
# aggregated parameter estimates and standard errors from repeated cross-fitting
self.coef, self.se = _aggregate_coefs_and_ses(self._all_coef, self._all_se, self._var_scaling_factors)

# validate sensitivity elements (e.g., re-estimate nu2 if negative)
self._validate_sensitivity_elements()

# construct framework for inference
self._framework = self.construct_framework()

Expand Down Expand Up @@ -553,18 +557,27 @@ def construct_framework(self):
"var_scaling_factors": self._var_scaling_factors,
"scaled_psi": scaled_psi_reshape,
"is_cluster_data": self._is_cluster_data,
"treatment_names": self._dml_data.d_cols,
}

if self._sensitivity_implemented:
# reshape sensitivity elements to (n_obs, n_coefs, n_rep)
# reshape sensitivity elements to (1 or n_obs, n_coefs, n_rep)
sensitivity_dict = {
"sigma2": np.transpose(self.sensitivity_elements["sigma2"], (0, 2, 1)),
"nu2": np.transpose(self.sensitivity_elements["nu2"], (0, 2, 1)),
"psi_sigma2": np.transpose(self.sensitivity_elements["psi_sigma2"], (0, 2, 1)),
"psi_nu2": np.transpose(self.sensitivity_elements["psi_nu2"], (0, 2, 1)),
}

max_bias, psi_max_bias = _compute_sensitivity_bias(**sensitivity_dict)

doubleml_dict.update(
{
"sensitivity_elements": {
"sigma2": np.transpose(self.sensitivity_elements["sigma2"], (0, 2, 1)),
"nu2": np.transpose(self.sensitivity_elements["nu2"], (0, 2, 1)),
"psi_sigma2": np.transpose(self.sensitivity_elements["psi_sigma2"], (0, 2, 1)),
"psi_nu2": np.transpose(self.sensitivity_elements["psi_nu2"], (0, 2, 1)),
"riesz_rep": np.transpose(self.sensitivity_elements["riesz_rep"], (0, 2, 1)),
"max_bias": max_bias,
"psi_max_bias": psi_max_bias,
"sigma2": sensitivity_dict["sigma2"],
"nu2": sensitivity_dict["nu2"],
}
}
)
Expand Down Expand Up @@ -1423,6 +1436,27 @@ def _initialize_sensitivity_elements(self, score_dim):
}
return sensitivity_elements

def _validate_sensitivity_elements(self):
if self._sensitivity_implemented:
for i_treat in range(self._dml_data.n_treat):
nu2 = self.sensitivity_elements["nu2"][:, :, i_treat]
riesz_rep = self.sensitivity_elements["riesz_rep"][:, :, i_treat]

if np.any(nu2 <= 0):
treatment_name = self._dml_data.d_cols[i_treat]
msg = (
f"The estimated nu2 for {treatment_name} is not positive. "
"Re-estimation based on riesz representer (non-orthogonal)."
)
warnings.warn(msg, UserWarning)
psi_nu2 = np.power(riesz_rep, 2)
nu2 = np.mean(psi_nu2, axis=0, keepdims=True)

self.sensitivity_elements["nu2"][:, :, i_treat] = nu2
self.sensitivity_elements["psi_nu2"][:, :, i_treat] = psi_nu2

return

def _get_sensitivity_elements(self, i_rep, i_treat):
sensitivity_elements = {key: value[:, i_rep, i_treat] for key, value in self.sensitivity_elements.items()}
return sensitivity_elements
Expand Down
Loading