99from sklearn .metrics import roc_auc_score , roc_curve
1010
1111
12- # --- 1. NEW PLOTTING FUNCTIONS ---
1312
1413def plot_fuzz_vs_intervention (latent_df : pd .DataFrame , out_dir : Path , run_label : str ):
1514 """
1615 Replicates the Scatter Plot from the paper (Figure 3/Appendix G).
1716 Plots Fuzz Score vs. Intervention Score for the same latents.
1817 """
19- # We need to merge the rows for 'fuzz' and 'surprisal_intervention'
20- # 1. Pivot the table so we have columns: 'latent_idx', 'fuzz_score', 'intervention_score'
2118
2219 # Extract Fuzz Scores (using F1 or Accuracy as the metric)
2320 fuzz_df = latent_df [latent_df ["score_type" ] == "fuzz" ].copy ()
@@ -32,12 +29,10 @@ def plot_fuzz_vs_intervention(latent_df: pd.DataFrame, out_dir: Path, run_label:
3229 int_df = latent_df [latent_df ["score_type" ] == "surprisal_intervention" ].copy ()
3330 if int_df .empty : return
3431
35- # Deduplicate intervention scores
3632 int_metrics = int_df .drop_duplicates (subset = ["module" , "latent_idx" ])[
3733 ["module" , "latent_idx" , "avg_kl_divergence" , "final_score" ]
3834 ]
3935
40- # Merge them
4136 merged = pd .merge (fuzz_metrics , int_metrics , on = ["module" , "latent_idx" ])
4237
4338 if merged .empty :
@@ -63,7 +58,8 @@ def plot_fuzz_vs_intervention(latent_df: pd.DataFrame, out_dir: Path, run_label:
6358 y = "final_score" ,
6459 hover_data = ["latent_idx" ],
6560 title = f"Correlation vs. Causation (Score) - { run_label } " ,
66- labels = {"fuzz_score" : "Fuzzing Score (Correlation)" , "final_score" : "Intervention Score (Surprisal)" },
61+ labels = {"fuzz_score" : "Fuzzing Score (Correlation)" ,
62+ "final_score" : "Intervention Score (Surprisal)" },
6763 trendline = "ols"
6864 )
6965 fig_score .write_image (out_dir / "scatter_fuzz_vs_score.pdf" )
@@ -87,7 +83,6 @@ def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str):
8783 counts = df ["status" ].value_counts ().reset_index ()
8884 counts .columns = ["Status" , "Count" ]
8985
90- # Get percentage
9186 total = counts ["Count" ].sum ()
9287 live = counts [counts ["Status" ] == "Decoder-Live" ]["Count" ].sum () if "Decoder-Live" in counts ["Status" ].values else 0
9388 pct = (live / total * 100 ) if total > 0 else 0
@@ -99,7 +94,7 @@ def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str):
9994 )
10095 fig_bar .write_image (out_dir / "intervention_live_dead_split.pdf" )
10196
102- # 2. "Live Features Only" Histogram (The "Pretty" one)
97+ # 2. "Live Features Only" Histogram
10398 live_df = df [df ["avg_kl_divergence" ] > threshold ]
10499 if not live_df .empty :
105100 fig_live = px .histogram (
@@ -124,7 +119,6 @@ def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str):
124119 fig_all .write_image (out_dir / "intervention_kl_dist_log_scale.pdf" )
125120
126121
127- # --- 2. STANDARD PLOTTING HELPERS ---
128122
129123def plot_firing_vs_f1 (latent_df , num_tokens , out_dir , run_label ):
130124 out_dir .mkdir (parents = True , exist_ok = True )
@@ -168,7 +162,6 @@ def plot_roc_curve(df, out_dir):
168162 fig .write_image (out_dir / "roc_curve.pdf" )
169163
170164
171- # --- 3. METRIC COMPUTATION ---
172165
173166def compute_confusion (df , threshold = 0.5 ):
174167 df_valid = df [df ["prediction" ].notna ()]
@@ -208,7 +201,6 @@ def add_latent_f1(df):
208201 return df .merge (f1s , on = ["module" , "latent_idx" ])
209202
210203
211- # --- 4. DATA LOADING ---
212204
213205def load_data (scores_path , modules ):
214206 def parse_file (path ):
@@ -248,7 +240,6 @@ def parse_file(path):
248240 return (pd .concat (dfs , ignore_index = True ) if dfs else pd .DataFrame ()), counts
249241
250242
251- # --- 5. MAIN LOGIC ---
252243
253244def log_results (scores_path : Path , viz_path : Path , modules : list [str ], scorer_names : list [str ], model_name : str = "Unknown" ):
254245 import_plotly ()
@@ -299,6 +290,5 @@ def log_results(scores_path: Path, viz_path: Path, modules: list[str], scorer_na
299290 plot_intervention_stats (unique_latents , viz_path , model_name )
300291
301292 # 3. Generate Scatter Plot (Fuzz vs. Intervention)
302- # Only works if we have BOTH types of data
303293 if not class_df .empty and not int_df .empty :
304294 plot_fuzz_vs_intervention (latent_df , viz_path , scores_path .name )
0 commit comments