File size: 4,750 Bytes
5358a12
bbdb282
 
 
5358a12
bbdb282
5358a12
bbdb282
6acd15c
bbdb282
6acd15c
bbdb282
 
 
 
 
 
 
6acd15c
bbdb282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cae16e
bbdb282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6acd15c
bbdb282
 
 
 
 
 
6acd15c
 
bbdb282
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
---
language:
- ja
- en
license: mit
library_name: transformers
---

![SpiralAI Spiral-RetNet-3b-base](logo.png)

# SpiralAI Spiral-RetNet-3b-base

We have conducted pre-training from scratch on the RetNet (https://arxiv.org/abs/2307.08621) architecture model 3b using a mixed dataset of Japanese and English.
This model is released primarily for the basic research of "retention mechanism".

# Model Description

- **Developed by:** [SpiralAI](https://go-spiral.ai/)
- **Model type:** The `SpiralAI Spiral-RetNet-3b-base` is a language model equipped with a retention mechanism. It uses the `cyberagent/calm2-7b-chat` tokenizer.
- **Languages:** Japanese, English.
- **License:** MIT
- **Training:** Trained on 80b tokens.
- **Context Length:** 2,048 tokens.

# Installation

```bash
pip install transformers==4.38  # The top_k_top_p_filtering feature has been removed in later versions.

```

Clone the repository from **`https://github.com/syncdoth/RetNet`** and follow the *Getting Started* guide provided there.

Example:

```bash
git clone https://github.com/syncdoth/RetNet.git
pip install torch transformers timm
cd RetNet

```

# Usage

```python
from transformers import AutoTokenizer

from retnet.modeling_retnet import RetNetForCausalLM

tokenizer = AutoTokenizer.from_pretrained("cyberagent/calm2-7b-chat")
tokenizer.pad_token = tokenizer.eos_token

model = RetNetForCausalLM.from_pretrained(
    "Spiral-AI/Spiral-RetNet-3b-base", device_map="auto"
)
inputs = tokenizer("最近、秋葉原周辺で興味深い", return_tensors="pt")
input_ids = inputs["input_ids"].to(model.device)
generated = model.generate(
    input_ids,
    max_new_tokens=32,
    repetition_penalty=1.2,  # better to set this value for 3 billion model
)
print(tokenizer.decode(generated[0]))

```

## Examples
```
input: 最近、秋葉原周辺で興味深い
output: お店がいくつかあります。
1. 神田カレー街「カレーハウスCoCo壱番屋」
2016年7月3日オープン
```

```
input: 近年、AI技術の進歩によって
output: 人間の仕事が奪われるのではないかという懸念がある。
しかしながら、AIは人間に取って代わるものではなく、「人間がコンピュータに仕事をさせる」という考え方
```

```
input: When I was a child, I used to play with
output: 3-D glasses. They were so much fun!
I have been playing around in the world of video games for years now and it is amazing how
```

# Basic study

## Visualization of the retention mechanism

![retention](retention.gif)
This visualization shows the retention mechanism in action. The token being generated is represented by `*`.
The blue bars show how the tokens are weighted during generation.

Using the mathmatical equivalence between "recurrent mode" and "parallel mode", we apply the similar visualization technique as the attention mechanism, e.g.,
inner product between queries and keys are added up over all heads after absolute values are taken.
Here we show the result of the last layer.

## Test loss comparison

We compared the test loss of `Spiral-AI/Spiral-RetNet-3b-base` and `cyberagent/open-calm-3b` on different length of tokens.
The first 100 examples are extracted from `wikipedia-ja` for the test dataset.

![test_loss](loss_comparison.png)

Key findings are:

- The test loss of `Spiral-AI/Spiral-RetNet-3b-base` goes as low as `cyberagent/open-calm-3b`, showing the effectiveness of the retention mechanism.
- The explosion of test loss is suppressed in `Spiral-AI/Spiral-RetNet-3b-base` when the context length goes longer than 2,048 tokens (the maximum context length of training data; Note that `cyberagent/open-calm-3b` is trained on the same context length.).

# Training Datasets

- [izumi-lab/cc100-ja-filter-ja-normal](https://huggingface.co/datasets/izumi-lab/cc100-ja-filter-ja-normal) (Japanese)
- [izumi-lab/wikipedia-ja-20230720](https://huggingface.co/datasets/izumi-lab/wikipedia-ja-20230720) (Japanese)
- [wikipedia](https://huggingface.co/datasets/wikipedia/tree/main/data/20220301.en) (English)
- [uonlp/CulturaX](https://huggingface.co/datasets/uonlp/CulturaX) (English, Japanese)

# Limitations

This model is designed for broad applicability, but it may not fully meet the specific needs or contexts of all uses.
Pre-training data may contain inappropriate content, which could be reflected in the texts generated by the model. Therefore, when using this model, it is important to carefully review its output and avoid situations where it might cause discomfort or harm to individuals or groups.

There are no specific restrictions on commercial use, but users are responsible for addressing any ethical or legal issues that may arise in connection with the use of the model.