Evan Lesmez
Cleanup notebooks to be nbdev_test friendly
5f3a430
raw
history blame
5.07 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_app.ipynb.
# %% auto 0
__all__ = ['ConversationBot', 'create_demo']
# %% ../nbs/01_app.ipynb 3
import copy
import os
import gradio as gr
from fastcore.utils import in_jupyter
from langchain.chains import ConversationChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from PIL import Image
import constants
from .engineer_prompt import INIT_PROMPT
from lv_recipe_chatbot.ingredient_vision import (
SAMPLE_IMG_DIR,
BlipImageCaptioning,
VeganIngredientFinder,
format_image,
)
# %% ../nbs/01_app.ipynb 4
class ConversationBot:
def __init__(self, verbose=True):
self.chat = ChatOpenAI(temperature=1, verbose=True)
self.memory = ConversationBufferMemory(return_messages=True)
self.init_prompt = copy.deepcopy(INIT_PROMPT)
init_prompt_msgs = self.init_prompt.messages
self.ai_prompt_questions = {
"ingredients": init_prompt_msgs[1],
"allergies": init_prompt_msgs[3],
"recipe_open_params": init_prompt_msgs[5],
}
self.img_cap = BlipImageCaptioning("cpu")
self.vegan_ingred_finder = VeganIngredientFinder()
self.verbose = verbose
def respond(self, user_msg, chat_history):
response = self._get_bot_response(user_msg, chat_history)
chat_history.append((user_msg, response))
return "", chat_history
def init_conversation(self, formatted_chat_prompt):
self.conversation = ConversationChain(
llm=self.chat,
memory=self.memory,
prompt=formatted_chat_prompt,
verbose=self.verbose,
)
def reset(self):
self.memory.clear()
self.init_prompt = copy.deepcopy(INIT_PROMPT)
def run_img(self, image: str):
desc = self.img_cap.inference(format_image(image))
answer = self.vegan_ingred_finder.list_ingredients(image)
msg = f"""I uploaded an image that may contain vegan ingredients.
The description of the image is: `{desc}`.
The extracted ingredients are:
```
{answer}
```
"""
base_prompt = INIT_PROMPT.messages[2].prompt.template
new_prompt = f"{msg}I may type some more ingredients below.\n{base_prompt}"
self.init_prompt.messages[2].prompt.template = new_prompt
return msg
def _get_bot_response(self, user_msg: str, chat_history) -> str:
if len(chat_history) < 2:
return self.ai_prompt_questions["allergies"].prompt.template
if len(chat_history) < 3:
return self.ai_prompt_questions["recipe_open_params"].prompt.template
if len(chat_history) < 4:
user = 0
ai = 1
user_msgs = [msg_pair[user] for msg_pair in chat_history[1:]]
f_init_prompt = self.init_prompt.format_prompt(
ingredients=user_msgs[0],
allergies=user_msgs[1],
recipe_freeform_input=user_msg,
)
chat_msgs = f_init_prompt.to_messages()
results = self.chat.generate([chat_msgs])
chat_msgs.extend(
[
results.generations[0][0].message,
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
]
)
open_prompt = ChatPromptTemplate.from_messages(chat_msgs)
# prepare the open conversation chain from this point
self.init_conversation(open_prompt)
return results.generations[0][0].message.content
response = self.conversation.predict(input=user_msg)
return response
def __del__(self):
del self.vegan_ingred_finder
# %% ../nbs/01_app.ipynb 10
def create_demo(bot=ConversationBot):
sample_images = []
all_imgs = [f"{SAMPLE_IMG_DIR}/{img}" for img in os.listdir(SAMPLE_IMG_DIR)]
for i, img in enumerate(all_imgs):
if i in [
1,
2,
3,
]:
sample_images.append(img)
with gr.Blocks() as demo:
gr_img = gr.Image(type="filepath")
btn = gr.Button(value="Submit image")
ingredients_msg = gr.Text(label="Ingredients from image")
btn.click(bot.run_img, inputs=[gr_img], outputs=[ingredients_msg])
gr.Examples(
examples=sample_images,
inputs=gr_img,
)
chatbot = gr.Chatbot(
value=[(None, bot.ai_prompt_questions["ingredients"].prompt.template)]
)
msg = gr.Textbox()
# clear = gr.Button("Clear")
gr.Markdown("**🔃Refresh the page to start from scratch🔃**")
msg.submit(
fn=bot.respond, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
)
# clear.click(lambda: None, None, chatbot, queue=False).then(bot.reset)
return demo