Spaces:
Running
Running
fixing image2text
Browse files- image2text.py +18 -8
- requirements.txt +2 -1
image2text.py
CHANGED
@@ -4,7 +4,8 @@ from utils import text_encoder, image_encoder
|
|
4 |
from PIL import Image
|
5 |
from jax import numpy as jnp
|
6 |
import pandas as pd
|
7 |
-
|
|
|
8 |
|
9 |
def app():
|
10 |
st.title("From Image to Text")
|
@@ -17,23 +18,31 @@ def app():
|
|
17 |
image classification task!
|
18 |
|
19 |
π€ Italian mode on! π€
|
|
|
|
|
|
|
20 |
|
21 |
"""
|
22 |
)
|
23 |
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
)
|
27 |
|
|
|
28 |
MAX_CAP = 4
|
29 |
|
30 |
col1, col2 = st.beta_columns([3, 1])
|
31 |
|
32 |
with col2:
|
33 |
captions_count = st.selectbox(
|
34 |
-
"Number of labels", options=range(1, MAX_CAP + 1)
|
35 |
)
|
36 |
-
compute = st.button("
|
37 |
|
38 |
with col1:
|
39 |
captions = list()
|
@@ -43,7 +52,7 @@ def app():
|
|
43 |
if compute:
|
44 |
captions = [c for c in captions if c != ""]
|
45 |
|
46 |
-
if not captions or not
|
47 |
st.error("Please choose one image and at least one label")
|
48 |
else:
|
49 |
with st.spinner("Computing..."):
|
@@ -55,13 +64,14 @@ def app():
|
|
55 |
text_embeds.extend(text_encoder(c, model, tokenizer))
|
56 |
|
57 |
text_embeds = jnp.array(text_embeds)
|
|
|
58 |
|
59 |
-
image = Image.open(
|
60 |
transform = get_image_transform(model.config.vision_config.image_size)
|
61 |
image_embed = image_encoder(transform(image), model)
|
62 |
|
63 |
# we could have a softmax here
|
64 |
-
cos_similarities = jnp.matmul(image_embed, text_embeds.T)
|
65 |
|
66 |
chart_data = pd.Series(cos_similarities[0], index=captions)
|
67 |
|
|
|
4 |
from PIL import Image
|
5 |
from jax import numpy as jnp
|
6 |
import pandas as pd
|
7 |
+
import requests
|
8 |
+
import jax
|
9 |
|
10 |
def app():
|
11 |
st.title("From Image to Text")
|
|
|
18 |
image classification task!
|
19 |
|
20 |
π€ Italian mode on! π€
|
21 |
+
|
22 |
+
For example, try to write "cat" in the space for label1 and "dog" in the space for label2 and the run
|
23 |
+
"classify"!
|
24 |
|
25 |
"""
|
26 |
)
|
27 |
|
28 |
+
image_url = st.text_input(
|
29 |
+
|
30 |
+
"You can input the URL of an image",
|
31 |
+
|
32 |
+
value="https://upload.wikimedia.org/wikipedia/commons/thumb/5/5e/Domestic_Cat_Face_Shot.jpg/1280px-Domestic_Cat_Face_Shot.jpg",
|
33 |
+
|
34 |
)
|
35 |
|
36 |
+
|
37 |
MAX_CAP = 4
|
38 |
|
39 |
col1, col2 = st.beta_columns([3, 1])
|
40 |
|
41 |
with col2:
|
42 |
captions_count = st.selectbox(
|
43 |
+
"Number of labels", options=range(1, MAX_CAP + 1), index=1
|
44 |
)
|
45 |
+
compute = st.button("Classify")
|
46 |
|
47 |
with col1:
|
48 |
captions = list()
|
|
|
52 |
if compute:
|
53 |
captions = [c for c in captions if c != ""]
|
54 |
|
55 |
+
if not captions or not image_url:
|
56 |
st.error("Please choose one image and at least one label")
|
57 |
else:
|
58 |
with st.spinner("Computing..."):
|
|
|
64 |
text_embeds.extend(text_encoder(c, model, tokenizer))
|
65 |
|
66 |
text_embeds = jnp.array(text_embeds)
|
67 |
+
image_raw = requests.get(image_url, stream=True).raw
|
68 |
|
69 |
+
image = Image.open(image_raw).convert("RGB")
|
70 |
transform = get_image_transform(model.config.vision_config.image_size)
|
71 |
image_embed = image_encoder(transform(image), model)
|
72 |
|
73 |
# we could have a softmax here
|
74 |
+
cos_similarities = jax.nn.softmax(jnp.matmul(image_embed, text_embeds.T))
|
75 |
|
76 |
chart_data = pd.Series(cos_similarities[0], index=captions)
|
77 |
|
requirements.txt
CHANGED
@@ -5,4 +5,5 @@ torch
|
|
5 |
torchvision
|
6 |
natsort
|
7 |
stqdm
|
8 |
-
pandas
|
|
|
|
5 |
torchvision
|
6 |
natsort
|
7 |
stqdm
|
8 |
+
pandas
|
9 |
+
requests
|