File size: 3,523 Bytes
cfa124c
 
 
5eae7c2
 
cfa124c
 
78a42dd
9064b67
a31fc9a
88ad70a
4ea00c6
764e0ce
ca87d3a
932df04
89fe551
932df04
89fe551
764e0ce
 
 
 
 
 
 
d1f4761
766f004
764e0ce
 
 
212331b
 
34e885b
39b2ad4
a5f2386
50ddfc1
 
5fa0e98
88ad70a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2df7342
88ad70a
 
 
 
 
b89a040
88ad70a
1057e6a
5e0acb9
c40c2ce
8b35ca4
edd99e4
70bf8f7
18eed82
7fe3430
 
 
3f12f24
bb05240
 
 
 
901e011
bb05240
 
3f12f24
d08307a
34e885b
cc3cc81
a5f2386
952a213
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Reference:
#
# https://vimeo.com/990334325/56b552bc7a
# https://platform.openai.com/playground/assistants
# https://cookbook.openai.com/examples/assistants_api_overview_python
# https://platform.openai.com/docs/api-reference/assistants/createAssistant
# https://platform.openai.com/docs/assistants/tools

import gradio as gr

import os, threading

from assistants import (
    set_openai_client,
    get_assistant,
    set_assistant,
    get_thread,
    set_thread,
    create_assistant,
    load_assistant,
    create_thread,
    create_message,
    create_run,
    wait_on_run,
    get_run_steps,
    recurse_execute_tool_calls,
    get_messages,
    extract_content_values,
)

lock = threading.Lock()

def chat(message, history, openai_api_key):
    if not openai_api_key:
        raise gr.Error("OpenAI API Key is required (see additional inputs below).")
    if not message:
        raise gr.Error("Message is required.")

    with lock:
        text_values, image_values = [], []
        download_link = ""
    
        try:
            if os.environ["OPENAI_API_KEY"] != openai_api_key:
                os.environ["OPENAI_API_KEY"] = openai_api_key
                
                set_openai_client()
                
                #set_assistant(create_assistant()) # first run
                set_assistant(load_assistant()) # subsequent runs
                
            if get_thread() == None or len(history) == 0:
                set_thread(create_thread())
                
            create_message(get_thread(), message)
            run = create_run(get_assistant(), get_thread())
            run = wait_on_run(get_thread(), run)
            run_steps = get_run_steps(get_thread(), run)
            recurse_execute_tool_calls(get_thread(), run, run_steps, 0)
            messages = get_messages(get_thread())
            text_values, image_values = extract_content_values(messages)
            
            # TODO: Handle multiple images and other file types
            if len(image_values) > 0:
                download_link = f"<hr>[Download](https://platform.openai.com/storage/files/{image_values[0]})"
        except Exception as e:
            raise gr.Error(e)

        return f"{'<hr>'.join(list(reversed(text_values))[1:])}{download_link}"

gr.ChatInterface(
        fn=chat,
        chatbot=gr.Chatbot(height=250),
        textbox=gr.Textbox(placeholder="Ask anything", container=False, scale=7),
        title="Python Coding Assistant",
        description=os.environ.get("DESCRIPTION"),
        clear_btn="Clear",
        retry_btn=None,
        undo_btn=None,
        examples=[
                  ["Generate: Code to fine-tune model meta-llama/Meta-Llama-3.1-8B on dataset gretelai/synthetic_text_to_sql using QLoRA"],
                  ["Explain: r\"^(?=.*[A-Z])(?=.*[a-z])(?=.*[0-9])(?=.*[\\W]).{8,}$\""],
                  ["Fix: x = [5, 2, 1, 3, 4]; print(x.sort())"],
                  ["Optimize: x = []; for i in range(0, 10000): x.append(i)"],
                  ["1. Execute: Calculate the first 25 Fibbonaci numbers. 2. Show the code."],
                  ["1. Execute with tools: Create a plot showing stock gain QTD for NVDA and AMD, x-axis is \"Day\" and y-axis is \"Gain %\". 2. Show the code."],
                  ["1. Execute with tools: Get key announcements from latest OpenAI Dev Day. 2. Show the web references."]
                 ],
        cache_examples=False,
        additional_inputs=[
            gr.Textbox("", label="OpenAI API Key")
        ],
    ).launch()