File size: 3,972 Bytes
986fa13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_ingredient_vision.ipynb.

# %% auto 0
__all__ = ['SAMPLE_IMG_DIR', 'format_image', 'BlipImageCaptioning', 'BlipVQA', 'VeganIngredientFinder']

# %% ../nbs/03_ingredient_vision.ipynb 3
import imghdr
import os
import time
from pathlib import Path

import constants
import numpy as np
import torch
from PIL import Image
from transformers import (
    BlipForConditionalGeneration,
    BlipForQuestionAnswering,
    BlipProcessor,
    pipeline,
)

# %% ../nbs/03_ingredient_vision.ipynb 7
def format_image(image):
    img = Image.open(image)
    width, height = img.size
    ratio = min(512 / width, 512 / height)
    width_new, height_new = (round(width * ratio), round(height * ratio))
    width_new = int(np.round(width_new / 64.0)) * 64
    height_new = int(np.round(height_new / 64.0)) * 64
    img = img.resize((width_new, height_new))
    img = img.convert("RGB")
    return img

# %% ../nbs/03_ingredient_vision.ipynb 8
class BlipImageCaptioning:
    """
    Useful when you want to know what is inside the photo.
    """

    def __init__(self, device: str):
        self.device = device
        self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
        self.processor = BlipProcessor.from_pretrained(
            "Salesforce/blip-image-captioning-base"
        )
        self.model = BlipForConditionalGeneration.from_pretrained(
            "Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype
        ).to(self.device)

    def inference(self, image: Image):
        inputs = self.processor(image, return_tensors="pt").to(
            self.device, self.torch_dtype
        )
        out = self.model.generate(**inputs, max_new_tokens=50)
        captions = self.processor.decode(out[0], skip_special_tokens=True)
        return captions

# %% ../nbs/03_ingredient_vision.ipynb 9
class BlipVQA:
    """
                                                BLIP Visual Question Answering
                                                Useful when you need an answer for a question based on an image.
                                        Examples:
    what is the background color of this image, how many cats are in this figure, what is in this figure?
    """

    def __init__(self, device: str):
        self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
        self.device = device
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
        self.model = BlipForQuestionAnswering.from_pretrained(
            "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype
        ).to(self.device)

    def inference(self, image: Image, question: str):
        image = image.convert("RGB")
        inputs = self.processor(image, question, return_tensors="pt").to(
            self.device, self.torch_dtype
        )
        out = self.model.generate(**inputs, max_new_tokens=100)
        answer = self.processor.decode(out[0], skip_special_tokens=True)
        return answer

# %% ../nbs/03_ingredient_vision.ipynb 10
SAMPLE_IMG_DIR = Path(f"{constants.ROOT_DIR}/assets/images/vegan_ingredients")

# %% ../nbs/03_ingredient_vision.ipynb 17
class VeganIngredientFinder:
    def __init__(self):
        self.vqa = BlipVQA("cpu")

    def list_ingredients(self, img: str) -> str:
        img = format_image(img)
        answer = self.vqa.inference(
            img, f"What are three of the vegetables seen in the image if any?"
        )
        answer += "\n" + self.vqa.inference(
            img, f"What are three of the fruits seen in the image if any?"
        )
        answer += "\n" + self.vqa.inference(
            img, f"What grains and starches are in the image if any?"
        )
        if (
            "yes"
            in self.vqa.inference(
                img, f"Is there plant-based milk in the image?"
            ).lower()
        ):
            answer += "\n" + "plant-based milk"
        return answer