The did_imp
package is a Python implementation of the BJS
estimator.
See:
uv pip install git+https://github.com/jsr-p/did-imputation
The following example is taken from Kyle Butt’s R implementation, see here
First we load the data and plot the average outcome by group and year.
import did_imp
import polars as pl
import seaborn as sns
import matplotlib.pyplot as plt
import plotnine as pn
import pyfixest as pf
df = pl.read_csv(did_imp.utils.proj_folder() / "data/dfhetfull.csv")
print(df.select("dep_var", "year", "g", "unit").head())
shape: (5, 4)
┌───────────┬──────┬─────┬──────┐
│ dep_var ┆ year ┆ g ┆ unit │
│ --- ┆ --- ┆ --- ┆ --- │
│ f64 ┆ i64 ┆ i64 ┆ i64 │
╞═══════════╪══════╪═════╪══════╡
│ 0.221443 ┆ 1990 ┆ 0 ┆ 1 │
│ 0.796955 ┆ 1991 ┆ 0 ┆ 1 │
│ 0.604479 ┆ 1992 ┆ 0 ┆ 1 │
│ -0.935092 ┆ 1993 ┆ 0 ┆ 1 │
│ 2.806541 ┆ 1994 ┆ 0 ┆ 1 │
└───────────┴──────┴─────┴──────┘
gp = (
df.group_by("g", "year")
.agg(pl.col("dep_var").mean())
.with_columns(
pl.col("g").replace_strict(
{
0: "Group 3",
2000: "Group 1",
2010: "Group 2",
}
)
)
)
fig, ax = plt.subplots(figsize=(10, 6))
ax = sns.lineplot(
data=gp.to_pandas(),
x="year",
y="dep_var",
palette={
"Group 1": "red",
"Group 2": "blue",
"Group 3": "purple",
},
hue="g",
ax=ax,
)
ax.axvline(2000, color="black", linestyle="--")
ax.axvline(2010, color="black", linestyle="--")
ax.set(ylabel="Outcome", xlabel="Year")
ax.legend(
title="Treatment Cohort",
loc="upper center",
bbox_to_anchor=(0.5, -0.1),
ncol=3,
frameon=False,
)
sns.despine()
_ = fig.savefig(
did_imp.utils.proj_folder() / "figs/example_avg.png", dpi=300, bbox_inches="tight"
)
plt.close(fig) # quarto
Let’s estimate a static DID
result_static = did_imp.estimate(
data=df,
outcome="dep_var",
time="year",
group="g",
unit="unit",
fes="unit + year",
horizons="static",
)
print(result_static.estimates)
shape: (1, 7)
┌───────┬──────────┬──────────┬───────────┬──────┬──────────┬─────────┐
│ term ┆ estimate ┆ se ┆ tstat ┆ pval ┆ lower ┆ upper │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞═══════╪══════════╪══════════╪═══════════╪══════╪══════════╪═════════╡
│ treat ┆ 2.262952 ┆ 0.031397 ┆ 72.075777 ┆ 0.0 ┆ 2.201414 ┆ 2.32449 │
└───────┴──────────┴──────────┴───────────┴──────┴──────────┴─────────┘
Next, let’s estimate an event study DID.
result = did_imp.estimate(
data=df,
outcome="dep_var",
time="year",
group="g",
unit="unit",
fes="unit + year",
horizons="event",
pretrends=list(range(-5, 0)),
)
print(result.estimates)
shape: (26, 7)
┌──────┬───────────┬──────────┬───────────┬──────────┬───────────┬──────────┐
│ term ┆ estimate ┆ se ┆ tstat ┆ pval ┆ lower ┆ upper │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞══════╪═══════════╪══════════╪═══════════╪══════════╪═══════════╪══════════╡
│ -5 ┆ -0.064121 ┆ 0.07666 ┆ -0.836434 ┆ 0.403111 ┆ -0.214554 ┆ 0.086312 │
│ -4 ┆ -0.012016 ┆ 0.075277 ┆ -0.159621 ┆ 0.873212 ┆ -0.159734 ┆ 0.135703 │
│ -3 ┆ -0.013872 ┆ 0.076535 ┆ -0.181249 ┆ 0.856209 ┆ -0.16406 ┆ 0.136316 │
│ -2 ┆ 0.051031 ┆ 0.07703 ┆ 0.662488 ┆ 0.507811 ┆ -0.100128 ┆ 0.20219 │
│ -1 ┆ 0.020225 ┆ 0.075849 ┆ 0.266642 ┆ 0.7898 ┆ -0.128618 ┆ 0.169067 │
│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │
│ 16 ┆ 2.880654 ┆ 0.115632 ┆ 24.912355 ┆ 0.0 ┆ 2.654016 ┆ 3.107292 │
│ 17 ┆ 2.993839 ┆ 0.114385 ┆ 26.17336 ┆ 0.0 ┆ 2.769644 ┆ 3.218033 │
│ 18 ┆ 2.646169 ┆ 0.115458 ┆ 22.918909 ┆ 0.0 ┆ 2.419872 ┆ 2.872466 │
│ 19 ┆ 2.875306 ┆ 0.114058 ┆ 25.209072 ┆ 0.0 ┆ 2.651752 ┆ 3.098861 │
│ 20 ┆ 2.904657 ┆ 0.113202 ┆ 25.659015 ┆ 0.0 ┆ 2.68278 ┆ 3.126533 │
└──────┴───────────┴──────────┴───────────┴──────────┴───────────┴──────────┘
Let’s compare the estimates from the event study with the true effects
def plot_eventstudy(res: pl.DataFrame):
"""Helper to plot event study estimates"""
p = (
pn.ggplot(res, pn.aes(x="rel_year", y="estimate", color="group"))
+ pn.geom_point(position=pn.position_dodge(width=(w := 0.3)))
+ pn.geom_errorbar(
pn.aes(ymin="lower", ymax="upper"),
width=0.1,
position=pn.position_dodge(width=w),
)
+ pn.theme_classic()
+ pn.labs(
x="Relative Time",
y="Estimate",
color="",
)
+ pn.scale_x_continuous(breaks=list(range(-5, 8)))
+ pn.geom_hline(yintercept=0, color="black", size=0.5, linetype="dotted")
+ pn.geom_vline(xintercept=0, color="black", size=0.5, linetype="dotted")
+ pn.theme(
legend_position="bottom",
axis_title=pn.element_text(size=14),
axis_text=pn.element_text(size=12),
legend_text=pn.element_text(size=12),
legend_title=pn.element_text(size=13),
)
+ pn.scale_color_manual(
values={
"DID Imputation Estimate": "blue",
"True Effect": "purple",
"TWFE Estimate": "green",
}
)
)
return p
# did_imp estimates
estimates = result.estimates.select(
pl.col("term").cast(pl.Int8).alias("rel_year"),
"estimate",
"se",
"lower",
"upper",
pl.lit("DID Imputation Estimate").alias("group"),
).filter(pl.col("rel_year").is_between(-5, 7))
# true effects
te_true = (
df.filter(pl.col("g").gt(0))
.group_by("rel_year")
.agg((pl.col("te") + pl.col("te_dynamic")).mean().alias("estimate"))
.cast({"rel_year": pl.Int8})
.filter(pl.col("rel_year").is_between(-5, 7))
.sort("rel_year")
.with_columns(
se=0,
lower=0,
upper=0,
group=pl.lit("True Effect"),
)
)
results = pl.concat([estimates, te_true], how="vertical_relaxed").with_columns(
pl.col("rel_year").cast(pl.Int8)
)
# Plot the estimates
p = plot_eventstudy(results)
p.save(
did_imp.utils.proj_folder() / "figs/example_estimates.png",
width=10,
height=6,
transparent=False,
verbose=False,
dpi=300,
)
Now let’s compare with estimates from a classical TWFE model.
mod = pf.feols(
"dep_var ~ i(rel_year, ref=-99) | unit + year",
data=df.with_columns(
pl.col("rel_year").replace(
{
# hack to have reference category for never-treated as well
"Inf": "-99",
"-1": "-99",
}
)
).cast({"rel_year": pl.Int8}),
)
print(mod.tidy()[["Estimate", "Std. Error"]].head())
Estimate Std. Error
Coefficient
C(rel_year, contr.treatment(base=-99))[T.-20] 0.228396 0.121547
C(rel_year, contr.treatment(base=-99))[T.-19] 0.225683 0.122520
C(rel_year, contr.treatment(base=-99))[T.-18] 0.124711 0.127552
C(rel_year, contr.treatment(base=-99))[T.-17] 0.143929 0.127100
C(rel_year, contr.treatment(base=-99))[T.-16] 0.122706 0.123787
Lets plot the estimates together with the previous estimates from
did_imp
.
twfe = (
mod.tidy()
.reset_index()
.pipe(pl.from_pandas)
.select(
pl.col("Coefficient")
.str.extract(r"\[T\.(.*)\]")
.cast(pl.Int8)
.alias("rel_year"),
pl.col("Estimate").alias("estimate"),
pl.col("Std. Error").alias("se"),
pl.col("2.5%").alias("lower"),
pl.col("97.5%").alias("upper"),
pl.lit("TWFE Estimate").alias("group"),
)
.filter(pl.col("rel_year").is_between(-5, 7))
)
results = pl.concat([results, twfe], how="vertical_relaxed")
p = plot_eventstudy(results)
p.save(
did_imp.utils.proj_folder() / "figs/example_estimates_wtwfe.png",
width=10,
height=6,
transparent=False,
verbose=False,
dpi=300,
)
TWFE regression
where
-
$Y_{it}$ is the outcome of unit$i$ in period$t$ -
$\alpha_{i}$ is the unit fixed effect -
$\beta_{t}$ is the time fixed effect -
$D_{it}$ is the binary treatment indicator
BJS write
Unit level TE: Causal effects on the treated
Estimation target:
i.e. weighted sum of the treated units’ unit level effects.
-
Overall ATT:
$\omega_{it} = 1/N_{1}$ for all$it \in \Omega_{1}$ . -
Average effect
$h$ periods since treatment for horizon$h \geq 0$ :$\omega_{it} = \mathbf{1}\{K_{it} = h\} / |\Omega_{1, h}|$ where$\quad \Omega_{1, h} = \{it \mid K_{it} = h\}$ -
The difference between average treatment effects at different horizons or across some groups of units, corresponding to weights
$\omega_{it}$ ’s:$\sum_{it \in \Omega_{1}} \omega_{it} = 0$
BJS assumptions A1-A4 lead to model
BJS Theorem 1: the (unique) efficient estimator
-
Estimate
$\theta$ by OLS; solution$\hat{\theta}^{*}$ (assuming$\theta$ is identified) -
Estimate vector of treatment effects
$\tau$ by$\hat{\tau}^{*} = \Gamma \hat{\theta}^{*}$ -
Estimate target
$\tau_{w}$ by$\hat{\tau}_{w}^{*} = w_{1}'\hat{\tau}^{*}$
See the paper for more details; in particular for
-
The Imputation Representation of the efficient estimator
-
The Weight representation of the efficient estimator (how the estimator is actually coded in the package)
-
The Conservative variance estimator (how standard errors are computed in the package)
git clone [email protected]:jsr-p/did-imputation.git
cd tabx
uv venv
uv sync --all-extras
- 🚧🔨⏳Work in progress! 🚧🔨⏳
- Don’t open the
tests
folder yet; 🍝 inside
- 🚧🔨⏳Work in progress! 🚧🔨⏳
- Shout-out to pyfixest; makes everything easier
- Code inspiration taken from: