maldv's picture
Upload folder using huggingface_hub
b59223f verified
raw
history blame
1.36 kB
# ztrain/model.py
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
from collections import defaultdict
import re
def generate_merge_group(group_data : list, parents : list[int] = []):
# drill down until we find a list of strings, then yield it with a parent tree index
for i, g in enumerate(group_data):
if isinstance(g, list):
yield from generate_merge_group(g, parents + [i])
else:
yield g, parents + [i]
def merge_groups(group_data : list):
results = defaultdict(list)
for g, k in generate_merge_group(group_data):
key = tuple(k[:-1])
results[key].append(g)
return results
def get_layer_type(k : str) -> tuple[int, str, str, str]:
matcher = re.compile(r"model.layers.(\d+)\.(.+)\.(.+)\.(.+)")
m = matcher.match(k)
if m is not None:
return int(m.group(1)), m.group(2), m.group(3), m.group(4)
matcher = re.compile(r"model.layers.(\d+)\.(.+)\.(.+)")
if m is not None:
return int(m.group(1)), m.group(2), "", m.group(3)
if "model.norm.weight" == k:
return -1, "norm", "", "weight"
if "model.embed_tokens.weight" == k:
return -1, "embed_tokens", "", "weight"
if "lm_head.weight" == k:
return -1, "lm_head", "", "weight"
print(f"Unknown key {k}")
return -1, "unknown", "unknown", "unknown"