Spaces:
Running
on
T4
Running
on
T4
eliphatfs
commited on
Commit
•
22e326e
1
Parent(s):
7c9515a
Updates.
Browse files
app.py
CHANGED
@@ -69,11 +69,10 @@ def sq(kc, vc):
|
|
69 |
|
70 |
|
71 |
def reset_3d_shape_input(key):
|
72 |
-
|
73 |
model_key = key + "_model"
|
74 |
npy_key = key + "_npy"
|
75 |
swap_key = key + "_swap"
|
76 |
-
sq(objaid_key, "")
|
77 |
sq(model_key, None)
|
78 |
sq(npy_key, None)
|
79 |
sq(swap_key, "Y is up (for most Objaverse shapes)")
|
@@ -121,43 +120,40 @@ def image_examples(samples, ncols, return_key=None):
|
|
121 |
return trigger
|
122 |
|
123 |
|
124 |
-
def text_examples(samples):
|
125 |
-
return st.selectbox("Or pick an example", samples)
|
126 |
-
|
127 |
-
|
128 |
def demo_classification():
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
161 |
if image_examples(samples_index.classification, 3):
|
162 |
queue_auto_submit("clsauto")
|
163 |
|
@@ -226,18 +222,25 @@ def demo_retrieval():
|
|
226 |
with tab_text:
|
227 |
with st.form("rtextform"):
|
228 |
k = st.slider("Shapes to Retrieve", 1, 100, 16, key='rtext')
|
229 |
-
text = st.text_input("Input Text")
|
230 |
-
|
231 |
-
if st.form_submit_button("Run with Text"):
|
232 |
prog.progress(0.49, "Computing Embeddings")
|
233 |
device = clip_model.device
|
234 |
tn = clip_prep(
|
235 |
-
text=[text
|
236 |
).to(device)
|
237 |
enc = clip_model.get_text_features(**tn).float().cpu()
|
238 |
prog.progress(0.7, "Running Retrieval")
|
239 |
retrieval_results(retrieval.retrieve(enc, k))
|
240 |
prog.progress(1.0, "Idle")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
with tab_img:
|
243 |
submit = False
|
@@ -246,19 +249,21 @@ def demo_retrieval():
|
|
246 |
pic = st.file_uploader("Upload an Image", key='rimageinput')
|
247 |
if st.form_submit_button("Run with Image"):
|
248 |
submit = True
|
|
|
249 |
sample_got = image_examples(samples_index.iret, 4, 'rimageinput')
|
250 |
if sample_got:
|
251 |
pic = sample_got
|
252 |
if sample_got or submit:
|
253 |
img = Image.open(pic)
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
|
|
262 |
|
263 |
with tab_pc:
|
264 |
with st.form("rpcform"):
|
|
|
69 |
|
70 |
|
71 |
def reset_3d_shape_input(key):
|
72 |
+
# this is not working due to streamlit problems, don't use it
|
73 |
model_key = key + "_model"
|
74 |
npy_key = key + "_npy"
|
75 |
swap_key = key + "_swap"
|
|
|
76 |
sq(model_key, None)
|
77 |
sq(npy_key, None)
|
78 |
sq(swap_key, "Y is up (for most Objaverse shapes)")
|
|
|
120 |
return trigger
|
121 |
|
122 |
|
|
|
|
|
|
|
|
|
123 |
def demo_classification():
|
124 |
+
with st.form("clsform"):
|
125 |
+
load_data = misc_utils.input_3d_shape('cls')
|
126 |
+
cats = st.text_input("Custom Categories (64 max, separated with comma)")
|
127 |
+
cats = [a.strip() for a in cats.split(',')]
|
128 |
+
if len(cats) > 64:
|
129 |
+
st.error('Maximum 64 custom categories supported in the demo')
|
130 |
+
return
|
131 |
+
lvis_run = st.form_submit_button("Run Classification on LVIS Categories")
|
132 |
+
custom_run = st.form_submit_button("Run Classification on Custom Categories")
|
133 |
+
if lvis_run or auto_submit("clsauto"):
|
134 |
+
pc = load_data(prog)
|
135 |
+
col2 = misc_utils.render_pc(pc)
|
136 |
+
prog.progress(0.5, "Running Classification")
|
137 |
+
pred = classification.pred_lvis_sims(model_g14, pc)
|
138 |
+
with col2:
|
139 |
+
for i, (cat, sim) in zip(range(5), pred.items()):
|
140 |
+
st.text(cat)
|
141 |
+
st.caption("Similarity %.4f" % sim)
|
142 |
+
prog.progress(1.0, "Idle")
|
143 |
+
if custom_run:
|
144 |
+
pc = load_data(prog)
|
145 |
+
col2 = misc_utils.render_pc(pc)
|
146 |
+
prog.progress(0.5, "Computing Category Embeddings")
|
147 |
+
device = clip_model.device
|
148 |
+
tn = clip_prep(text=cats, return_tensors='pt', truncation=True, max_length=76).to(device)
|
149 |
+
feats = clip_model.get_text_features(**tn).float().cpu()
|
150 |
+
prog.progress(0.5, "Running Classification")
|
151 |
+
pred = classification.pred_custom_sims(model_g14, pc, cats, feats)
|
152 |
+
with col2:
|
153 |
+
for i, (cat, sim) in zip(range(5), pred.items()):
|
154 |
+
st.text(cat)
|
155 |
+
st.caption("Similarity %.4f" % sim)
|
156 |
+
prog.progress(1.0, "Idle")
|
157 |
if image_examples(samples_index.classification, 3):
|
158 |
queue_auto_submit("clsauto")
|
159 |
|
|
|
222 |
with tab_text:
|
223 |
with st.form("rtextform"):
|
224 |
k = st.slider("Shapes to Retrieve", 1, 100, 16, key='rtext')
|
225 |
+
text = st.text_input("Input Text", key="inputrtext")
|
226 |
+
if st.form_submit_button("Run with Text") or auto_submit("rtextauto"):
|
|
|
227 |
prog.progress(0.49, "Computing Embeddings")
|
228 |
device = clip_model.device
|
229 |
tn = clip_prep(
|
230 |
+
text=[text], return_tensors='pt', truncation=True, max_length=76
|
231 |
).to(device)
|
232 |
enc = clip_model.get_text_features(**tn).float().cpu()
|
233 |
prog.progress(0.7, "Running Retrieval")
|
234 |
retrieval_results(retrieval.retrieve(enc, k))
|
235 |
prog.progress(1.0, "Idle")
|
236 |
+
picked_sample = st.selectbox("Examples", ["Select..."] + samples_index.retrieval_texts)
|
237 |
+
text_last_example = st.session_state.get('text_last_example', None)
|
238 |
+
if text_last_example is None:
|
239 |
+
st.session_state.text_last_example = picked_sample
|
240 |
+
elif text_last_example != picked_sample and picked_sample != "Select...":
|
241 |
+
st.session_state.text_last_example = picked_sample
|
242 |
+
sq("inputrtext", picked_sample)
|
243 |
+
queue_auto_submit("rtextauto")
|
244 |
|
245 |
with tab_img:
|
246 |
submit = False
|
|
|
249 |
pic = st.file_uploader("Upload an Image", key='rimageinput')
|
250 |
if st.form_submit_button("Run with Image"):
|
251 |
submit = True
|
252 |
+
results_container = st.container()
|
253 |
sample_got = image_examples(samples_index.iret, 4, 'rimageinput')
|
254 |
if sample_got:
|
255 |
pic = sample_got
|
256 |
if sample_got or submit:
|
257 |
img = Image.open(pic)
|
258 |
+
with results_container:
|
259 |
+
st.image(img)
|
260 |
+
prog.progress(0.49, "Computing Embeddings")
|
261 |
+
device = clip_model.device
|
262 |
+
tn = clip_prep(images=[img], return_tensors="pt").to(device)
|
263 |
+
enc = clip_model.get_image_features(pixel_values=tn['pixel_values'].type(half)).float().cpu()
|
264 |
+
prog.progress(0.7, "Running Retrieval")
|
265 |
+
retrieval_results(retrieval.retrieve(enc, k))
|
266 |
+
prog.progress(1.0, "Idle")
|
267 |
|
268 |
with tab_pc:
|
269 |
with st.form("rpcform"):
|