File size: 4,590 Bytes
891eb83
672d8c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8316353
672d8c3
 
 
 
1c78082
672d8c3
 
 
 
 
 
 
 
 
 
 
 
 
29aa575
672d8c3
 
 
 
29aa575
 
672d8c3
 
 
 
 
 
 
 
defdef0
672d8c3
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import gradio as gr
from openai import OpenAI
import jinja2
from transformers import AutoTokenizer

# Initialize the OpenAI client
client = OpenAI(
    base_url="https://api.hyperbolic.xyz/v1",
    api_key=os.environ["HYPERBOLIC_API_KEY"],
)

# the tokenizer complains later after gradio forks without this setting.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# use unofficial copy of Llama to avoid access restrictions.
tokenizer = AutoTokenizer.from_pretrained("mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated")

# Initial prompt
initial_prompts = {
    "Default": ["405B", """A chat between a person and the Llama 3.1 405B base model.

"""],
}

# ChatML template
chatml_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"""
chat_template = """{% for message in messages %}{{'<' + message['role'] + '>: ' + message['content'] + '\n'}}{% endfor %}"""

def format_chat(messages, use_chatml=False):
    if use_chatml:
        template = jinja2.Template(chatml_template)
    else:
        template = jinja2.Template(chat_template)
    formatted = template.render(messages=messages)
    return formatted

def count_tokens(text):
    return len(tokenizer.encode(text))

def limit_history(initial_prompt, history, new_message, max_tokens):
    limited_history = []

    token_count = count_tokens(new_message) + count_tokens(initial_prompt)
    if token_count > max_tokens:
        raise(ValueError("message too large for context window"))

    for user_msg, assistant_msg in reversed(history):
        # TODO add ChatML wrapping here for better counting?
        user_tokens = count_tokens(user_msg)
        assistant_tokens = count_tokens(assistant_msg)
        if token_count + user_tokens + assistant_tokens > max_tokens:
            break
        token_count += user_tokens + assistant_tokens
        limited_history.insert(0, (user_msg, assistant_msg))
    return limited_history


def generate_response(message, history, initial_prompt, user_role, assistant_role, use_chatml):
    context_length = 8192
    response_length = 1000
    slop_length = 300  # slop for chatml encoding etc--TODO fix this

    # trim history based on token count
    history_tokens = context_length - response_length - slop_length
    limited_history = limit_history(initial_prompt, history, message, max_tokens=history_tokens)

    # Prepare the input
    chat_history = [{"role": user_role if i % 2 == 0 else assistant_role, "content": m}
                for i, m in enumerate([item for sublist in limited_history for item in sublist] + [message])]
    formatted_input = format_chat(chat_history, use_chatml)

    if use_chatml:
        full_prompt = "<|im_start|>system\n" + initial_prompt + "<|im_end|>\n" + formatted_input + f"<|im_start|>{assistant_role}\n"
    else:
        full_prompt = initial_prompt + "\n\n" + formatted_input + f"<{assistant_role}>:"

    completion = client.completions.create(
        model="meta-llama/Meta-Llama-3.1-405B",
        prompt=full_prompt,
        temperature=0.7,
        frequency_penalty=0.1,
        max_tokens=response_length,
        stop=[f'<{user_role}>:', f'<{assistant_role}>:'] if not use_chatml else [f'<|im_end|>']
    )
    
    assistant_response = completion.choices[0].text.strip()
    return assistant_response

with gr.Blocks(theme=gr.themes.Soft()) as iface:
    with gr.Row():
        initial_prompt = gr.Textbox(
            value="Please respond in whatever manner comes most naturally to you. You do not need to act as an assistant.",
            label="Initial Prompt",
            lines=3
        )
    with gr.Column():
        user_role = gr.Textbox(value="user", label="User Role")
        assistant_role = gr.Textbox(value="model", label="Assistant Role")
        use_chatml = gr.Checkbox(label="Use ChatML", value=True)


    chatbot = gr.ChatInterface(
        generate_response,
        title="Chat with 405B",
        additional_inputs=[initial_prompt, user_role, assistant_role, use_chatml],
        concurrency_limit=10,
        chatbot=gr.Chatbot(height=600)
    )

    gr.Markdown("""
This chat interface is powered by the Llama 3.1 405B base model, served by [Hyperbolic](https://hyperbolic.xyz), The Open Access AI Cloud.

Thank you to Hyperbolic for making this base model available!
""")


# Launch the interface
iface.launch(share=True, max_threads=40)