eliphatfs commited on
Commit
22e326e
1 Parent(s): 7c9515a
Files changed (1) hide show
  1. app.py +55 -50
app.py CHANGED
@@ -69,11 +69,10 @@ def sq(kc, vc):
69
 
70
 
71
  def reset_3d_shape_input(key):
72
- objaid_key = key + "_objaid"
73
  model_key = key + "_model"
74
  npy_key = key + "_npy"
75
  swap_key = key + "_swap"
76
- sq(objaid_key, "")
77
  sq(model_key, None)
78
  sq(npy_key, None)
79
  sq(swap_key, "Y is up (for most Objaverse shapes)")
@@ -121,43 +120,40 @@ def image_examples(samples, ncols, return_key=None):
121
  return trigger
122
 
123
 
124
- def text_examples(samples):
125
- return st.selectbox("Or pick an example", samples)
126
-
127
-
128
  def demo_classification():
129
- load_data = misc_utils.input_3d_shape('cls')
130
- cats = st.text_input("Custom Categories (64 max, separated with comma)")
131
- cats = [a.strip() for a in cats.split(',')]
132
- if len(cats) > 64:
133
- st.error('Maximum 64 custom categories supported in the demo')
134
- return
135
- lvis_run = st.button("Run Classification on LVIS Categories")
136
- custom_run = st.button("Run Classification on Custom Categories")
137
- if lvis_run or auto_submit("clsauto"):
138
- pc = load_data(prog)
139
- col2 = misc_utils.render_pc(pc)
140
- prog.progress(0.5, "Running Classification")
141
- pred = classification.pred_lvis_sims(model_g14, pc)
142
- with col2:
143
- for i, (cat, sim) in zip(range(5), pred.items()):
144
- st.text(cat)
145
- st.caption("Similarity %.4f" % sim)
146
- prog.progress(1.0, "Idle")
147
- if custom_run:
148
- pc = load_data(prog)
149
- col2 = misc_utils.render_pc(pc)
150
- prog.progress(0.5, "Computing Category Embeddings")
151
- device = clip_model.device
152
- tn = clip_prep(text=cats, return_tensors='pt', truncation=True, max_length=76).to(device)
153
- feats = clip_model.get_text_features(**tn).float().cpu()
154
- prog.progress(0.5, "Running Classification")
155
- pred = classification.pred_custom_sims(model_g14, pc, cats, feats)
156
- with col2:
157
- for i, (cat, sim) in zip(range(5), pred.items()):
158
- st.text(cat)
159
- st.caption("Similarity %.4f" % sim)
160
- prog.progress(1.0, "Idle")
 
161
  if image_examples(samples_index.classification, 3):
162
  queue_auto_submit("clsauto")
163
 
@@ -226,18 +222,25 @@ def demo_retrieval():
226
  with tab_text:
227
  with st.form("rtextform"):
228
  k = st.slider("Shapes to Retrieve", 1, 100, 16, key='rtext')
229
- text = st.text_input("Input Text")
230
- picked_sample = text_examples(samples_index.retrieval_texts)
231
- if st.form_submit_button("Run with Text"):
232
  prog.progress(0.49, "Computing Embeddings")
233
  device = clip_model.device
234
  tn = clip_prep(
235
- text=[text or picked_sample], return_tensors='pt', truncation=True, max_length=76
236
  ).to(device)
237
  enc = clip_model.get_text_features(**tn).float().cpu()
238
  prog.progress(0.7, "Running Retrieval")
239
  retrieval_results(retrieval.retrieve(enc, k))
240
  prog.progress(1.0, "Idle")
 
 
 
 
 
 
 
 
241
 
242
  with tab_img:
243
  submit = False
@@ -246,19 +249,21 @@ def demo_retrieval():
246
  pic = st.file_uploader("Upload an Image", key='rimageinput')
247
  if st.form_submit_button("Run with Image"):
248
  submit = True
 
249
  sample_got = image_examples(samples_index.iret, 4, 'rimageinput')
250
  if sample_got:
251
  pic = sample_got
252
  if sample_got or submit:
253
  img = Image.open(pic)
254
- st.image(img)
255
- prog.progress(0.49, "Computing Embeddings")
256
- device = clip_model.device
257
- tn = clip_prep(images=[img], return_tensors="pt").to(device)
258
- enc = clip_model.get_image_features(pixel_values=tn['pixel_values'].type(half)).float().cpu()
259
- prog.progress(0.7, "Running Retrieval")
260
- retrieval_results(retrieval.retrieve(enc, k))
261
- prog.progress(1.0, "Idle")
 
262
 
263
  with tab_pc:
264
  with st.form("rpcform"):
 
69
 
70
 
71
  def reset_3d_shape_input(key):
72
+ # this is not working due to streamlit problems, don't use it
73
  model_key = key + "_model"
74
  npy_key = key + "_npy"
75
  swap_key = key + "_swap"
 
76
  sq(model_key, None)
77
  sq(npy_key, None)
78
  sq(swap_key, "Y is up (for most Objaverse shapes)")
 
120
  return trigger
121
 
122
 
 
 
 
 
123
  def demo_classification():
124
+ with st.form("clsform"):
125
+ load_data = misc_utils.input_3d_shape('cls')
126
+ cats = st.text_input("Custom Categories (64 max, separated with comma)")
127
+ cats = [a.strip() for a in cats.split(',')]
128
+ if len(cats) > 64:
129
+ st.error('Maximum 64 custom categories supported in the demo')
130
+ return
131
+ lvis_run = st.form_submit_button("Run Classification on LVIS Categories")
132
+ custom_run = st.form_submit_button("Run Classification on Custom Categories")
133
+ if lvis_run or auto_submit("clsauto"):
134
+ pc = load_data(prog)
135
+ col2 = misc_utils.render_pc(pc)
136
+ prog.progress(0.5, "Running Classification")
137
+ pred = classification.pred_lvis_sims(model_g14, pc)
138
+ with col2:
139
+ for i, (cat, sim) in zip(range(5), pred.items()):
140
+ st.text(cat)
141
+ st.caption("Similarity %.4f" % sim)
142
+ prog.progress(1.0, "Idle")
143
+ if custom_run:
144
+ pc = load_data(prog)
145
+ col2 = misc_utils.render_pc(pc)
146
+ prog.progress(0.5, "Computing Category Embeddings")
147
+ device = clip_model.device
148
+ tn = clip_prep(text=cats, return_tensors='pt', truncation=True, max_length=76).to(device)
149
+ feats = clip_model.get_text_features(**tn).float().cpu()
150
+ prog.progress(0.5, "Running Classification")
151
+ pred = classification.pred_custom_sims(model_g14, pc, cats, feats)
152
+ with col2:
153
+ for i, (cat, sim) in zip(range(5), pred.items()):
154
+ st.text(cat)
155
+ st.caption("Similarity %.4f" % sim)
156
+ prog.progress(1.0, "Idle")
157
  if image_examples(samples_index.classification, 3):
158
  queue_auto_submit("clsauto")
159
 
 
222
  with tab_text:
223
  with st.form("rtextform"):
224
  k = st.slider("Shapes to Retrieve", 1, 100, 16, key='rtext')
225
+ text = st.text_input("Input Text", key="inputrtext")
226
+ if st.form_submit_button("Run with Text") or auto_submit("rtextauto"):
 
227
  prog.progress(0.49, "Computing Embeddings")
228
  device = clip_model.device
229
  tn = clip_prep(
230
+ text=[text], return_tensors='pt', truncation=True, max_length=76
231
  ).to(device)
232
  enc = clip_model.get_text_features(**tn).float().cpu()
233
  prog.progress(0.7, "Running Retrieval")
234
  retrieval_results(retrieval.retrieve(enc, k))
235
  prog.progress(1.0, "Idle")
236
+ picked_sample = st.selectbox("Examples", ["Select..."] + samples_index.retrieval_texts)
237
+ text_last_example = st.session_state.get('text_last_example', None)
238
+ if text_last_example is None:
239
+ st.session_state.text_last_example = picked_sample
240
+ elif text_last_example != picked_sample and picked_sample != "Select...":
241
+ st.session_state.text_last_example = picked_sample
242
+ sq("inputrtext", picked_sample)
243
+ queue_auto_submit("rtextauto")
244
 
245
  with tab_img:
246
  submit = False
 
249
  pic = st.file_uploader("Upload an Image", key='rimageinput')
250
  if st.form_submit_button("Run with Image"):
251
  submit = True
252
+ results_container = st.container()
253
  sample_got = image_examples(samples_index.iret, 4, 'rimageinput')
254
  if sample_got:
255
  pic = sample_got
256
  if sample_got or submit:
257
  img = Image.open(pic)
258
+ with results_container:
259
+ st.image(img)
260
+ prog.progress(0.49, "Computing Embeddings")
261
+ device = clip_model.device
262
+ tn = clip_prep(images=[img], return_tensors="pt").to(device)
263
+ enc = clip_model.get_image_features(pixel_values=tn['pixel_values'].type(half)).float().cpu()
264
+ prog.progress(0.7, "Running Retrieval")
265
+ retrieval_results(retrieval.retrieve(enc, k))
266
+ prog.progress(1.0, "Idle")
267
 
268
  with tab_pc:
269
  with st.form("rpcform"):