Spaces:
Running
Running
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_app.ipynb. | |
# %% auto 0 | |
__all__ = ['ConversationBot', 'create_demo'] | |
# %% ../nbs/01_app.ipynb 3 | |
import copy | |
import os | |
import gradio as gr | |
from fastcore.utils import in_jupyter | |
from langchain.chains import ConversationChain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.memory import ConversationBufferMemory | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
HumanMessagePromptTemplate, | |
MessagesPlaceholder, | |
) | |
from PIL import Image | |
import constants | |
from .engineer_prompt import INIT_PROMPT | |
from lv_recipe_chatbot.ingredient_vision import ( | |
SAMPLE_IMG_DIR, | |
BlipImageCaptioning, | |
VeganIngredientFinder, | |
format_image, | |
) | |
# %% ../nbs/01_app.ipynb 4 | |
class ConversationBot: | |
def __init__(self, verbose=True): | |
self.chat = ChatOpenAI(temperature=1, verbose=True) | |
self.memory = ConversationBufferMemory(return_messages=True) | |
self.init_prompt = copy.deepcopy(INIT_PROMPT) | |
init_prompt_msgs = self.init_prompt.messages | |
self.ai_prompt_questions = { | |
"ingredients": init_prompt_msgs[1], | |
"allergies": init_prompt_msgs[3], | |
"recipe_open_params": init_prompt_msgs[5], | |
} | |
self.img_cap = BlipImageCaptioning("cpu") | |
self.vegan_ingred_finder = VeganIngredientFinder() | |
self.verbose = verbose | |
def respond(self, user_msg, chat_history): | |
response = self._get_bot_response(user_msg, chat_history) | |
chat_history.append((user_msg, response)) | |
return "", chat_history | |
def init_conversation(self, formatted_chat_prompt): | |
self.conversation = ConversationChain( | |
llm=self.chat, | |
memory=self.memory, | |
prompt=formatted_chat_prompt, | |
verbose=self.verbose, | |
) | |
def reset(self): | |
self.memory.clear() | |
self.init_prompt = copy.deepcopy(INIT_PROMPT) | |
def run_img(self, image: str): | |
desc = self.img_cap.inference(format_image(image)) | |
answer = self.vegan_ingred_finder.list_ingredients(image) | |
msg = f"""I uploaded an image that may contain vegan ingredients. | |
The description of the image is: `{desc}`. | |
The extracted ingredients are: | |
``` | |
{answer} | |
``` | |
""" | |
base_prompt = INIT_PROMPT.messages[2].prompt.template | |
new_prompt = f"{msg}I may type some more ingredients below.\n{base_prompt}" | |
self.init_prompt.messages[2].prompt.template = new_prompt | |
return msg | |
def _get_bot_response(self, user_msg: str, chat_history) -> str: | |
if len(chat_history) < 2: | |
return self.ai_prompt_questions["allergies"].prompt.template | |
if len(chat_history) < 3: | |
return self.ai_prompt_questions["recipe_open_params"].prompt.template | |
if len(chat_history) < 4: | |
user = 0 | |
ai = 1 | |
user_msgs = [msg_pair[user] for msg_pair in chat_history[1:]] | |
f_init_prompt = self.init_prompt.format_prompt( | |
ingredients=user_msgs[0], | |
allergies=user_msgs[1], | |
recipe_freeform_input=user_msg, | |
) | |
chat_msgs = f_init_prompt.to_messages() | |
results = self.chat.generate([chat_msgs]) | |
chat_msgs.extend( | |
[ | |
results.generations[0][0].message, | |
MessagesPlaceholder(variable_name="history"), | |
HumanMessagePromptTemplate.from_template("{input}"), | |
] | |
) | |
open_prompt = ChatPromptTemplate.from_messages(chat_msgs) | |
# prepare the open conversation chain from this point | |
self.init_conversation(open_prompt) | |
return results.generations[0][0].message.content | |
response = self.conversation.predict(input=user_msg) | |
return response | |
def __del__(self): | |
del self.vegan_ingred_finder | |
# %% ../nbs/01_app.ipynb 10 | |
def create_demo(bot=ConversationBot): | |
sample_images = [] | |
all_imgs = [f"{SAMPLE_IMG_DIR}/{img}" for img in os.listdir(SAMPLE_IMG_DIR)] | |
for i, img in enumerate(all_imgs): | |
if i in [ | |
1, | |
2, | |
3, | |
]: | |
sample_images.append(img) | |
with gr.Blocks() as demo: | |
gr_img = gr.Image(type="filepath") | |
btn = gr.Button(value="Submit image") | |
ingredients_msg = gr.Text(label="Ingredients from image") | |
btn.click(bot.run_img, inputs=[gr_img], outputs=[ingredients_msg]) | |
gr.Examples( | |
examples=sample_images, | |
inputs=gr_img, | |
) | |
chatbot = gr.Chatbot( | |
value=[(None, bot.ai_prompt_questions["ingredients"].prompt.template)] | |
) | |
msg = gr.Textbox() | |
# clear = gr.Button("Clear") | |
gr.Markdown("**🔃Refresh the page to start from scratch🔃**") | |
msg.submit( | |
fn=bot.respond, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False | |
) | |
# clear.click(lambda: None, None, chatbot, queue=False).then(bot.reset) | |
return demo | |