SmolLm and mergekit_moe: is lm_head missing ?
Hello,
I am trying to play with mergekit-moe and SmolLm but I am facing a problem that I can't solve. I am not sure whether the problem is SmolLM related or mergekit-moe related.
Using a dummy merging config.yaml
such as
base_model: HuggingFaceTB/SmolLM-135M
gate_mode: random
dtype: bfloat16
experts:
- source_model: HuggingFaceTB/SmolLM-135M
- source_model: HuggingFaceTB/SmolLM-135M
and running the command
mergekit-moe config.yaml merge --copy-tokenizer
I am getting the following error
Fetching 7 files: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 7/7 [00:00<00:00, 70407.98it/s]
Fetching 7 files: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 7/7 [00:00<00:00, 8063.75it/s]
Fetching 7 files: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 7/7 [00:00<00:00, 18213.48it/s]
Warm up loaders: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 3/3 [00:00<00:00, 4.13it/s]
Weights: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 272/273 [00:00<00:00, 2120.41it/s]
Traceback (most recent call last):
File "/home/ubuntu/merging/.venv/bin/mergekit-moe", line 8, in <module>
sys.exit(main())
File "/home/ubuntu/merging/.venv/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
return self.main(*args, **kwargs)
File "/home/ubuntu/merging/.venv/lib/python3.10/site-packages/click/core.py", line 1078, in main
rv = self.invoke(ctx)
File "/home/ubuntu/merging/.venv/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/home/ubuntu/merging/.venv/lib/python3.10/site-packages/click/core.py", line 783, in invoke
return __callback(*args, **kwargs)
File "/home/ubuntu/merging/mergekit/mergekit/options.py", line 82, in wrapper
f(*args, **kwargs)
File "/home/ubuntu/merging/mergekit/mergekit/scripts/moe.py", line 211, in main
build(
File "/home/ubuntu/merging/mergekit/mergekit/scripts/moe.py", line 82, in build
out_arch.write_model(
File "/home/ubuntu/merging/mergekit/mergekit/moe/mixtral.py", line 160, in write_model
tensor = base_loader.get_tensor(
File "/home/ubuntu/merging/mergekit/mergekit/io/lazy_tensor_loader.py", line 127, in get_tensor
raise KeyError(key)
KeyError: 'lm_head.weight'
somehow lm_head.weight
seems to be missing. But when I load SmolLM and inspect the layers I get
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(49152, 576)
(layers): ModuleList(
(0-29): 30 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=576, out_features=576, bias=False)
(k_proj): Linear(in_features=576, out_features=192, bias=False)
(v_proj): Linear(in_features=576, out_features=192, bias=False)
(o_proj): Linear(in_features=576, out_features=576, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=576, out_features=1536, bias=False)
(up_proj): Linear(in_features=576, out_features=1536, bias=False)
(down_proj): Linear(in_features=1536, out_features=576, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((576,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=576, out_features=49152, bias=False)
)
indicating that lm_head
is right where it should be.
However when I inspect the layers from HF "Files and versions tab" lm_head
does not appear as suggests the following screenshot
somehow lm_head
seems to be missing...
Any thoughts?
Hey, It's due to the use of tie_word_embeddings=true
parameter, the lm_head is the same as the embed_tokens layer (but transposed). You probably have to replace AutoModel
by AutoModelForCausalLM
somewhere in the mergekit-moe
to make it work.
Hello, yes that was it!
Pointing toward model.embed_tokens.weight
when asked for lm_head.weight
solves the merging problem.
mergekit has json files that define the architecture of current models, I will add one for SmolLm to avoid interfering with the source code.
Thanks for the help!