diff --git a/README.md b/README.md index 8df8288..ce41770 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Tools for inspecting sparse circuit models from [Gao et al. 2025](https://openai.com/index/understanding-neural-networks-through-sparse-circuits/). Provides code for running inference as well as a Streamlit dashboard that allows you to interact -with task-specific circuits found by pruning. +with task-specific circuits found by pruning. Note: this README was AI-generated and lightly edited. ## Installation @@ -23,6 +23,10 @@ visualizer loads you can choose a model, dataset, pruning sweep, and node budget using the controls in the left column. The plots are rendered with Plotly; most elements are interactive and support hover/click exploration. +Example view of the Streamlit circuit visualizer (wte/wpe tab) with node ablation deltas and activation previews: + +![Streamlit circuit visualizer](annotated-circuit-sparsity-viz.png) + ## Running Model Forward Passes Transformer definitions live in `circuit_sparsity.inference.gpt`. The module diff --git a/annotated-circuit-sparsity-viz.png b/annotated-circuit-sparsity-viz.png new file mode 100644 index 0000000..f2b2b54 Binary files /dev/null and b/annotated-circuit-sparsity-viz.png differ diff --git a/circuit_sparsity/viz.py b/circuit_sparsity/viz.py index d974b61..490200f 100644 --- a/circuit_sparsity/viz.py +++ b/circuit_sparsity/viz.py @@ -1485,13 +1485,8 @@ def _maybe_item(x): cols = st.columns([1, 1]) model_config = viz_data["importances"]["beeg_model_config"] with cols[0]: - cols2 = st.columns([1, 2, 1]) + cols2 = st.columns([2, 1]) with cols2[0]: - use_pca = st.toggle( - "PCA components", - value=False, - ) - with cols2[1]: chidx = st.selectbox( "res channel index", options=[ @@ -1502,14 +1497,10 @@ def _maybe_item(x): ) chidx = int(chidx.split(" ")[0]) - with cols2[2]: + with cols2[1]: st.text(f"encname: {model_config.tokenizer_name}") - if use_pca: - U, S, V = get_embed_weights_pca(model_path, q=100) - embsort = U[:, chidx].cpu().sort(descending=True) - else: - embsort = embed_weight[:, chidx].sort(descending=True) + embsort = embed_weight[:, chidx].sort(descending=True) def _filter_embsort(xs): return [ @@ -1588,12 +1579,6 @@ def _filter_embsort(xs): status_placeholder.html("
ready
 
") -@cache("get_embed_weights_pca_v1") -def get_embed_weights_pca(model_path, q): - wte = get_embed_weights(model_path).float() - return torch.pca_lowrank(wte, niter=10, q=q) - - @cache("get_embed_weights_v1") def get_embed_weights(model_path): return get_model_weights(model_path, lambda x: x["transformer.wte.weight"]).half() @@ -1714,5 +1699,3 @@ def treemap(f, x): st.set_page_config(page_title="Circuit viz", layout="wide") main() - -