File size: 2,599 Bytes
a265a8c
 
07fea5c
 
 
 
6825b98
 
 
 
 
 
 
0fcc8ab
 
6825b98
0fcc8ab
 
 
 
 
 
 
 
 
 
 
 
6825b98
 
 
0fcc8ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e52ea2c
0fcc8ab
1eccc0b
6825b98
 
2748af1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
---
license: mit
tags:
- text generation
- RAG
- baichuan2
---

This model is a 7B Chinese version of [Self-RAG](https://huggingface.co/selfrag/selfrag_llama2_7b).

It is trained on Baichuan2-7B-Chat with a sample of [belle](https://github.com/LianjiaTech/BELLE) sft data, acompanying with interleaving passages from zhwiki. The reflection tokens are aligned with the original verison (in English), so the usage is the same. Hope you enjoy.

### Usage
I found some output errors while adopting vllm to accelerate the generation process and not sure whether it is due to some precision issues.
This may be owing to the implementation of vllm. Thus, I use the original generate method of transformers.
```
import os, torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(YOUR_TOKENIZER_PATH)
model = AutoModelForCausalLM.from_pretrained(
        YOUR_MODEL_PATH,
        torch_dtype=torch.bfloat16,
        device_map="cuda",
    )

### set your retriever if necessary
retriever = setup_retriever(YOUR_RETRIEVER_PATH)


def format_prompt(input, paragraph=None):
    prompt = "### Instruction:\n{0}\n\n### Response:".format(input)
    if paragraph is not None:
        prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
    return prompt


while True:
    query = input("[Human]: ")
    prompt = format_prompt(query)
    sequences = model.generate(
        **tokenizer(prompt, return_tensors='pt').to(model.device),
        do_sample=False,
        num_beams=5,
        # top_k=10,
        # top_p=0.8,
        temperature=0.9,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        max_new_tokens=1024,
        min_new_tokens=1,
        repetition_penalty=1.5,
    )
    for seq in sequences:
        print(f"[Model]: {tokenizer.decode(seq, skip_special_tokens=False)}")
        print("-"*50)
    print("="*50)

# query_1 = "你好呀"
# Model prediction: [No Retrieval] 你好!有什么我可以帮你解答的问题吗? [Utility:5] </s>
# query_2 = "故宫三大殿是哪些?"
# Model prediction: [Retrieval] <paragraph> ... (this query requires factual grounding, call a retriever) </paragraph> [Relevant] 太和殿、中和殿、保和殿 [Utility:5] </s>
```

### Data
The data used to train the model is also available ([FINAL_OUTPUT_4w.jsonl](https://huggingface.co/Aman/selfrag-zh_baichuan2_7b_chat/blob/main/FINAL_OUTPUT_4w.jsonl)), which is constructed using [Belle](https://github.com/LianjiaTech/BELLE/tree/main/data/1.5M) SFT data and Wikipedia Chinese docs.
Hope you enjoy it!