File size: 2,866 Bytes
ba6ae78
18e7e04
ba6ae78
 
18e7e04
 
 
 
 
 
 
 
 
 
3afe408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba6ae78
 
 
 
 
 
 
 
 
 
 
 
 
df99996
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
---
license: cc-by-nc-4.0
---

## License

非商用ライセンスで公開します。

## Chat Vector

```
Tora-7B-v0.2 = NTQAI/chatntq-ja-7b-v1.0 + (NousResearch/Hermes-2-Pro-Mistral-7B - mistralai/Mistral-7B-v0.1)
```

## 実装

@jovyan様の実装を参考に下記のコードでモデルを作成しました。

```python
import torch
from transformers import AutoModelForCausalLM


def build_chat_vector_model(
    base_model_name,
    inst_model_name,
    target_model_name,
    skip_layers,
    ):

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cpu",
    )
    inst_model = AutoModelForCausalLM.from_pretrained(
        inst_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cpu",
    )

    target_model = AutoModelForCausalLM.from_pretrained(
        target_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cuda",
    )

    # 英語ベースモデル
    for k, v in base_model.state_dict().items():
        print(k, v.shape)

    # 日本語継続事前学習モデル
    for k, v in target_model.state_dict().items():
        print(k, v.shape)

    # 除外対象
    skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]

    for k, v in target_model.state_dict().items():
        # layernormも除外
        if (k in skip_layers) or ("layernorm" in k):
            continue
        chat_vector = inst_model.state_dict()[k] - base_model.state_dict()[k]
        new_v = v + chat_vector.to(v.device)
        v.copy_(new_v)

    target_model.save_pretrained("./chat_model")

    return


if __name__ == '__main__':

    base_model_name = "mistralai/Mistral-7B-v0.1"
    inst_model_name = "NousResearch/Hermes-2-Pro-Mistral-7B"
    target_model_name = "NTQAI/chatntq-ja-7b-v1.0"

    skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]

    build_chat_vector_model(
        base_model_name=base_model_name,
        inst_model_name=inst_model_name,
        target_model_name=target_model_name,
        skip_layers=skip_layers
    )

```

## Benchmark (Japanese MT bench)

|model|category|score|ver|
|:---|:---|:---|:---|
|Tora-7B-v0.2|Writing|3.8|single-turn|
|Tora-7B-v0.2|Roleplay|7.1|single-turn|
|Tora-7B-v0.2|Reasoning|6.3|single-turn|
|Tora-7B-v0.2|Math|3.0|single-turn|
|Tora-7B-v0.2|Coding|2.2|single-turn|
|Tora-7B-v0.2|Extraction|6.6|single-turn|
|Tora-7B-v0.2|STEM|7.2|single-turn|
|Tora-7B-v0.2|Humanities|8.2|single-turn|

![image/png](https://cdn-uploads.huggingface.co/production/uploads/651e3f30ca333f3c8df692b8/_CBS90NRrYUMXzsFC1LIV.png)


## 謝辞

ChatVectorの記事を執筆してくださった@jovyan様に深くお礼申し上げます。

## 参考

[Chat Vectorを使って日本語LLMをチャットモデルに改造する](https://qiita.com/jovyan/items/ee6affa5ee5bdaada6b4)