File size: 5,406 Bytes
eab6e7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
from transformers import MllamaForConditionalGeneration, AutoProcessor
import os
import json
from safetensors import safe_open
import re

# apologies in advance for shitty gpt-assisted code

# this script should also work with 70b/90b if you change `cross_attention_layers` and `total_layers` accordingly
# but i dont have enough deditated wam to test it and i dont feel like spinning up runpod so

cross_attention_layers = [3, 8, 13, 18, 23, 28, 33, 38]

#b8 = './models/mlabonne_Meta-Llama-3.1-8B-Instruct-abliterated'
b8 = './models/v000000_L3-8B-Stheno-v3.2-abliterated'
#b8 = './models/arcee-ai_Llama-3.1-SuperNova-Lite'
print(b8)

model_id = "./models/meta-llama_Llama-3.2-11B-Vision-Instruct"

def create_layer_mapping(total_layers=32, cross_attn_layers=cross_attention_layers):
    """

    Creates a mapping from llama-3.1-8b layer indices to llama-3.2-11b layer indices.

    """
    mapping = {}
    shift = 0
    next_cross_attn_idx = 0
    for X in range(total_layers):
        # Check if a cross-attention layer is inserted before this layer
        if next_cross_attn_idx < len(cross_attn_layers) and (X + shift) == cross_attn_layers[next_cross_attn_idx]:
            shift += 1
            next_cross_attn_idx += 1
        Y = X + shift
        mapping[X] = Y
    return mapping

def load_sharded_state_dict(model_dir):
    index_file = os.path.join(model_dir, 'model.safetensors.index.json')
    with open(index_file, 'r') as f:
        index_data = json.load(f)
    weight_map = index_data['weight_map']
    state_dict = {}
    shard_to_params = {}
    for param_name, shard_file in weight_map.items():
        if shard_file not in shard_to_params:
            shard_to_params[shard_file] = []
        shard_to_params[shard_file].append(param_name)
    for shard_file, params_in_shard in shard_to_params.items():
        shard_path = os.path.join(model_dir, shard_file)
        with safe_open(shard_path, framework="pt", device="cpu") as f:
            for name in params_in_shard:
                state_dict[name] = f.get_tensor(name)
    return state_dict

def compare_model_states(model, new_state_dict):
    current_state = model.state_dict()
    unchanged_params = []
    changed_params = []
    missing_params = []

    for name, param in current_state.items():
        if name not in new_state_dict:
            missing_params.append(name)
        elif torch.equal(param.cpu(), new_state_dict[name].cpu()):
            unchanged_params.append(name)
        else:
            changed_params.append(name)

    return {
        'unchanged': unchanged_params,
        'changed': changed_params,
        'missing': missing_params
    }


layer_mapping = create_layer_mapping()

# Load Llama 3.2 state dict
llama_3_2_state_dict = load_sharded_state_dict(model_id)

# Extract the embedding matrix from Llama 3.2
llama_3_2_embeddings = llama_3_2_state_dict['language_model.model.embed_tokens.weight']  # Shape: [128264, 4096]

llama_3_2_state_dict.clear()

b8dict = load_sharded_state_dict(b8)

embed_tokens_weight = b8dict['model.embed_tokens.weight']  # Shape: [128256, 4096]
new_vocab_size = 128264  # From Llama 3.2
new_embed_tokens_weight = torch.zeros((new_vocab_size, 4096), dtype=embed_tokens_weight.dtype)

# Copy the existing embeddings
new_embed_tokens_weight[:128256, :] = embed_tokens_weight
# Copy the additional embeddings from Llama 3.2
new_embed_tokens_weight[128256:, :] = llama_3_2_embeddings[128256:, :]

b8dict['model.embed_tokens.weight'] = new_embed_tokens_weight


llama_3_2_embeddings = None

# Adjust Llama 3.1 parameter names to match Llama 3.2 language model
st8dict = {}
for name, param in b8dict.items():
    # Prefix non-layer parameters with 'language_model.'
    if not re.match(r'model\.layers\.\d+\.', name):
        new_name = 'language_model.' + name
    else:
        # Extract the layer index X from 'model.layers.X.'
        match = re.match(r'model\.layers\.(\d+)\.(.+)', name)
        if match:
            X = int(match.group(1))
            suffix = match.group(2)
            # Get the corresponding Y in llama-3.2-11b
            Y = layer_mapping.get(X, X + len(cross_attention_layers))
            new_name = f'language_model.model.layers.{Y}.{suffix}'
        else:
            # If the pattern doesn't match, just prefix with 'language_model.'
            new_name = 'language_model.' + name
    st8dict[new_name] = param

#write st8dict keys to file for verification
with open('st8dict.txt', 'w') as f:
    f.write('\n'.join(st8dict.keys()))


model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="cpu",
)

#original_state = {k: v.clone() for k, v in model.state_dict().items()}

model.load_state_dict(st8dict, strict=False)

b8dict.clear()
st8dict.clear()


'''

result = compare_model_states(model, original_state)



print("Unchanged parameters:", len(result['unchanged']))

print("Changed parameters:", len(result['changed']))

print("Missing parameters:", len(result['missing']))



#write result to file

with open('result.txt', 'w') as f:

    f.write(json.dumps(result, indent=2))

'''


processor = AutoProcessor.from_pretrained(model_id)


model.save_pretrained("llama-3.2-11b-vision-stheno-abliterated")