File size: 3,406 Bytes
0231e6a
 
 
 
54f7da0
8df0f23
 
1a5890e
0231e6a
 
 
 
 
8df0f23
0231e6a
 
54f7da0
 
 
 
0231e6a
460e2a9
 
 
 
 
 
 
0231e6a
 
 
 
54f7da0
0231e6a
 
8df0f23
0231e6a
8df0f23
460e2a9
0231e6a
 
 
 
460e2a9
0231e6a
 
 
 
 
54f7da0
460e2a9
0231e6a
 
 
 
54f7da0
9e82682
0231e6a
 
 
 
9e82682
 
 
 
 
 
0231e6a
9e82682
0231e6a
 
8df0f23
a19b85a
 
 
0231e6a
302880f
0231e6a
 
8359b58
0231e6a
 
 
 
 
 
 
 
 
f358cdd
460e2a9
 
ddd17ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr 
import os
from typing import List
import logging
import urllib.request
from utils import model_name_mapping, urial_template, openai_base_request, DEFAULT_API_KEY
from constant import js_code_label, HEADER_MD
from openai import OpenAI

# add logging info to console 
logging.basicConfig(level=logging.INFO)



URIAL_VERSION = "inst_1k_v4.help"

URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt"
urial_prompt = urllib.request.urlopen(URIAL_URL).read().decode('utf-8')
urial_prompt = urial_prompt.replace("```", '"""') # new version of URIAL uses """ instead of ```
STOP_STRS = ['"""', '# Query:', '# Answer:']


def respond(
    message,
    history: list[tuple[str, str]],
    max_tokens,
    temperature,
    top_p,
    rp,
    model_name,
    together_api_key
):  
    global STOP_STRS, urial_prompt
    rp = 1.0
    prompt = urial_template(urial_prompt, history, message)
    
    # _model_name = "meta-llama/Llama-3-8b-hf"
    _model_name = model_name_mapping(model_name)

    if together_api_key and len(together_api_key) == 64:
        api_key = together_api_key
    else:
        api_key = DEFAULT_API_KEY

    request = openai_base_request(prompt=prompt, model=_model_name, 
                                   temperature=temperature, 
                                   max_tokens=max_tokens, 
                                   top_p=top_p, 
                                   repetition_penalty=rp,
                                   stop=STOP_STRS, api_key=api_key)  
    response = ""
    for msg in request:
        # print(msg.choices[0].delta.keys())
        token = msg.choices[0].delta["content"]
        should_stop = False
        for _stop in STOP_STRS:
            if _stop in response + token:
                should_stop = True
                break
        if should_stop:
            break
        response += token
        if response.endswith('\n"'):
            response = response[:-1]
        elif response.endswith('\n""'):
            response = response[:-2]
        yield response
 
with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown(HEADER_MD)
            model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B", "Mistral-7B-v0.1", 
                                   "Mixtral-8x22B", "Yi-6B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"]
                                  , value="Llama-3-8B", label="Base LLM name")
        with gr.Column():
            together_api_key = gr.Textbox(label="🔑 Together APIKey", placeholder="Enter your Together API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key")
            with gr.Column():
                with gr.Row():
                    max_tokens = gr.Textbox(value=256, label="Max tokens")
                    temperature = gr.Textbox(value=0.5, label="Temperature")
                    top_p = gr.Textbox(value=0.9, label="Top-p")
                    rp = gr.Textbox(value=1.1, label="Repetition penalty")
    chat = gr.ChatInterface(
        respond,
        additional_inputs=[max_tokens, temperature, top_p, rp, model_name, together_api_key],
        # additional_inputs_accordion="⚙️ Parameters",
        # fill_height=True, 
    )
    chat.chatbot.height = 550

if __name__ == "__main__":
    demo.launch(show_api=False)