Spaces:
Sleeping
Sleeping
COLAB = False | |
# from IPython import get_ipython # type: ignore | |
# ipython = get_ipython(); assert ipython is not None | |
# ipython.run_line_magic("load_ext", "autoreload") | |
# ipython.run_line_magic("autoreload", "2") | |
# Standard imports | |
import os | |
import torch | |
from tqdm import tqdm | |
import plotly.express as px | |
# Imports for displaying vis in Colab / notebook | |
import webbrowser | |
import http.server | |
import socketserver | |
import threading | |
PORT = 8000 | |
torch.set_grad_enabled(False); | |
# For the most part I'll try to import functions and classes near where they are used | |
# to make it clear where they come from. | |
if torch.backends.mps.is_available(): | |
device = "mps" | |
else: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Device: {device}") | |
def display_vis_inline(filename: str, height: int = 850): | |
''' | |
Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each | |
vis has a unique port without having to define a port within the function. | |
''' | |
if not(COLAB): | |
webbrowser.open(filename); | |
else: | |
global PORT | |
def serve(directory): | |
os.chdir(directory) | |
# Create a handler for serving files | |
handler = http.server.SimpleHTTPRequestHandler | |
# Create a socket server with the handler | |
with socketserver.TCPServer(("", PORT), handler) as httpd: | |
print(f"Serving files from {directory} on port {PORT}") | |
httpd.serve_forever() | |
thread = threading.Thread(target=serve, args=("/content",)) | |
thread.start() | |
# output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True) | |
PORT += 1 | |
from datasets import load_dataset | |
from transformer_lens import HookedTransformer | |
from sae_lens import SAE | |
model = HookedTransformer.from_pretrained("gpt2-small", device = device) | |
# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store) | |
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict | |
# We also return the feature sparsities which are stored in HF for convenience. | |
sae, cfg_dict, sparsity = SAE.from_pretrained( | |
release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml | |
sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point | |
device = device | |
) | |
from transformer_lens.utils import tokenize_and_concatenate | |
dataset = load_dataset( | |
path = "NeelNanda/pile-10k", | |
split="train", | |
streaming=False, | |
) | |
token_dataset = tokenize_and_concatenate( | |
dataset= dataset,# type: ignore | |
tokenizer = model.tokenizer, # type: ignore | |
streaming=True, | |
max_length=sae.cfg.context_size, | |
add_bos_token=sae.cfg.prepend_bos, | |
) | |
sae.eval() # prevents error if we're expecting a dead neuron mask for who grads | |
with torch.no_grad(): | |
# activation store can give us tokens. | |
batch_tokens = token_dataset[:32]["tokens"] | |
_, cache = model.run_with_cache(batch_tokens, prepend_bos=True) | |
# Use the SAE | |
feature_acts = sae.encode(cache[sae.cfg.hook_name]) | |
sae_out = sae.decode(feature_acts) | |
# save some room | |
del cache | |
# ignore the bos token, get the number of features that activated in each token, averaged accross batch and position | |
l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach() | |
print("average l0", l0.mean().item()) | |
px.histogram(l0.flatten().cpu().numpy()).show() | |
from transformer_lens import utils | |
from functools import partial | |
# next we want to do a reconstruction test. | |
def reconstr_hook(activation, hook, sae_out): | |
return sae_out | |
def zero_abl_hook(activation, hook): | |
return torch.zeros_like(activation) | |
print("Orig", model(batch_tokens, return_type="loss").item()) | |
print( | |
"reconstr", | |
model.run_with_hooks( | |
batch_tokens, | |
fwd_hooks=[ | |
( | |
sae.cfg.hook_name, | |
partial(reconstr_hook, sae_out=sae_out), | |
) | |
], | |
return_type="loss", | |
).item(), | |
) | |
print( | |
"Zero", | |
model.run_with_hooks( | |
batch_tokens, | |
return_type="loss", | |
fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)], | |
).item(), | |
) | |
example_prompt = "When John and Mary went to the shops, John gave the bag to" | |
example_answer = " Mary" | |
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True) | |
logits, cache = model.run_with_cache(example_prompt, prepend_bos=True) | |
tokens = model.to_tokens(example_prompt) | |
sae_out = sae(cache[sae.cfg.hook_name]) | |
def reconstr_hook(activations, hook, sae_out): | |
return sae_out | |
def zero_abl_hook(mlp_out, hook): | |
return torch.zeros_like(mlp_out) | |
hook_name = sae.cfg.hook_name | |
print("Orig", model(tokens, return_type="loss").item()) | |
print( | |
"reconstr", | |
model.run_with_hooks( | |
tokens, | |
fwd_hooks=[ | |
( | |
hook_name, | |
partial(reconstr_hook, sae_out=sae_out), | |
) | |
], | |
return_type="loss", | |
).item(), | |
) | |
print( | |
"Zero", | |
model.run_with_hooks( | |
tokens, | |
return_type="loss", | |
fwd_hooks=[(hook_name, zero_abl_hook)], | |
).item(), | |
) | |
with model.hooks( | |
fwd_hooks=[ | |
( | |
hook_name, | |
partial(reconstr_hook, sae_out=sae_out), | |
) | |
] | |
): | |
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True) | |
from sae_dashboard.sae_vis_data import SaeVisConfig | |
from sae_dashboard.sae_vis_runner import SaeVisRunner | |
test_feature_idx_gpt = list(range(10)) + [14057] | |
feature_vis_config_gpt = SaeVisConfig( | |
hook_point=hook_name, | |
features=test_feature_idx_gpt, | |
minibatch_size_features=64, | |
minibatch_size_tokens=256, | |
verbose=True, | |
device=device, | |
) | |
visualization_data_gpt = SaeVisRunner(feature_vis_config_gpt).run( | |
encoder=sae, # type: ignore | |
model=model, | |
tokens=token_dataset[:10000]["tokens"], # type: ignore | |
) | |
# SaeVisData.create( | |
# encoder=sae, | |
# model=model, # type: ignore | |
# tokens=token_dataset[:10000]["tokens"], # type: ignore | |
# cfg=feature_vis_config_gpt, | |
# ) | |
from sae_dashboard.data_writing_fns import save_feature_centric_vis | |
filename = f"demo_feature_dashboards.html" | |
save_feature_centric_vis(sae_vis_data=visualization_data_gpt, filename=filename) | |
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list | |
# this function should open | |
neuronpedia_quick_list = get_neuronpedia_quick_list( | |
test_feature_idx_gpt, | |
layer=sae.cfg.hook_layer, | |
model="gpt2-small", | |
dataset="res-jb", | |
name="A quick list we made", | |
) | |
if COLAB: | |
# If you're on colab, click the link below | |
print(neuronpedia_quick_list) |