gradio_101 / app.py
whoami02's picture
Add query op and alarm table
f281782 verified
import os
import gradio as gr
from dotenv import load_dotenv, find_dotenv
from langchain.utilities.sql_database import SQLDatabase
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chat_models.anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from langchain.agents import create_sql_agent, AgentType
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.tracers import ConsoleCallbackHandler
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from huggingface_hub import login
from langchain.globals import set_verbose
from sqlalchemy import create_engine
from prompts import agent_template, table_info
set_verbose(True)
# load_dotenv(find_dotenv(r".env"))
def load_model(model_id):
if model_id == "gemini":
return ChatGoogleGenerativeAI(
model='gemini-pro',
google_api_key=os.getenv("GOOGLE_API_KEY"),
convert_system_message_to_human=True,
temperature=0.05,
verbose=True,
)
elif model_id == "claude":
return ChatAnthropic(
model_name="claude-2",
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
temperature=0.05,
streaming=True,
verbose=True,
)
else:
print("only gemini and claude supported aofn")
def chain(db, llm):
def get_schema(_):
return db.get_table_info()
def run_query(query):
return db.run(query)
template = """Based on the table schema below, write a MS SQL query that would answer the user's question:
{schema}
Question: {question}
Query:"""
prompt = ChatPromptTemplate.from_messages(
[
("system", "Given an input question, convert it to a MSSQL query. No pre-amble."),
("human", template),
]
)
template = """Based on the table schema below, question, mssql query, and mssql response, write a natural language response:
{schema}
Question: {question}
MS-SQL Query: {query}
MS-SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_messages(
[
("system", "Given an input question and MS-SQL response, convert it to a natural language answer. No pre-amble."),
("human", template),
]
)
sql_response = (
RunnablePassthrough.assign(schema=get_schema)
| prompt
| llm.bind(stop=["\nSQLResult:"])
| StrOutputParser()
)
full_chain = (
RunnablePassthrough.assign(query=sql_response)
| RunnablePassthrough.assign(
schema=get_schema,
response=lambda x: db.run(x["query"]),
)
| prompt_response
| llm
)
return full_chain
def main():
gemini = load_model("gemini")
agent_llm = load_model("claude")
path = r"OPPI_shift.db" # \OPPI_down.db"
db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0)
db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0)
db3 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails','ShiftProductionDetails','Focas_AlarmHistory'],sample_rows_in_table_info=0)
down_chain = chain(db=db1, llm=gemini)
prod_chain = chain(db=db2, llm=gemini)
def echo1(message, history):
ans = down_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
return str(ans)
def echo2(message, history):
ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
return str(ans)
prompt_agent = ChatPromptTemplate.from_messages(
[
("system", "Given an input question, create a syntactically correct MS-SQL query to run, then look at the results of the query and return the answer in natural language. No Pre-amble."+agent_template),
("human", "{question}"+table_info)
]
)
examples = [
"calculate total Prod quantity in Second Shift for 2024",
"Calculate total accepted parts in shift 2 for 2024",
"How many accepted parts were produced in October 2023 in each machine",
"How likely is the Turrent index aborted alarm expected on machine k-1",
"List all the distinct reasons behind DownTime in machine K-2",
"Calculate the total Downtime experienced by machine K-8 due to the reason of No Shift",
"What was the most common reason for Downtime in the year 2023?",
"Calculate the average downtime for Machine M-2 in for every month in later half of 2023",
"return all the reasons for Downcategory in Nov and dec on machine L-7 in 3rd shift",
]
sql_toolkit = SQLDatabaseToolkit(db=db3, llm=agent_llm)
agent = create_sql_agent(
toolkit=sql_toolkit,
llm=agent_llm,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
agent_executor_kwargs={"handle_parsing_errors":True, "return_intermediate_steps": True}
)
def echo3(message, history):
answer = agent.invoke(prompt_agent.format_prompt(question=message))
final_answer = f"Final Query:- {list(answer['intermediate_steps'][1][0])[-2][1].split('Action Input: ')[-1]}\n\nAnswer:- {answer['output']}"
return final_answer
downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table")
production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table")
agent_tab = gr.ChatInterface(fn=echo3, examples=examples, title="SQL-Chatbot", description="General Chatbot with self-thinking capability, more robust to questions.")
demo = gr.TabbedInterface([agent_tab, downtime, production], ['DB_bot-both tables','ShiftDownTimeDetails', 'ShiftProductionDetails'])
demo.launch(debug=True, share=True)
if __name__ == "__main__":
main()