steer-hexter / loading_analyzing.py
LeeHarrold's picture
Upload folder using huggingface_hub
50e2012 verified
COLAB = False
# from IPython import get_ipython # type: ignore
# ipython = get_ipython(); assert ipython is not None
# ipython.run_line_magic("load_ext", "autoreload")
# ipython.run_line_magic("autoreload", "2")
# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px
# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000
torch.set_grad_enabled(False);
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.
if torch.backends.mps.is_available():
device = "mps"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
def display_vis_inline(filename: str, height: int = 850):
'''
Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
vis has a unique port without having to define a port within the function.
'''
if not(COLAB):
webbrowser.open(filename);
else:
global PORT
def serve(directory):
os.chdir(directory)
# Create a handler for serving files
handler = http.server.SimpleHTTPRequestHandler
# Create a socket server with the handler
with socketserver.TCPServer(("", PORT), handler) as httpd:
print(f"Serving files from {directory} on port {PORT}")
httpd.serve_forever()
thread = threading.Thread(target=serve, args=("/content",))
thread.start()
# output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)
PORT += 1
from datasets import load_dataset
from transformer_lens import HookedTransformer
from sae_lens import SAE
model = HookedTransformer.from_pretrained("gpt2-small", device = device)
# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained(
release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point
device = device
)
from transformer_lens.utils import tokenize_and_concatenate
dataset = load_dataset(
path = "NeelNanda/pile-10k",
split="train",
streaming=False,
)
token_dataset = tokenize_and_concatenate(
dataset= dataset,# type: ignore
tokenizer = model.tokenizer, # type: ignore
streaming=True,
max_length=sae.cfg.context_size,
add_bos_token=sae.cfg.prepend_bos,
)
sae.eval() # prevents error if we're expecting a dead neuron mask for who grads
with torch.no_grad():
# activation store can give us tokens.
batch_tokens = token_dataset[:32]["tokens"]
_, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
# Use the SAE
feature_acts = sae.encode(cache[sae.cfg.hook_name])
sae_out = sae.decode(feature_acts)
# save some room
del cache
# ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
print("average l0", l0.mean().item())
px.histogram(l0.flatten().cpu().numpy()).show()
from transformer_lens import utils
from functools import partial
# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
return sae_out
def zero_abl_hook(activation, hook):
return torch.zeros_like(activation)
print("Orig", model(batch_tokens, return_type="loss").item())
print(
"reconstr",
model.run_with_hooks(
batch_tokens,
fwd_hooks=[
(
sae.cfg.hook_name,
partial(reconstr_hook, sae_out=sae_out),
)
],
return_type="loss",
).item(),
)
print(
"Zero",
model.run_with_hooks(
batch_tokens,
return_type="loss",
fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],
).item(),
)
example_prompt = "When John and Mary went to the shops, John gave the bag to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)
logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)
tokens = model.to_tokens(example_prompt)
sae_out = sae(cache[sae.cfg.hook_name])
def reconstr_hook(activations, hook, sae_out):
return sae_out
def zero_abl_hook(mlp_out, hook):
return torch.zeros_like(mlp_out)
hook_name = sae.cfg.hook_name
print("Orig", model(tokens, return_type="loss").item())
print(
"reconstr",
model.run_with_hooks(
tokens,
fwd_hooks=[
(
hook_name,
partial(reconstr_hook, sae_out=sae_out),
)
],
return_type="loss",
).item(),
)
print(
"Zero",
model.run_with_hooks(
tokens,
return_type="loss",
fwd_hooks=[(hook_name, zero_abl_hook)],
).item(),
)
with model.hooks(
fwd_hooks=[
(
hook_name,
partial(reconstr_hook, sae_out=sae_out),
)
]
):
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)
from sae_dashboard.sae_vis_data import SaeVisConfig
from sae_dashboard.sae_vis_runner import SaeVisRunner
test_feature_idx_gpt = list(range(10)) + [14057]
feature_vis_config_gpt = SaeVisConfig(
hook_point=hook_name,
features=test_feature_idx_gpt,
minibatch_size_features=64,
minibatch_size_tokens=256,
verbose=True,
device=device,
)
visualization_data_gpt = SaeVisRunner(feature_vis_config_gpt).run(
encoder=sae, # type: ignore
model=model,
tokens=token_dataset[:10000]["tokens"], # type: ignore
)
# SaeVisData.create(
# encoder=sae,
# model=model, # type: ignore
# tokens=token_dataset[:10000]["tokens"], # type: ignore
# cfg=feature_vis_config_gpt,
# )
from sae_dashboard.data_writing_fns import save_feature_centric_vis
filename = f"demo_feature_dashboards.html"
save_feature_centric_vis(sae_vis_data=visualization_data_gpt, filename=filename)
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list
# this function should open
neuronpedia_quick_list = get_neuronpedia_quick_list(
test_feature_idx_gpt,
layer=sae.cfg.hook_layer,
model="gpt2-small",
dataset="res-jb",
name="A quick list we made",
)
if COLAB:
# If you're on colab, click the link below
print(neuronpedia_quick_list)