File size: 11,126 Bytes
aa7d8c4
5eae7c2
cfa124c
 
 
5eae7c2
 
cfa124c
 
78a42dd
9064b67
6a5793a
a31fc9a
 
2a3fc7b
9064b67
16d66ad
9064b67
73899aa
6a5793a
2db705f
9064b67
73899aa
 
9064b67
872c848
10d6c27
5da2b8f
9064b67
f10c5f8
8b7dba2
a05ded5
16d66ad
2db705f
f10c5f8
fd3fd7b
6a5793a
 
2e01617
6a5793a
2e01617
a31fc9a
f10c5f8
73899aa
 
 
 
 
 
1d42f25
 
a31fc9a
73899aa
1d42f25
 
73899aa
 
70bf8f7
0b33796
10d6c27
 
 
 
59c15ca
7accea0
2db705f
16d66ad
a31fc9a
73899aa
2db705f
9064b67
2de5e80
8153096
2de5e80
9064b67
fc30f91
73899aa
 
8153096
0784f56
fc30f91
73899aa
 
8153096
9064b67
 
73899aa
 
9064b67
 
39b970f
9064b67
2de5e80
8153096
9064b67
 
73899aa
 
9064b67
 
ee9f7f7
9064b67
2de5e80
8153096
9064b67
 
73899aa
9064b67
73899aa
9064b67
 
 
71987ad
 
2de5e80
8153096
2798fff
55021f1
 
b511565
9064b67
 
73899aa
 
9064b67
 
 
 
fc30f91
8153096
fc30f91
 
67b3b6b
 
049fb2a
 
47417a1
 
8b7dba2
1d42f25
67b3b6b
42e9172
7ddca6e
1012371
e7c3210
 
7ddca6e
9064b67
 
7ddca6e
5fcd91e
fc30f91
67b3b6b
 
 
42e9172
0d6f576
e7c3210
 
67b3b6b
e7c3210
7ddca6e
73899aa
 
9064b67
 
2de5e80
8153096
9064b67
6e6e7d5
9064b67
03869b0
2de5e80
9064b67
 
 
7ddca6e
 
29d58d0
7ddca6e
 
2de5e80
29d58d0
9064b67
e7c3210
 
 
c865dde
e7c3210
5e2a083
 
6e17db2
 
 
 
 
 
 
 
 
 
 
e7c3210
c865dde
e7c3210
 
 
5befa9c
50ddfc1
 
 
536b36d
e04bd50
5da2b8f
872c848
73899aa
872c848
42c9326
e893203
73899aa
9064b67
73899aa
9064b67
73899aa
9064b67
73899aa
 
7bf3f3d
536b36d
e7c3210
55b9732
c066ca5
a31fc9a
686539e
 
 
 
 
 
 
 
 
 
 
 
 
73899aa
a31fc9a
 
686539e
a31fc9a
 
73899aa
 
1b1181f
536b36d
3216382
c066ca5
536b36d
686539e
 
 
 
 
 
 
 
 
 
 
 
 
73899aa
536b36d
 
686539e
536b36d
 
73899aa
 
0150bec
9064b67
73899aa
9064b67
29d58d0
29028c0
7ddca6e
 
a2df0ee
19bfc9a
2de5e80
21ce7f1
1057e6a
5e0acb9
c40c2ce
d92a321
edd99e4
70bf8f7
10d6c27
aa7d8c4
 
536b36d
70bf8f7
4be439f
872c848
10d6c27
7fe3430
 
 
3f12f24
1292850
 
9247b68
1292850
da3afee
73899aa
70bf8f7
3f12f24
aa7d8c4
952a213
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
# TODO: Gradio session / multi-user thread

# Reference:
#
# https://vimeo.com/990334325/56b552bc7a
# https://platform.openai.com/playground/assistants
# https://cookbook.openai.com/examples/assistants_api_overview_python
# https://platform.openai.com/docs/api-reference/assistants/createAssistant
# https://platform.openai.com/docs/assistants/tools

import gradio as gr
import pandas as pd
import yfinance as yf

import json, openai, os, time

from datetime import date
from openai import OpenAI
from tavily import TavilyClient
from typing import List
from utils import function_to_schema, show_json

openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API_KEY"))

assistant_id = "asst_DbCpNsJ0vHSSdl6ePlkKZ8wG"

assistant, thread = None, None

def today_tool() -> str:
    """Returns today's date. Use this function for any questions related to knowing today's date. 
       There should be no input. This function always returns today's date."""
    return str(date.today())

def yf_download_tool(tickers: List[str], start_date: date, end_date: date) -> pd.DataFrame:
    """Returns historical stock data for a list of given tickers from start date to end date 
       using the yfinance library download function. 
       Use this function for any questions related to getting historical stock data. 
       The input should be the tickers as a List of strings, a start date, and an end date. 
       This function always returns a pandas DataFrame."""
    return yf.download(tickers, start=start_date, end=end_date)

def tavily_search_tool(query: str) -> str:
    """Searches the web for a given query and returns an answer, "
       ready for use as context in a RAG application, using the Tavily API. 
       Use this function for any questions requiring knowledge not available to the model. 
       The input should be the query string. This function always returns an answer string."""
    return tavily_client.get_search_context(query=query, max_results=5)

tools = {
    "today_tool": today_tool,
    "yf_download_tool": yf_download_tool,
    "tavily_search_tool": tavily_search_tool,
}

def create_assistant(openai_client):
    assistant = openai_client.beta.assistants.create(
        name="Python Coding Assistant",
        instructions=(
             "You are a Python programming language expert that "
             "generates Pylint-compliant code and explains it. "
             "Execute code when explicitly asked to."
        ),
        model="gpt-4o",
        tools=[
            {"type": "code_interpreter"},
            {"type": "function", "function": function_to_schema(today_tool)},
            {"type": "function", "function": function_to_schema(yf_download_tool)},
            {"type": "function", "function": function_to_schema(tavily_search_tool)},
        ],
    )
    
    show_json("assistant", assistant)
    
    return assistant

def load_assistant(openai_client):   
    assistant = openai_client.beta.assistants.retrieve(assistant_id)
    show_json("assistant", assistant)
    return assistant

def create_thread(openai_client):
    thread = openai_client.beta.threads.create()
    show_json("thread", thread)
    return thread

def create_message(openai_client, thread, msg):        
    message = openai_client.beta.threads.messages.create(
        role="user",
        thread_id=thread.id,
        content=msg,
    )
    
    show_json("message", message)
    return message

def create_run(openai_client, assistant, thread):
    run = openai_client.beta.threads.runs.create(
        assistant_id=assistant.id,
        thread_id=thread.id,
        parallel_tool_calls=False,
    )
    
    show_json("run", run)
    return run

def wait_on_run(openai_client, thread, run):
    while run.status == "queued" or run.status == "in_progress":
        run = openai_client.beta.threads.runs.retrieve(
            thread_id=thread.id,
            run_id=run.id,
        )
            
        time.sleep(1)
    
    show_json("run", run)

    if hasattr(run, "last_error") and run.last_error:
        raise gr.Error(run.last_error)

    return run

def get_run_steps(openai_client, thread, run):
    run_steps = openai_client.beta.threads.runs.steps.list(
        thread_id=thread.id,
        run_id=run.id,
        order="asc",
    )

    show_json("run_steps", run_steps)
    return run_steps

def execute_tool_call(tool_call):
    name = tool_call.function.name
    args = {}

    if len(tool_call.function.arguments) > 10:
        args = json.loads(tool_call.function.arguments)

    return tools[name](**args)

def execute_tool_calls(run_steps):
    run_step_details = []

    tool_call_ids = []
    tool_call_results = []
    
    for step in run_steps.data:
        step_details = step.step_details
        run_step_details.append(step_details)
        show_json("step_details", step_details)
        
        if hasattr(step_details, "tool_calls"):
            for tool_call in step_details.tool_calls:
                show_json("tool_call", tool_call)
                
                if hasattr(tool_call, "function"):
                    tool_call_ids.append(tool_call.id)
                    tool_call_results.append(execute_tool_call(tool_call))

    return tool_call_ids, tool_call_results

def get_messages(openai_client, thread):
    messages = openai_client.beta.threads.messages.list(
        thread_id=thread.id
    )
    
    show_json("messages", messages)
    return messages
                        
def extract_content_values(data):
    text_values, image_values = [], []
    
    for item in data.data:
        for content in item.content:
            if content.type == "text":
                text_value = content.text.value
                text_values.append(text_value)
            if content.type == "image_file":
                image_value = content.image_file.file_id
                image_values.append(image_value)
    
    return text_values, image_values

###
def generate_tool_outputs(tool_call_ids, tool_call_results):
    tool_outputs = []
    
    for tool_call_id, tool_call_result in zip(tool_call_ids, tool_call_results):
        tool_output = {}
        
        try:
            tool_output = {
                "tool_call_id": tool_call_id,
                "output": tool_call_result.to_json()
            }
        except AttributeError:
            tool_output = {
                "tool_call_id": tool_call_id,
                "output": tool_call_result
            }
            
        tool_outputs.append(tool_output)
    
    return tool_outputs
###

def chat(message, history):
    if not message:
        raise gr.Error("Message is required.")
    
    global assistant, thread     
    
    if assistant == None:
        #assistant = create_assistant(openai_client) # on first run, create assistant and update assistant_id
                                                     # see https://platform.openai.com/playground/assistants
        assistant = load_assistant(openai_client) # on subsequent runs, load assistant
    
    if thread == None or len(history) == 0:
        thread = create_thread(openai_client)
        
    create_message(openai_client, thread, message)

    run = create_run(openai_client, assistant, thread)

    run = wait_on_run(openai_client, thread, run)
    run_steps = get_run_steps(openai_client, thread, run)

    ### TODO
    tool_call_ids, tool_call_results = execute_tool_calls(run_steps)
    
    if len(tool_call_ids) > 0:
        # https://platform.openai.com/docs/api-reference/runs/submitToolOutputs
        tool_output = {}
        
        try:
            tool_output = {
                "tool_call_id": tool_call_ids[0],
                "output": tool_call_results[0].to_json()
            }
        except AttributeError:
            tool_output = {
                "tool_call_id": tool_call_ids[0],
                "output": tool_call_results[0]
            }
        
        run = openai_client.beta.threads.runs.submit_tool_outputs(
            thread_id=thread.id,
            run_id=run.id,
            tool_outputs=[tool_output]
        )
    
        run = wait_on_run(openai_client, thread, run)
        run_steps = get_run_steps(openai_client, thread, run)
    ###
        tool_call_ids, tool_call_results = execute_tool_calls(run_steps)
            
        if len(tool_call_ids) > 1:
            # https://platform.openai.com/docs/api-reference/runs/submitToolOutputs
            tool_output = {}
            
            try:
                tool_output = {
                    "tool_call_id": tool_call_ids[1],
                    "output": tool_call_results[1].to_json()
                }
            except AttributeError:
                tool_output = {
                    "tool_call_id": tool_call_ids[1],
                    "output": tool_call_results[1]
                }
            
            run = openai_client.beta.threads.runs.submit_tool_outputs(
                thread_id=thread.id,
                run_id=run.id,
                tool_outputs=[tool_output]
            )
        
            run = wait_on_run(openai_client, thread, run)
            run_steps = get_run_steps(openai_client, thread, run)    
    ###
    
    messages = get_messages(openai_client, thread)

    text_values, image_values = extract_content_values(messages)

    download_link = ""
    
    if len(image_values) > 0:
        download_link = f"<p>Download: https://platform.openai.com/storage/files/{image_values[0]}</p>"
    
    return f"{'<hr>'.join(list(reversed(text_values))[1:])}{download_link}"

gr.ChatInterface(
        fn=chat,
        chatbot=gr.Chatbot(height=350),
        textbox=gr.Textbox(placeholder="Ask anything", container=False, scale=7),
        title="Python Coding Assistant",
        description=(
            "The assistant can **generate, explain, fix, optimize,** and **document Python code, "
            "create unit test cases,** and **answer general coding-related questions.** "
            "It can also **execute code**. "
            "The assistant has access to a <b>today tool</b> (get current date), to a "
            "**yfinance download tool** (get stock data), and to a "
            "**tavily search tool** (web search)."
        ),
        clear_btn="Clear",
        retry_btn=None,
        undo_btn=None,
        examples=[
                  ["Generate: Python code to fine-tune model meta-llama/Meta-Llama-3.1-8B on dataset gretelai/synthetic_text_to_sql using QLoRA"],
                  ["Explain: r\"^(?=.*[A-Z])(?=.*[a-z])(?=.*[0-9])(?=.*[\\W]).{8,}$\""],
                  ["Fix: x = [5, 2, 1, 3, 4]; print(x.sort())"],
                  ["Optimize: x = []; for i in range(0, 10000): x.append(i)"],
                  ["Execute: First 25 Fibbonaci numbers"],
                  ["Execute with tools: Create a plot showing stock gain QTD for NVDA and AMD, x-axis is \"Day\" and y-axis is \"Gain %\""],
                  ["Execute with tools: Get key announcements from the latest OpenAI Dev Day"]
                 ],
        cache_examples=False,
    ).launch()