bababababooey's picture
Upload hotswap.py
eab6e7f verified
raw
history blame contribute delete
No virus
5.41 kB
import torch
from transformers import MllamaForConditionalGeneration, AutoProcessor
import os
import json
from safetensors import safe_open
import re
# apologies in advance for shitty gpt-assisted code
# this script should also work with 70b/90b if you change `cross_attention_layers` and `total_layers` accordingly
# but i dont have enough deditated wam to test it and i dont feel like spinning up runpod so
cross_attention_layers = [3, 8, 13, 18, 23, 28, 33, 38]
#b8 = './models/mlabonne_Meta-Llama-3.1-8B-Instruct-abliterated'
b8 = './models/v000000_L3-8B-Stheno-v3.2-abliterated'
#b8 = './models/arcee-ai_Llama-3.1-SuperNova-Lite'
print(b8)
model_id = "./models/meta-llama_Llama-3.2-11B-Vision-Instruct"
def create_layer_mapping(total_layers=32, cross_attn_layers=cross_attention_layers):
"""
Creates a mapping from llama-3.1-8b layer indices to llama-3.2-11b layer indices.
"""
mapping = {}
shift = 0
next_cross_attn_idx = 0
for X in range(total_layers):
# Check if a cross-attention layer is inserted before this layer
if next_cross_attn_idx < len(cross_attn_layers) and (X + shift) == cross_attn_layers[next_cross_attn_idx]:
shift += 1
next_cross_attn_idx += 1
Y = X + shift
mapping[X] = Y
return mapping
def load_sharded_state_dict(model_dir):
index_file = os.path.join(model_dir, 'model.safetensors.index.json')
with open(index_file, 'r') as f:
index_data = json.load(f)
weight_map = index_data['weight_map']
state_dict = {}
shard_to_params = {}
for param_name, shard_file in weight_map.items():
if shard_file not in shard_to_params:
shard_to_params[shard_file] = []
shard_to_params[shard_file].append(param_name)
for shard_file, params_in_shard in shard_to_params.items():
shard_path = os.path.join(model_dir, shard_file)
with safe_open(shard_path, framework="pt", device="cpu") as f:
for name in params_in_shard:
state_dict[name] = f.get_tensor(name)
return state_dict
def compare_model_states(model, new_state_dict):
current_state = model.state_dict()
unchanged_params = []
changed_params = []
missing_params = []
for name, param in current_state.items():
if name not in new_state_dict:
missing_params.append(name)
elif torch.equal(param.cpu(), new_state_dict[name].cpu()):
unchanged_params.append(name)
else:
changed_params.append(name)
return {
'unchanged': unchanged_params,
'changed': changed_params,
'missing': missing_params
}
layer_mapping = create_layer_mapping()
# Load Llama 3.2 state dict
llama_3_2_state_dict = load_sharded_state_dict(model_id)
# Extract the embedding matrix from Llama 3.2
llama_3_2_embeddings = llama_3_2_state_dict['language_model.model.embed_tokens.weight'] # Shape: [128264, 4096]
llama_3_2_state_dict.clear()
b8dict = load_sharded_state_dict(b8)
embed_tokens_weight = b8dict['model.embed_tokens.weight'] # Shape: [128256, 4096]
new_vocab_size = 128264 # From Llama 3.2
new_embed_tokens_weight = torch.zeros((new_vocab_size, 4096), dtype=embed_tokens_weight.dtype)
# Copy the existing embeddings
new_embed_tokens_weight[:128256, :] = embed_tokens_weight
# Copy the additional embeddings from Llama 3.2
new_embed_tokens_weight[128256:, :] = llama_3_2_embeddings[128256:, :]
b8dict['model.embed_tokens.weight'] = new_embed_tokens_weight
llama_3_2_embeddings = None
# Adjust Llama 3.1 parameter names to match Llama 3.2 language model
st8dict = {}
for name, param in b8dict.items():
# Prefix non-layer parameters with 'language_model.'
if not re.match(r'model\.layers\.\d+\.', name):
new_name = 'language_model.' + name
else:
# Extract the layer index X from 'model.layers.X.'
match = re.match(r'model\.layers\.(\d+)\.(.+)', name)
if match:
X = int(match.group(1))
suffix = match.group(2)
# Get the corresponding Y in llama-3.2-11b
Y = layer_mapping.get(X, X + len(cross_attention_layers))
new_name = f'language_model.model.layers.{Y}.{suffix}'
else:
# If the pattern doesn't match, just prefix with 'language_model.'
new_name = 'language_model.' + name
st8dict[new_name] = param
#write st8dict keys to file for verification
with open('st8dict.txt', 'w') as f:
f.write('\n'.join(st8dict.keys()))
model = MllamaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="cpu",
)
#original_state = {k: v.clone() for k, v in model.state_dict().items()}
model.load_state_dict(st8dict, strict=False)
b8dict.clear()
st8dict.clear()
'''
result = compare_model_states(model, original_state)
print("Unchanged parameters:", len(result['unchanged']))
print("Changed parameters:", len(result['changed']))
print("Missing parameters:", len(result['missing']))
#write result to file
with open('result.txt', 'w') as f:
f.write(json.dumps(result, indent=2))
'''
processor = AutoProcessor.from_pretrained(model_id)
model.save_pretrained("llama-3.2-11b-vision-stheno-abliterated")