File size: 2,766 Bytes
e55bd08
 
 
 
f60e921
e55bd08
 
 
 
 
 
 
 
 
 
 
 
 
 
f60e921
78da2d6
f60e921
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e55bd08
f60e921
131a07a
 
 
 
 
 
 
e55bd08
131a07a
f60e921
131a07a
 
f60e921
131a07a
f60e921
15d1015
 
f60e921
 
e55bd08
f60e921
 
 
 
e55bd08
 
78da2d6
e55bd08
fd8b009
78da2d6
e55bd08
78da2d6
 
 
 
 
 
 
 
e55bd08
 
0474700
78da2d6
e55bd08
e7e3b25
e55bd08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import spaces
from threading import Thread
from typing import Iterator

# Load model and tokenizer
model_name = "Magpie-Align/MagpieLM-4B-Chat-v0.1"

device = "cuda" # the device to load the model onto
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto"
)
model.to(device)

MAX_INPUT_TOKEN_LENGTH = 4096  # You may need to adjust this value

@spaces.GPU
def respond(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are Magpie, a helpful AI assistant.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.9,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
        gr.Slider(minimum=0.5, maximum=1.5, value=1.0, step=0.1, label="Repetation Penalty"),
    ],
)


if __name__ == "__main__":
    demo.queue()
    demo.launch(share=True)