Spaces:
Running
Running
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_app.ipynb. | |
# %% auto 0 | |
__all__ = ['ConversationBot', 'launch_demo'] | |
# %% ../nbs/01_app.ipynb 3 | |
import os | |
import gradio as gr | |
from fastcore.utils import in_jupyter | |
from langchain.chains import ConversationChain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.memory import ConversationBufferMemory | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
HumanMessagePromptTemplate, | |
MessagesPlaceholder, | |
) | |
from .engineer_prompt import init_prompt | |
# %% ../nbs/01_app.ipynb 4 | |
class ConversationBot: | |
def __init__( | |
self, | |
): | |
self.chat = ChatOpenAI(temperature=1, verbose=True) | |
self.memory = ConversationBufferMemory(return_messages=True) | |
self.init_prompt_msgs = init_prompt.messages | |
self.ai_prompt_questions = { | |
"ingredients": self.init_prompt_msgs[1], | |
"allergies": self.init_prompt_msgs[3], | |
"recipe_open_params": self.init_prompt_msgs[5], | |
} | |
def respond(self, user_msg, chat_history): | |
response = self._get_bot_response(user_msg, chat_history) | |
chat_history.append((user_msg, response)) | |
return "", chat_history | |
def init_conversation(self, formatted_chat_prompt): | |
self.conversation = ConversationChain( | |
llm=self.chat, | |
memory=self.memory, | |
prompt=formatted_chat_prompt, | |
verbose=True, | |
) | |
def reset(self): | |
self.memory.clear() | |
def _get_bot_response(self, user_msg: str, chat_history) -> str: | |
if len(chat_history) < 2: | |
return self.ai_prompt_questions["allergies"].prompt.template | |
if len(chat_history) < 3: | |
return self.ai_prompt_questions["recipe_open_params"].prompt.template | |
if len(chat_history) < 4: | |
user = 0 | |
ai = 1 | |
user_msgs = [msg_pair[user] for msg_pair in chat_history[1:]] | |
f_init_prompt = init_prompt.format_prompt( | |
ingredients=user_msgs[0], | |
allergies=user_msgs[1], | |
recipe_freeform_input=user_msg, | |
) | |
chat_msgs = f_init_prompt.to_messages() | |
results = self.chat.generate([chat_msgs]) | |
chat_msgs.extend( | |
[ | |
results.generations[0][0].message, | |
MessagesPlaceholder(variable_name="history"), | |
HumanMessagePromptTemplate.from_template("{input}"), | |
] | |
) | |
open_prompt = ChatPromptTemplate.from_messages(chat_msgs) | |
# prepare the open conversation chain from this point | |
self.init_conversation(open_prompt) | |
return results.generations[0][0].message.content | |
response = self.conversation.predict(input=user_msg) | |
return response | |
# %% ../nbs/01_app.ipynb 5 | |
def launch_demo(): | |
with gr.Blocks() as demo: | |
bot = ConversationBot() | |
chatbot = gr.Chatbot( | |
value=[(None, bot.ai_prompt_questions["ingredients"].prompt.template)] | |
) | |
msg = gr.Textbox() | |
clear = gr.Button("Clear") | |
msg.submit( | |
fn=bot.respond, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False | |
) | |
clear.click(lambda: None, None, chatbot, queue=False).then(bot.reset) | |
demo.launch( | |
auth=( | |
os.environ["GRADIO_DEMO_USERNAME"], | |
os.environ["GRADIO_DEMO_PASSWORD"], | |
) | |
) | |