Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from dotenv import load_dotenv, find_dotenv | |
from langchain.utilities.sql_database import SQLDatabase | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain.chat_models.anthropic import ChatAnthropic | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain.agents import create_sql_agent, AgentType | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.tracers import ConsoleCallbackHandler | |
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit | |
from huggingface_hub import login | |
from langchain.globals import set_verbose | |
from sqlalchemy import create_engine | |
from prompts import agent_template, table_info | |
set_verbose(True) | |
# load_dotenv(find_dotenv(r".env")) | |
def load_model(model_id): | |
if model_id == "gemini": | |
return ChatGoogleGenerativeAI( | |
model='gemini-pro', | |
google_api_key=os.getenv("GOOGLE_API_KEY"), | |
convert_system_message_to_human=True, | |
temperature=0.05, | |
verbose=True, | |
) | |
elif model_id == "claude": | |
return ChatAnthropic( | |
model_name="claude-2", | |
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"), | |
temperature=0.05, | |
streaming=True, | |
verbose=True, | |
) | |
else: | |
print("only gemini and claude supported aofn") | |
def chain(db, llm): | |
def get_schema(_): | |
return db.get_table_info() | |
def run_query(query): | |
return db.run(query) | |
template = """Based on the table schema below, write a MS SQL query that would answer the user's question: | |
{schema} | |
Question: {question} | |
Query:""" | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", "Given an input question, convert it to a MSSQL query. No pre-amble."), | |
("human", template), | |
] | |
) | |
template = """Based on the table schema below, question, mssql query, and mssql response, write a natural language response: | |
{schema} | |
Question: {question} | |
MS-SQL Query: {query} | |
MS-SQL Response: {response}""" | |
prompt_response = ChatPromptTemplate.from_messages( | |
[ | |
("system", "Given an input question and MS-SQL response, convert it to a natural language answer. No pre-amble."), | |
("human", template), | |
] | |
) | |
sql_response = ( | |
RunnablePassthrough.assign(schema=get_schema) | |
| prompt | |
| llm.bind(stop=["\nSQLResult:"]) | |
| StrOutputParser() | |
) | |
full_chain = ( | |
RunnablePassthrough.assign(query=sql_response) | |
| RunnablePassthrough.assign( | |
schema=get_schema, | |
response=lambda x: db.run(x["query"]), | |
) | |
| prompt_response | |
| llm | |
) | |
return full_chain | |
def main(): | |
gemini = load_model("gemini") | |
agent_llm = load_model("claude") | |
path = r"OPPI_shift.db" # \OPPI_down.db" | |
db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0) | |
db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0) | |
db3 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails','ShiftProductionDetails','Focas_AlarmHistory'],sample_rows_in_table_info=0) | |
down_chain = chain(db=db1, llm=gemini) | |
prod_chain = chain(db=db2, llm=gemini) | |
def echo1(message, history): | |
ans = down_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]}) | |
return str(ans) | |
def echo2(message, history): | |
ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]}) | |
return str(ans) | |
prompt_agent = ChatPromptTemplate.from_messages( | |
[ | |
("system", "Given an input question, create a syntactically correct MS-SQL query to run, then look at the results of the query and return the answer in natural language. No Pre-amble."+agent_template), | |
("human", "{question}"+table_info) | |
] | |
) | |
examples = [ | |
"calculate total Prod quantity in Second Shift for 2024", | |
"Calculate total accepted parts in shift 2 for 2024", | |
"How many accepted parts were produced in October 2023 in each machine", | |
"How likely is the Turrent index aborted alarm expected on machine k-1", | |
"List all the distinct reasons behind DownTime in machine K-2", | |
"Calculate the total Downtime experienced by machine K-8 due to the reason of No Shift", | |
"What was the most common reason for Downtime in the year 2023?", | |
"Calculate the average downtime for Machine M-2 in for every month in later half of 2023", | |
"return all the reasons for Downcategory in Nov and dec on machine L-7 in 3rd shift", | |
] | |
sql_toolkit = SQLDatabaseToolkit(db=db3, llm=agent_llm) | |
agent = create_sql_agent( | |
toolkit=sql_toolkit, | |
llm=agent_llm, | |
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
verbose=True, | |
agent_executor_kwargs={"handle_parsing_errors":True, "return_intermediate_steps": True} | |
) | |
def echo3(message, history): | |
answer = agent.invoke(prompt_agent.format_prompt(question=message)) | |
final_answer = f"Final Query:- {list(answer['intermediate_steps'][1][0])[-2][1].split('Action Input: ')[-1]}\n\nAnswer:- {answer['output']}" | |
return final_answer | |
downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table") | |
production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table") | |
agent_tab = gr.ChatInterface(fn=echo3, examples=examples, title="SQL-Chatbot", description="General Chatbot with self-thinking capability, more robust to questions.") | |
demo = gr.TabbedInterface([agent_tab, downtime, production], ['DB_bot-both tables','ShiftDownTimeDetails', 'ShiftProductionDetails']) | |
demo.launch(debug=True, share=True) | |
if __name__ == "__main__": | |
main() |