{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# app\n", "\n", "> Gradio app.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| default_exp app" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "import copy\n", "import os\n", "\n", "import gradio as gr\n", "from fastcore.utils import in_jupyter\n", "from langchain.chains import ConversationChain\n", "from langchain.chat_models import ChatOpenAI\n", "from langchain.memory import ConversationBufferMemory\n", "from langchain.prompts.chat import (\n", " ChatPromptTemplate,\n", " HumanMessagePromptTemplate,\n", " MessagesPlaceholder,\n", ")\n", "from PIL import Image\n", "\n", "import constants\n", "from lv_recipe_chatbot.engineer_prompt import INIT_PROMPT\n", "from lv_recipe_chatbot.ingredient_vision import (\n", " SAMPLE_IMG_DIR,\n", " BlipImageCaptioning,\n", " VeganIngredientFinder,\n", " format_image,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# | export\n", "\n", "\n", "class ConversationBot:\n", " def __init__(self, verbose=True):\n", " self.chat = ChatOpenAI(temperature=1, verbose=True)\n", " self.memory = ConversationBufferMemory(return_messages=True)\n", " self.init_prompt = copy.deepcopy(INIT_PROMPT)\n", " init_prompt_msgs = self.init_prompt.messages\n", " self.ai_prompt_questions = {\n", " \"ingredients\": init_prompt_msgs[1],\n", " \"allergies\": init_prompt_msgs[3],\n", " \"recipe_open_params\": init_prompt_msgs[5],\n", " }\n", " self.img_cap = BlipImageCaptioning(\"cpu\")\n", " self.vegan_ingred_finder = VeganIngredientFinder()\n", " self.verbose = verbose\n", "\n", " def respond(self, user_msg, chat_history):\n", " response = self._get_bot_response(user_msg, chat_history)\n", " chat_history.append((user_msg, response))\n", " return \"\", chat_history\n", "\n", " def init_conversation(self, formatted_chat_prompt):\n", " self.conversation = ConversationChain(\n", " llm=self.chat,\n", " memory=self.memory,\n", " prompt=formatted_chat_prompt,\n", " verbose=self.verbose,\n", " )\n", "\n", " def reset(self):\n", " self.memory.clear()\n", " self.init_prompt = copy.deepcopy(INIT_PROMPT)\n", "\n", " def run_img(self, image: str):\n", " desc = self.img_cap.inference(format_image(image))\n", " answer = self.vegan_ingred_finder.list_ingredients(image)\n", " msg = f\"\"\"I uploaded an image that may contain vegan ingredients.\n", "The description of the image is: `{desc}`.\n", "The extracted ingredients are:\n", "```\n", "{answer}\n", "```\n", "\"\"\"\n", " base_prompt = INIT_PROMPT.messages[2].prompt.template\n", " new_prompt = f\"{msg}I may type some more ingredients below.\\n{base_prompt}\"\n", " self.init_prompt.messages[2].prompt.template = new_prompt\n", " return msg\n", "\n", " def _get_bot_response(self, user_msg: str, chat_history) -> str:\n", " if len(chat_history) < 2:\n", " return self.ai_prompt_questions[\"allergies\"].prompt.template\n", "\n", " if len(chat_history) < 3:\n", " return self.ai_prompt_questions[\"recipe_open_params\"].prompt.template\n", "\n", " if len(chat_history) < 4:\n", " user = 0\n", " ai = 1\n", " user_msgs = [msg_pair[user] for msg_pair in chat_history[1:]]\n", " f_init_prompt = self.init_prompt.format_prompt(\n", " ingredients=user_msgs[0],\n", " allergies=user_msgs[1],\n", " recipe_freeform_input=user_msg,\n", " )\n", " chat_msgs = f_init_prompt.to_messages()\n", " results = self.chat.generate([chat_msgs])\n", " chat_msgs.extend(\n", " [\n", " results.generations[0][0].message,\n", " MessagesPlaceholder(variable_name=\"history\"),\n", " HumanMessagePromptTemplate.from_template(\"{input}\"),\n", " ]\n", " )\n", " open_prompt = ChatPromptTemplate.from_messages(chat_msgs)\n", " # prepare the open conversation chain from this point\n", " self.init_conversation(open_prompt)\n", " return results.generations[0][0].message.content\n", "\n", " response = self.conversation.predict(input=user_msg)\n", " return response\n", "\n", " def __del__(self):\n", " del self.vegan_ingred_finder" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Path('/home/evylz/AnimalEquality/lv-recipe-chatbot/assets/images/vegan_ingredients')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "os.listdir(SAMPLE_IMG_DIR)\n", "SAMPLE_IMG_DIR" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from dotenv import load_dotenv" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#: eval: false\n", "load_dotenv()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 6.19 s, sys: 1.47 s, total: 7.66 s\n", "Wall time: 4.68 s\n" ] } ], "source": [ "#| eval: false\n", "%time bot = ConversationBot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "I uploaded an image that may contain vegan ingredients.\n", "The description of the image is: `a refrigerator with food inside`.\n", "The extracted ingredients are:\n", "```\n", "cabbage lettuce onion\n", "apples\n", "rice\n", "plant-based milk\n", "```\n", "\n", "CPU times: user 56.7 s, sys: 63.6 ms, total: 56.8 s\n", "Wall time: 5.95 s\n" ] } ], "source": [ "#| eval: false\n", "%time print(bot.run_img(SAMPLE_IMG_DIR / \"veggie-fridge.jpeg\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "\n", "\n", "def create_demo(bot=ConversationBot):\n", " sample_images = []\n", " all_imgs = [f\"{SAMPLE_IMG_DIR}/{img}\" for img in os.listdir(SAMPLE_IMG_DIR)]\n", " for i, img in enumerate(all_imgs):\n", " if i in [\n", " 1,\n", " 2,\n", " 3,\n", " ]:\n", " sample_images.append(img)\n", " with gr.Blocks() as demo:\n", " gr_img = gr.Image(type=\"filepath\")\n", " btn = gr.Button(value=\"Submit image\")\n", " ingredients_msg = gr.Text(label=\"Ingredients from image\")\n", " btn.click(bot.run_img, inputs=[gr_img], outputs=[ingredients_msg])\n", " gr.Examples(\n", " examples=sample_images,\n", " inputs=gr_img,\n", " )\n", "\n", " chatbot = gr.Chatbot(\n", " value=[(None, bot.ai_prompt_questions[\"ingredients\"].prompt.template)]\n", " )\n", "\n", " msg = gr.Textbox()\n", " # clear = gr.Button(\"Clear\")\n", " gr.Markdown(\"**🔃Refresh the page to start from scratch🔃**\")\n", "\n", " msg.submit(\n", " fn=bot.respond, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False\n", " )\n", " # clear.click(lambda: None, None, chatbot, queue=False).then(bot.reset)\n", " return demo" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Closing server running on port: 7860\n", "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "