-
Notifications
You must be signed in to change notification settings - Fork 99
LPLR model #365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
jer2ig
wants to merge
21
commits into
main
Choose a base branch
from
jh-logistic-model
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,308
−14
Draft
LPLR model #365
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
df03887
Logistic regression implementation WIP
jer2ig f5521f1
First WIP of implementation
jer2ig bfa756c
Working implementation. Started on test set-up.
jer2ig d729d0a
Changed data type of arrays
jer2ig 8fe7ca6
Fix variable name
jer2ig 18bac23
Moved into plm folder, started testing setup
jer2ig c6e600d
Fixed bug in score computation
jer2ig 6f556e0
Reverted from ensure_all_finite to force_all_finite
jer2ig 3a332bf
Fixes to instrument score
jer2ig b41a773
Added option for exception on convergence failure
jer2ig c434667
Added unbalanced dataset option, bug fixes
jer2ig 443d82d
Added binary treatment dataset, fixed bug for model check
jer2ig 774c74d
Adjusted dataset balancing
jer2ig 9695820
Renamed Logistic to LPLR
jer2ig dbfea73
Clean-up of branch
jer2ig 29114ce
Ruff checks and formatting
jer2ig 5d2d1ed
Unit tests work and bug fix in lplr
jer2ig 2c626a0
Cleanup
jer2ig 9819436
Tests updated
jer2ig 5a7e279
Pre-commit checks
jer2ig fc03cc6
Pre-commit checks on all files
jer2ig File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| import numpy as np | ||
| import pandas as pd | ||
| from scipy.special import expit | ||
|
|
||
| from doubleml.data import DoubleMLData | ||
| from doubleml.utils._aliases import _get_array_alias, _get_data_frame_alias, _get_dml_data_alias | ||
|
|
||
| _array_alias = _get_array_alias() | ||
| _data_frame_alias = _get_data_frame_alias() | ||
| _dml_data_alias = _get_dml_data_alias() | ||
|
|
||
|
|
||
| def make_lplr_LZZ2020( | ||
| n_obs=500, dim_x=20, alpha=0.5, return_type="DoubleMLData", balanced_r0=True, treatment="continuous", **kwargs | ||
| ): | ||
| r""" | ||
| Generates synthetic data for a logistic partially linear regression model, as in Liu et al. (2021), | ||
| designed for use in double/debiased machine learning applications. | ||
|
|
||
| The data generating process is defined as follows: | ||
|
|
||
| - Covariates :math:`x_i \sim \mathcal{N}(0, \Sigma)`, where :math:`\Sigma_{kj} = 0.7^{|j-k|}`. | ||
| - Treatment :math:`d_i = a_0(x_i)`. | ||
| - Propensity score :math:`p_i = \sigma(\alpha d_i + r_0(x_i))`, where :math:`\sigma(\cdot)` is the logistic function. | ||
| - Outcome :math:`y_i \sim \text{Bernoulli}(p_i)`. | ||
|
|
||
| The nuisance functions are defined as: | ||
|
|
||
| .. math:: | ||
| \begin{aligned} | ||
| a_0(x_i) &= \frac{2}{1 + \exp(x_{i,1})} - \frac{2}{1 + \exp(x_{i,2})} + \sin(x_{i,3}) + \cos(x_{i,4}) \\ | ||
| &\quad + 0.5 \cdot \mathbb{1}(x_{i,5} > 0) - 0.5 \cdot \mathbb{1}(x_{i,6} > 0) + 0.2\, x_{i,7} x_{i,8} | ||
| - 0.2\, x_{i,9} x_{i,10} \\ | ||
| r_0(x_i) &= 0.1\, x_{i,1} x_{i,2} x_{i,3} + 0.1\, x_{i,4} x_{i,5} + 0.1\, x_{i,6}^3 - 0.5 \sin^2(x_{i,7}) \\ | ||
| &\quad + 0.5 \cos(x_{i,8}) + \frac{1}{1 + x_{i,9}^2} - \frac{1}{1 + \exp(x_{i,10})} \\ | ||
| &\quad + 0.25 \cdot \mathbb{1}(x_{i,11} > 0) - 0.25 \cdot \mathbb{1}(x_{i,13} > 0) | ||
| \end{aligned} | ||
|
|
||
| Parameters | ||
| ---------- | ||
| n_obs : int | ||
| Number of observations to simulate. | ||
| dim_x : int | ||
| Number of covariates. | ||
| alpha : float | ||
| Value of the causal parameter. | ||
| return_type : str | ||
| Determines the return format. One of: | ||
|
|
||
| - 'DoubleMLData' or DoubleMLData: returns a ``DoubleMLData`` object. | ||
| - 'DataFrame', 'pd.DataFrame' or pd.DataFrame: returns a ``pandas.DataFrame``. | ||
| - 'array', 'np.ndarray', 'np.array' or np.ndarray: returns tuple of numpy arrays (x, y, d, p). | ||
| balanced_r0 : bool, default True | ||
| If True, uses the "balanced" r_0 specification (smaller magnitude / more balanced | ||
| heterogeneity). If False, uses an "unbalanced" r_0 specification with larger | ||
| share of Y=0. | ||
| treatment : {'continuous', 'binary', 'binary_unbalanced'}, default 'continuous' | ||
| Determines how the treatment d is generated from a_0(x): | ||
| - 'continuous': d = a_0(x) (continuous treatment). | ||
| - 'binary': d ~ Bernoulli( sigmoid(a_0(x) - mean(a_0(x))) ) . | ||
| - 'binary_unbalanced': d ~ Bernoulli( sigmoid(a_0(x)) ). | ||
|
|
||
| **kwargs | ||
| Optional keyword arguments (currently unused in this implementation). | ||
|
|
||
| Returns | ||
| ------- | ||
| Union[DoubleMLData, pd.DataFrame, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]] | ||
| The generated data in the specified format. | ||
|
|
||
| References | ||
| ---------- | ||
| Liu, Molei, Yi Zhang, and Doudou Zhou. 2021. | ||
| "Double/Debiased Machine Learning for Logistic Partially Linear Model." | ||
| The Econometrics Journal 24 (3): 559–88. https://doi.org/10.1093/ectj/utab019. | ||
|
|
||
| """ | ||
|
|
||
| if balanced_r0: | ||
|
|
||
| def r_0(X): | ||
| return ( | ||
| 0.1 * X[:, 0] * X[:, 1] * X[:, 2] | ||
| + 0.1 * X[:, 3] * X[:, 4] | ||
| + 0.1 * X[:, 5] ** 3 | ||
| + -0.5 * np.sin(X[:, 6]) ** 2 | ||
| + 0.5 * np.cos(X[:, 7]) | ||
| + 1 / (1 + X[:, 8] ** 2) | ||
| + -1 / (1 + np.exp(X[:, 9])) | ||
| + 0.25 * np.where(X[:, 10] > 0, 1, 0) | ||
| + -0.25 * np.where(X[:, 12] > 0, 1, 0) | ||
| ) | ||
|
|
||
| else: | ||
|
|
||
| def r_0(X): | ||
| return ( | ||
| 0.1 * X[:, 0] * X[:, 1] * X[:, 2] | ||
| + 0.1 * X[:, 3] * X[:, 4] | ||
| + 0.1 * X[:, 5] ** 3 | ||
| + -0.5 * np.sin(X[:, 6]) ** 2 | ||
| + 0.5 * np.cos(X[:, 7]) | ||
| + 4 / (1 + X[:, 8] ** 2) | ||
| + -1 / (1 + np.exp(X[:, 9])) | ||
| + 1.5 * np.where(X[:, 10] > 0, 1, 0) | ||
| + -0.25 * np.where(X[:, 12] > 0, 1, 0) | ||
| ) | ||
|
|
||
| def a_0(X): | ||
| return ( | ||
| 2 / (1 + np.exp(X[:, 0])) | ||
| + -2 / (1 + np.exp(X[:, 1])) | ||
| + 1 * np.sin(X[:, 2]) | ||
| + 1 * np.cos(X[:, 3]) | ||
| + 0.5 * np.where(X[:, 4] > 0, 1, 0) | ||
| + -0.5 * np.where(X[:, 5] > 0, 1, 0) | ||
| + 0.2 * X[:, 6] * X[:, 7] | ||
| + -0.2 * X[:, 8] * X[:, 9] | ||
| ) | ||
|
|
||
| sigma = np.full((dim_x, dim_x), 0.2) | ||
| np.fill_diagonal(sigma, 1) | ||
|
|
||
| x = np.random.multivariate_normal(np.zeros(dim_x), sigma, size=n_obs) | ||
| np.clip(x, -2, 2, out=x) | ||
|
|
||
| if treatment == "continuous": | ||
| d = a_0(x) | ||
| elif treatment == "binary": | ||
| d_cont = a_0(x) | ||
| d = np.random.binomial(1, expit(d_cont - d_cont.mean())) | ||
| elif treatment == "binary_unbalanced": | ||
| d_cont = a_0(x) | ||
| d = np.random.binomial(1, expit(d_cont)) | ||
| else: | ||
| raise ValueError("Invalid treatment type.") | ||
|
|
||
| p = expit(alpha * d[:] + r_0(x)) | ||
|
|
||
| y = np.random.binomial(1, p) | ||
|
|
||
| if return_type in _array_alias: | ||
| return x, y, d, p | ||
| elif return_type in _data_frame_alias + _dml_data_alias: | ||
| x_cols = [f"X{i + 1}" for i in np.arange(dim_x)] | ||
| data = pd.DataFrame(np.column_stack((x, y, d, p)), columns=x_cols + ["y", "d", "p"]) | ||
| if return_type in _data_frame_alias: | ||
| return data | ||
| else: | ||
| return DoubleMLData(data, "y", "d", x_cols) | ||
| else: | ||
| raise ValueError("Invalid return_type.") | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.