# general imports import os import torch from tqdm import tqdm import plotly.express as px torch.set_grad_enabled(False); # package import from torch import Tensor from transformer_lens import utils from functools import partial from jaxtyping import Int, Float # device setup if torch.backends.mps.is_available(): device = "mps" else: device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") from transformer_lens import HookedTransformer from sae_lens import SAE # Choose a layer you want to focus on # For this tutorial, we're going to use layer ???? layer = 0 # get model model = HookedTransformer.from_pretrained("tiny-stories-1L-21M", device = device) # get the SAE for this layer sae = SAE.load_from_pretrained("sae_tiny-stories-1L-21M_blocks.0.hook_mlp_out_16384", device = device) # get hook point hook_point = sae.cfg.hook_name print(hook_point) sv_prompt = " Lily" sv_logits, activationCache = model.run_with_cache(sv_prompt, prepend_bos=True) sv_feature_acts = sae.encode(activationCache[hook_point]) print(torch.topk(sv_feature_acts, 3).indices.tolist()) # Generate sv_prompt = " Lily" sv_logits, activationCache = model.run_with_cache(sv_prompt, prepend_bos=True) tokens = model.to_tokens(sv_prompt) print(tokens) # get the feature activations from our SAE sv_feature_acts = sae.encode(activationCache[hook_point]) # get sae_out sae_out = sae.decode(sv_feature_acts) # print out the top activations, focus on the indices print(torch.topk(sv_feature_acts, 3)) # get the neurons to use; print(torch.topk(sv_feature_acts, 3).indices.tolist()) # choose the vector -- find this from the above section # steering_vector = sae.W_dec[10284] example_prompt = "Once upon a time" coeff = 1000 sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0) # apply steering vector when the model generates def steering_hook(resid_pre, hook): if resid_pre.shape[1] == 1: return position = sae_out.shape[1] if steering_on: breakpoint() # using our steering vector and applying the coefficient resid_pre[:, :position - 1, :] += coeff * steering_vector def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs): if seed is not None: torch.manual_seed(seed) with model.hooks(fwd_hooks=fwd_hooks): tokenized = model.to_tokens(prompt_batch) result = model.generate( stop_at_eos=False, # avoids a bug on MPS input=tokenized, max_new_tokens=50, do_sample=True, **kwargs) return result def run_generate(example_prompt): model.reset_hooks() editing_hooks = [(f"blocks.{layer}.hook_resid_post", steering_hook)] res = hooked_generate([example_prompt] * 3, editing_hooks, seed=None, **sampling_kwargs) # Print results, removing the ugly beginning of sequence token res_str = model.to_string(res[:, 1:]) print(("\n\n" + "-" * 80 + "\n\n").join(res_str)) steering_on = True run_generate(example_prompt) # evaluate features import pandas as pd # Let's start by getting the top 10 logits for each feature projection_onto_unembed = sae.W_dec @ model.W_U # get the top 10 logits. vals, inds = torch.topk(projection_onto_unembed, 10, dim=1) # get 10 random features random_indices = torch.randint(0, projection_onto_unembed.shape[0], (10,)) # Show the top 10 logits promoted by those features top_10_logits_df = pd.DataFrame( [model.to_str_tokens(i) for i in inds[random_indices]], index=random_indices.tolist(), ).T top_10_logits_df # [7195, 5910, 2041] top_10_associated_words_logits_df = model.to_str_tokens(inds[5910]) # See the words associated with feature 7195 (Should be "Golden")