Skip to content

jsr-p/did-imputation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

did-imputation: Python implementation of the BJS imputation estimator

The did_imp package is a Python implementation of the BJS estimator.

See:

Installation

uv pip install git+https://github.com/jsr-p/did-imputation

Example

The following example is taken from Kyle Butt’s R implementation, see here

First we load the data and plot the average outcome by group and year.

import did_imp
import polars as pl
import seaborn as sns
import matplotlib.pyplot as plt
import plotnine as pn
import pyfixest as pf
df = pl.read_csv(did_imp.utils.proj_folder() / "data/dfhetfull.csv")
print(df.select("dep_var", "year", "g", "unit").head())
shape: (5, 4)
┌───────────┬──────┬─────┬──────┐
│ dep_var   ┆ year ┆ g   ┆ unit │
│ ---       ┆ ---  ┆ --- ┆ ---  │
│ f64       ┆ i64  ┆ i64 ┆ i64  │
╞═══════════╪══════╪═════╪══════╡
│ 0.221443  ┆ 1990 ┆ 0   ┆ 1    │
│ 0.796955  ┆ 1991 ┆ 0   ┆ 1    │
│ 0.604479  ┆ 1992 ┆ 0   ┆ 1    │
│ -0.935092 ┆ 1993 ┆ 0   ┆ 1    │
│ 2.806541  ┆ 1994 ┆ 0   ┆ 1    │
└───────────┴──────┴─────┴──────┘
gp = (
    df.group_by("g", "year")
    .agg(pl.col("dep_var").mean())
    .with_columns(
        pl.col("g").replace_strict(
            {
                0: "Group 3",
                2000: "Group 1",
                2010: "Group 2",
            }
        )
    )
)

fig, ax = plt.subplots(figsize=(10, 6))
ax = sns.lineplot(
    data=gp.to_pandas(),
    x="year",
    y="dep_var",
    palette={
        "Group 1": "red",
        "Group 2": "blue",
        "Group 3": "purple",
    },
    hue="g",
    ax=ax,
)
ax.axvline(2000, color="black", linestyle="--")
ax.axvline(2010, color="black", linestyle="--")
ax.set(ylabel="Outcome", xlabel="Year")
ax.legend(
    title="Treatment Cohort",
    loc="upper center",
    bbox_to_anchor=(0.5, -0.1),
    ncol=3,
    frameon=False,
)
sns.despine()
_ = fig.savefig(
    did_imp.utils.proj_folder() / "figs/example_avg.png", dpi=300, bbox_inches="tight"
)
plt.close(fig)  # quarto

image

Let’s estimate a static DID

result_static = did_imp.estimate(
    data=df,
    outcome="dep_var",
    time="year",
    group="g",
    unit="unit",
    fes="unit + year",
    horizons="static",
)
print(result_static.estimates)
shape: (1, 7)
┌───────┬──────────┬──────────┬───────────┬──────┬──────────┬─────────┐
│ term  ┆ estimate ┆ se       ┆ tstat     ┆ pval ┆ lower    ┆ upper   │
│ ---   ┆ ---      ┆ ---      ┆ ---       ┆ ---  ┆ ---      ┆ ---     │
│ str   ┆ f64      ┆ f64      ┆ f64       ┆ f64  ┆ f64      ┆ f64     │
╞═══════╪══════════╪══════════╪═══════════╪══════╪══════════╪═════════╡
│ treat ┆ 2.262952 ┆ 0.031397 ┆ 72.075777 ┆ 0.0  ┆ 2.201414 ┆ 2.32449 │
└───────┴──────────┴──────────┴───────────┴──────┴──────────┴─────────┘

Next, let’s estimate an event study DID.

result = did_imp.estimate(
    data=df,
    outcome="dep_var",
    time="year",
    group="g",
    unit="unit",
    fes="unit + year",
    horizons="event",
    pretrends=list(range(-5, 0)),
)
print(result.estimates)
shape: (26, 7)
┌──────┬───────────┬──────────┬───────────┬──────────┬───────────┬──────────┐
│ term ┆ estimate  ┆ se       ┆ tstat     ┆ pval     ┆ lower     ┆ upper    │
│ ---  ┆ ---       ┆ ---      ┆ ---       ┆ ---      ┆ ---       ┆ ---      │
│ str  ┆ f64       ┆ f64      ┆ f64       ┆ f64      ┆ f64       ┆ f64      │
╞══════╪═══════════╪══════════╪═══════════╪══════════╪═══════════╪══════════╡
│ -5   ┆ -0.064121 ┆ 0.07666  ┆ -0.836434 ┆ 0.403111 ┆ -0.214554 ┆ 0.086312 │
│ -4   ┆ -0.012016 ┆ 0.075277 ┆ -0.159621 ┆ 0.873212 ┆ -0.159734 ┆ 0.135703 │
│ -3   ┆ -0.013872 ┆ 0.076535 ┆ -0.181249 ┆ 0.856209 ┆ -0.16406  ┆ 0.136316 │
│ -2   ┆ 0.051031  ┆ 0.07703  ┆ 0.662488  ┆ 0.507811 ┆ -0.100128 ┆ 0.20219  │
│ -1   ┆ 0.020225  ┆ 0.075849 ┆ 0.266642  ┆ 0.7898   ┆ -0.128618 ┆ 0.169067 │
│ …    ┆ …         ┆ …        ┆ …         ┆ …        ┆ …         ┆ …        │
│ 16   ┆ 2.880654  ┆ 0.115632 ┆ 24.912355 ┆ 0.0      ┆ 2.654016  ┆ 3.107292 │
│ 17   ┆ 2.993839  ┆ 0.114385 ┆ 26.17336  ┆ 0.0      ┆ 2.769644  ┆ 3.218033 │
│ 18   ┆ 2.646169  ┆ 0.115458 ┆ 22.918909 ┆ 0.0      ┆ 2.419872  ┆ 2.872466 │
│ 19   ┆ 2.875306  ┆ 0.114058 ┆ 25.209072 ┆ 0.0      ┆ 2.651752  ┆ 3.098861 │
│ 20   ┆ 2.904657  ┆ 0.113202 ┆ 25.659015 ┆ 0.0      ┆ 2.68278   ┆ 3.126533 │
└──────┴───────────┴──────────┴───────────┴──────────┴───────────┴──────────┘

Let’s compare the estimates from the event study with the true effects

def plot_eventstudy(res: pl.DataFrame):
    """Helper to plot event study estimates"""
    p = (
        pn.ggplot(res, pn.aes(x="rel_year", y="estimate", color="group"))
        + pn.geom_point(position=pn.position_dodge(width=(w := 0.3)))
        + pn.geom_errorbar(
            pn.aes(ymin="lower", ymax="upper"),
            width=0.1,
            position=pn.position_dodge(width=w),
        )
        + pn.theme_classic()
        + pn.labs(
            x="Relative Time",
            y="Estimate",
            color="",
        )
        + pn.scale_x_continuous(breaks=list(range(-5, 8)))
        + pn.geom_hline(yintercept=0, color="black", size=0.5, linetype="dotted")
        + pn.geom_vline(xintercept=0, color="black", size=0.5, linetype="dotted")
        + pn.theme(
            legend_position="bottom",
            axis_title=pn.element_text(size=14),
            axis_text=pn.element_text(size=12),
            legend_text=pn.element_text(size=12),
            legend_title=pn.element_text(size=13),
        )
        + pn.scale_color_manual(
            values={
                "DID Imputation Estimate": "blue",
                "True Effect": "purple",
                "TWFE Estimate": "green",
            }
        )
    )
    return p


# did_imp estimates
estimates = result.estimates.select(
    pl.col("term").cast(pl.Int8).alias("rel_year"),
    "estimate",
    "se",
    "lower",
    "upper",
    pl.lit("DID Imputation Estimate").alias("group"),
).filter(pl.col("rel_year").is_between(-5, 7))

# true effects
te_true = (
    df.filter(pl.col("g").gt(0))
    .group_by("rel_year")
    .agg((pl.col("te") + pl.col("te_dynamic")).mean().alias("estimate"))
    .cast({"rel_year": pl.Int8})
    .filter(pl.col("rel_year").is_between(-5, 7))
    .sort("rel_year")
    .with_columns(
        se=0,
        lower=0,
        upper=0,
        group=pl.lit("True Effect"),
    )
)
results = pl.concat([estimates, te_true], how="vertical_relaxed").with_columns(
    pl.col("rel_year").cast(pl.Int8)
)

# Plot the estimates
p = plot_eventstudy(results)
p.save(
    did_imp.utils.proj_folder() / "figs/example_estimates.png",
    width=10,
    height=6,
    transparent=False,
    verbose=False,
    dpi=300,
)

image

Now let’s compare with estimates from a classical TWFE model.

mod = pf.feols(
    "dep_var ~ i(rel_year, ref=-99) | unit + year",
    data=df.with_columns(
        pl.col("rel_year").replace(
            {
                # hack to have reference category for never-treated as well
                "Inf": "-99",
                "-1": "-99",
            }
        )
    ).cast({"rel_year": pl.Int8}),
)
print(mod.tidy()[["Estimate", "Std. Error"]].head())
                                               Estimate  Std. Error
Coefficient                                                        
C(rel_year, contr.treatment(base=-99))[T.-20]  0.228396    0.121547
C(rel_year, contr.treatment(base=-99))[T.-19]  0.225683    0.122520
C(rel_year, contr.treatment(base=-99))[T.-18]  0.124711    0.127552
C(rel_year, contr.treatment(base=-99))[T.-17]  0.143929    0.127100
C(rel_year, contr.treatment(base=-99))[T.-16]  0.122706    0.123787

Lets plot the estimates together with the previous estimates from did_imp.

twfe = (
    mod.tidy()
    .reset_index()
    .pipe(pl.from_pandas)
    .select(
        pl.col("Coefficient")
        .str.extract(r"\[T\.(.*)\]")
        .cast(pl.Int8)
        .alias("rel_year"),
        pl.col("Estimate").alias("estimate"),
        pl.col("Std. Error").alias("se"),
        pl.col("2.5%").alias("lower"),
        pl.col("97.5%").alias("upper"),
        pl.lit("TWFE Estimate").alias("group"),
    )
    .filter(pl.col("rel_year").is_between(-5, 7))
)

results = pl.concat([results, twfe], how="vertical_relaxed")
p = plot_eventstudy(results)
p.save(
    did_imp.utils.proj_folder() / "figs/example_estimates_wtwfe.png",
    width=10,
    height=6,
    transparent=False,
    verbose=False,
    dpi=300,
)

image


Some theory

Borusyak estimator

TWFE regression

$$Y_{it} = \alpha_{i} + \beta_{t} + \tau D_{it} + \varepsilon_{it}$$

where

  • $Y_{it}$ is the outcome of unit $i$ in period $t$

  • $\alpha_{i}$ is the unit fixed effect

  • $\beta_{t}$ is the time fixed effect

  • $D_{it}$ is the binary treatment indicator

BJS write $it := (i, t)$ and defines $\Omega$ as the set of all unit and time pairs. BJS further defines

$$\begin{aligned} \Omega_{1} = \{it \in \Omega \mid D_{it} = 1\} \\\ \Omega_{0} = \{it \in \Omega \mid D_{it} = 0\} \\\ |\Omega_{1}| = N_{1}, \quad |\Omega_{0}| = N_{0}. \end{aligned}$$

Unit level TE: Causal effects on the treated $it \in \Omega_{1}$ denoted by

$$\tau_{it} = E\left\{ Y_{it}(1) - Y_{it}(0) \right\}$$

Estimation target:

$$\tau_{\omega} = \sum_{it \in \Omega_{1}} \omega_{it}\tau_{it} := \omega'_{1}\tau$$

i.e. weighted sum of the treated units’ unit level effects.

  • Overall ATT: $\omega_{it} = 1/N_{1}$ for all $it \in \Omega_{1}$.

  • Average effect $h$ periods since treatment for horizon $h \geq 0$: $\omega_{it} = \mathbf{1}\{K_{it} = h\} / |\Omega_{1, h}|$ where $\quad \Omega_{1, h} = \{it \mid K_{it} = h\}$

  • The difference between average treatment effects at different horizons or across some groups of units, corresponding to weights $\omega_{it}$’s: $\sum_{it \in \Omega_{1}} \omega_{it} = 0$

BJS assumptions A1-A4 lead to model

$$Y_{it} = A'_{it} \lambda_{i} + X_{it}' \delta + D_{it}\Gamma_{it}'\theta + \varepsilon_{it}$$

BJS Theorem 1: the (unique) efficient estimator $\tau^{*}_{w}$ of $\tau_{\omega}$ can be obtained with the following steps

  1. Estimate $\theta$ by OLS; solution $\hat{\theta}^{*}$ (assuming $\theta$ is identified)

  2. Estimate vector of treatment effects $\tau$ by $\hat{\tau}^{*} = \Gamma \hat{\theta}^{*}$

  3. Estimate target $\tau_{w}$ by $\hat{\tau}_{w}^{*} = w_{1}'\hat{\tau}^{*}$

See the paper for more details; in particular for

  • The Imputation Representation of the efficient estimator

  • The Weight representation of the efficient estimator (how the estimator is actually coded in the package)

  • The Conservative variance estimator (how standard errors are computed in the package)

Development

git clone [email protected]:jsr-p/did-imputation.git
cd tabx
uv venv
uv sync --all-extras

Testing

  • 🚧🔨⏳Work in progress! 🚧🔨⏳
  • Don’t open the tests folder yet; 🍝 inside

Docs

  • 🚧🔨⏳Work in progress! 🚧🔨⏳

References

Shout-out

About

BJS estimator Python implementation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published