Skip to content

Commit dd3956d

Browse files
committed
Clean plots and results file
1 parent c81b861 commit dd3956d

File tree

1 file changed

+3
-13
lines changed

1 file changed

+3
-13
lines changed

delphi/log/result_analysis.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,12 @@
99
from sklearn.metrics import roc_auc_score, roc_curve
1010

1111

12-
# --- 1. NEW PLOTTING FUNCTIONS ---
1312

1413
def plot_fuzz_vs_intervention(latent_df: pd.DataFrame, out_dir: Path, run_label: str):
1514
"""
1615
Replicates the Scatter Plot from the paper (Figure 3/Appendix G).
1716
Plots Fuzz Score vs. Intervention Score for the same latents.
1817
"""
19-
# We need to merge the rows for 'fuzz' and 'surprisal_intervention'
20-
# 1. Pivot the table so we have columns: 'latent_idx', 'fuzz_score', 'intervention_score'
2118

2219
# Extract Fuzz Scores (using F1 or Accuracy as the metric)
2320
fuzz_df = latent_df[latent_df["score_type"] == "fuzz"].copy()
@@ -32,12 +29,10 @@ def plot_fuzz_vs_intervention(latent_df: pd.DataFrame, out_dir: Path, run_label:
3229
int_df = latent_df[latent_df["score_type"] == "surprisal_intervention"].copy()
3330
if int_df.empty: return
3431

35-
# Deduplicate intervention scores
3632
int_metrics = int_df.drop_duplicates(subset=["module", "latent_idx"])[
3733
["module", "latent_idx", "avg_kl_divergence", "final_score"]
3834
]
3935

40-
# Merge them
4136
merged = pd.merge(fuzz_metrics, int_metrics, on=["module", "latent_idx"])
4237

4338
if merged.empty:
@@ -63,7 +58,8 @@ def plot_fuzz_vs_intervention(latent_df: pd.DataFrame, out_dir: Path, run_label:
6358
y="final_score",
6459
hover_data=["latent_idx"],
6560
title=f"Correlation vs. Causation (Score) - {run_label}",
66-
labels={"fuzz_score": "Fuzzing Score (Correlation)", "final_score": "Intervention Score (Surprisal)"},
61+
labels={"fuzz_score": "Fuzzing Score (Correlation)",
62+
"final_score": "Intervention Score (Surprisal)"},
6763
trendline="ols"
6864
)
6965
fig_score.write_image(out_dir / "scatter_fuzz_vs_score.pdf")
@@ -87,7 +83,6 @@ def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str):
8783
counts = df["status"].value_counts().reset_index()
8884
counts.columns = ["Status", "Count"]
8985

90-
# Get percentage
9186
total = counts["Count"].sum()
9287
live = counts[counts["Status"] == "Decoder-Live"]["Count"].sum() if "Decoder-Live" in counts["Status"].values else 0
9388
pct = (live / total * 100) if total > 0 else 0
@@ -99,7 +94,7 @@ def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str):
9994
)
10095
fig_bar.write_image(out_dir / "intervention_live_dead_split.pdf")
10196

102-
# 2. "Live Features Only" Histogram (The "Pretty" one)
97+
# 2. "Live Features Only" Histogram
10398
live_df = df[df["avg_kl_divergence"] > threshold]
10499
if not live_df.empty:
105100
fig_live = px.histogram(
@@ -124,7 +119,6 @@ def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str):
124119
fig_all.write_image(out_dir / "intervention_kl_dist_log_scale.pdf")
125120

126121

127-
# --- 2. STANDARD PLOTTING HELPERS ---
128122

129123
def plot_firing_vs_f1(latent_df, num_tokens, out_dir, run_label):
130124
out_dir.mkdir(parents=True, exist_ok=True)
@@ -168,7 +162,6 @@ def plot_roc_curve(df, out_dir):
168162
fig.write_image(out_dir / "roc_curve.pdf")
169163

170164

171-
# --- 3. METRIC COMPUTATION ---
172165

173166
def compute_confusion(df, threshold=0.5):
174167
df_valid = df[df["prediction"].notna()]
@@ -208,7 +201,6 @@ def add_latent_f1(df):
208201
return df.merge(f1s, on=["module", "latent_idx"])
209202

210203

211-
# --- 4. DATA LOADING ---
212204

213205
def load_data(scores_path, modules):
214206
def parse_file(path):
@@ -248,7 +240,6 @@ def parse_file(path):
248240
return (pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()), counts
249241

250242

251-
# --- 5. MAIN LOGIC ---
252243

253244
def log_results(scores_path: Path, viz_path: Path, modules: list[str], scorer_names: list[str], model_name: str = "Unknown"):
254245
import_plotly()
@@ -299,6 +290,5 @@ def log_results(scores_path: Path, viz_path: Path, modules: list[str], scorer_na
299290
plot_intervention_stats(unique_latents, viz_path, model_name)
300291

301292
# 3. Generate Scatter Plot (Fuzz vs. Intervention)
302-
# Only works if we have BOTH types of data
303293
if not class_df.empty and not int_df.empty:
304294
plot_fuzz_vs_intervention(latent_df, viz_path, scores_path.name)

0 commit comments

Comments
 (0)