heysho's picture
Create app.py
1c2e898 verified
raw
history blame
6.19 kB
import streamlit as st
from langchain.agents import create_tool_calling_agent, AgentExecutor
from langchain.memory import ConversationBufferWindowMemory
from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain_community.callbacks import StreamlitCallbackHandler
# models
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
# custom tools
from tools.search_ddg import search_ddg
from tools.fetch_page import fetch_page
###### dotenv を利用しない場合は消してください ######
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
import warnings
warnings.warn("dotenv not found. Please make sure to set your environment variables manually.", ImportWarning)
################################################
CUSTOM_SYSTEM_PROMPT = """
あなたは、ユーザーのリクエストに基づいてインターネットで調べ物を行うアシスタントです。
利用可能なツールを使用して、調査した情報を説明してください。
既に知っていることだけに基づいて答えないでください。回答する前にできる限り検索を行ってください。
(ユーザーが読むページを指定するなど、特別な場合は、検索する必要はありません。)
検索結果ページを見ただけでは情報があまりないと思われる場合は、次の2つのオプションを検討して試してみてください。
- 検索結果のリンクをクリックして、各ページのコンテンツにアクセスし、読んでみてください。
- 1ページが長すぎる場合は、3回以上ページ送りしないでください(メモリの負荷がかかるため)。
- 検索クエリを変更して、新しい検索を実行してください。
- ユーザーは日本に住んでいます。情報は日本の地域に限定してください。
- 時期の指定がない場合は、2024年の最新情報を検索ください。
ユーザーは非常に忙しく、あなたほど自由ではありません。
そのため、ユーザーの労力を節約するために、直接的な回答を提供してください。
=== 悪い回答の例 ===
- これらのページを参照してください。
- これらのページを参照してコードを書くことができます。
- 次のページが役立つでしょう。
=== 良い回答の例 ===
- これはサンプルコードです。 -- サンプルコードをここに --
- あなたの質問の答えは -- 回答をここに --
回答の最後には、参照したページのURLを**必ず**記載してください。(これにより、ユーザーは回答を検証することができます)
ユーザーが使用している言語で回答するようにしてください。
ユーザーが日本語で質問した場合は、日本語で回答してください。ユーザーがスペイン語で質問した場合は、スペイン語で回答してください。
"""
def init_page():
st.set_page_config(
page_title="Web Browsing Agent",
page_icon="🤗"
)
st.header("Web Browsing Agent 🤗")
# st.sidebar.title("Options")
def init_messages():
# clear_button = st.button("Clear Conversation", key="clear")
if "messages" not in st.session_state:
st.session_state.messages = [
{"role": "assistant", "content": "こんにちは!なんでも質問をどうぞ!"}
]
st.session_state['memory'] = ConversationBufferWindowMemory(
return_messages=True,
memory_key="chat_history",
k=10
)
# このようにも書ける
# from langchain_community.chat_message_histories import StreamlitChatMessageHistory
# msgs = StreamlitChatMessageHistory(key="special_app_key")
# st.session_state['memory'] = ConversationBufferMemory(memory_key="history", chat_memory=msgs)
def select_model(temperature=0):
models = ("GPT-4o","GPT-4o-mini", "Claude 3.5 Sonnet", "Gemini 1.5 Pro")
model_choice = st.radio("Choose a model:", models)
if model_choice == "GPT-4o":
return ChatOpenAI(temperature=temperature, model_name="gpt-4o")
elif model_choice == "GPT-4o-mini":
return ChatOpenAI(temperature=temperature, model_name="gpt-4o-mini")
elif model_choice == "Claude 3.5 Sonnet":
return ChatAnthropic(temperature=temperature, model_name="claude-3-5-sonnet-20240620")
elif model_choice == "Gemini 1.5 Pro":
return ChatGoogleGenerativeAI(temperature=temperature, model="gemini-1.5-pro-latest")
def create_agent():
tools = [search_ddg, fetch_page]
prompt = ChatPromptTemplate.from_messages([
("system", CUSTOM_SYSTEM_PROMPT),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad")
])
llm = select_model()
agent = create_tool_calling_agent(llm, tools, prompt)
return AgentExecutor(
agent=agent,
tools=tools,
verbose=True,
memory=st.session_state['memory']
)
def main():
init_page()
init_messages()
web_browsing_agent = create_agent()
for msg in st.session_state['memory'].chat_memory.messages:
st.chat_message(msg.type).write(msg.content)
if prompt := st.chat_input(placeholder="2024年9月の東京のイベント情報を教えて"):
st.chat_message("user").write(prompt)
with st.chat_message("assistant"):
# コールバック関数の設定 (エージェントの動作の可視化用)
st_cb = StreamlitCallbackHandler(
st.container(), expand_new_thoughts=True)
# エージェントを実行
response = web_browsing_agent.invoke(
{'input': prompt},
config=RunnableConfig({'callbacks': [st_cb]})
)
st.write(response["output"])
if __name__ == '__main__':
main()