Spaces:
Running
on
Zero
Running
on
Zero
add progress bar
Browse files- app.py +108 -30
- app_text.py +9 -1
app.py
CHANGED
@@ -77,7 +77,9 @@ def compute_ncut(
|
|
77 |
min_dist=0.1,
|
78 |
sampling_method="fps",
|
79 |
metric="cosine",
|
|
|
80 |
):
|
|
|
81 |
logging_str = ""
|
82 |
|
83 |
num_nodes = np.prod(features.shape[:-1])
|
@@ -88,6 +90,7 @@ def compute_ncut(
|
|
88 |
logging_str += f"Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.\n"
|
89 |
|
90 |
start = time.time()
|
|
|
91 |
eigvecs, eigvals = NCUT(
|
92 |
num_eig=num_eig,
|
93 |
num_sample=num_sample_ncut,
|
@@ -102,6 +105,7 @@ def compute_ncut(
|
|
102 |
logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
|
103 |
|
104 |
start = time.time()
|
|
|
105 |
_, rgb = eigenvector_to_rgb(
|
106 |
eigvecs,
|
107 |
method=embedding_method,
|
@@ -249,15 +253,34 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
|
|
249 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
250 |
return blended.astype(np.uint8)
|
251 |
|
252 |
-
def make_cluster_plot(eigvecs, images, h=64, w=64):
|
|
|
|
|
|
|
|
|
253 |
from ncut_pytorch.ncut_pytorch import farthest_point_sampling
|
254 |
magnitude = torch.norm(eigvecs, dim=-1)
|
255 |
-
p = 0.
|
256 |
top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
|
257 |
-
num_samples =
|
|
|
|
|
258 |
fps_idx = farthest_point_sampling(eigvecs[top_p_idx], num_samples)
|
259 |
fps_idx = top_p_idx[fps_idx]
|
260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
# downsample to 256x256
|
262 |
images = F.interpolate(images, (256, 256), mode="bilinear")
|
263 |
images = images.cpu().numpy()
|
@@ -269,29 +292,57 @@ def make_cluster_plot(eigvecs, images, h=64, w=64):
|
|
269 |
# sort the fps_idx by the mean of the heatmap
|
270 |
fps_heatmaps = {}
|
271 |
sort_values = []
|
|
|
272 |
for _, idx in enumerate(fps_idx):
|
273 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
274 |
-
eigvecs = eigvecs.to(device)
|
275 |
heatmap = F.cosine_similarity(eigvecs, eigvecs[idx][None], dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
heatmap = heatmap.reshape(-1, h, w)
|
277 |
-
mask = (heatmap >
|
|
|
|
|
|
|
|
|
278 |
sort_values.append(mask.mean().item())
|
279 |
-
fps_heatmaps[idx.item()] = heatmap.cpu()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
|
281 |
fig_images = []
|
282 |
i_cluster = 0
|
283 |
-
|
|
|
|
|
|
|
284 |
fig, axs = plt.subplots(3, 5, figsize=(15, 9))
|
285 |
for ax in axs.flatten():
|
286 |
ax.axis("off")
|
287 |
for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
|
288 |
heatmap = fps_heatmaps[idx.item()]
|
289 |
-
mask = (heatmap > 0.1).float()
|
290 |
-
sorted_image_idxs = torch.argsort(mask.mean((1, 2)), descending=True)
|
291 |
size = (images.shape[1], images.shape[2])
|
292 |
heatmap = apply_reds_colormap(heatmap, size)
|
293 |
-
for i, image_idx in enumerate(sorted_image_idxs[:3]):
|
294 |
-
|
|
|
|
|
295 |
axs[i, j].imshow(_heatmap)
|
296 |
if i == 0:
|
297 |
axs[i, j].set_title(f"cluster {i_cluster+1}", fontsize=24)
|
@@ -348,6 +399,9 @@ def ncut_run(
|
|
348 |
lisa_prompt2="",
|
349 |
lisa_prompt3="",
|
350 |
):
|
|
|
|
|
|
|
351 |
logging_str = ""
|
352 |
if "AlignedThreeModelAttnNodes" == model_name:
|
353 |
# dirty patch for the alignedcut paper
|
@@ -396,12 +450,16 @@ def ncut_run(
|
|
396 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
397 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
398 |
|
|
|
|
|
399 |
if recursion:
|
400 |
rgbs = []
|
401 |
recursion_gammas = [recursion_l1_gamma, recursion_l2_gamma, recursion_l3_gamma]
|
402 |
inp = features
|
|
|
403 |
for i, n_eigs in enumerate([num_eig, recursion_l2_n_eigs, recursion_l3_n_eigs]):
|
404 |
logging_str += f"Recursion #{i+1}\n"
|
|
|
405 |
rgb, _logging_str, eigvecs = compute_ncut(
|
406 |
inp,
|
407 |
num_eig=n_eigs,
|
@@ -417,6 +475,7 @@ def ncut_run(
|
|
417 |
min_dist=min_dist,
|
418 |
sampling_method=sampling_method,
|
419 |
metric="cosine" if i == 0 else recursion_metric,
|
|
|
420 |
)
|
421 |
logging_str += _logging_str
|
422 |
|
@@ -424,6 +483,7 @@ def ncut_run(
|
|
424 |
if "AlignedThreeModelAttnNodes" == model_name:
|
425 |
# dirty patch for the alignedcut paper
|
426 |
start = time.time()
|
|
|
427 |
pil_images = []
|
428 |
for i_image in range(rgb.shape[0]):
|
429 |
_im = plot_one_image_36_grid(images[i_image], rgb[i_image])
|
@@ -442,6 +502,8 @@ def ncut_run(
|
|
442 |
if old_school_ncut: # individual images
|
443 |
logging_str += "Running NCut for each image independently\n"
|
444 |
rgb = []
|
|
|
|
|
445 |
for i_image in range(features.shape[0]):
|
446 |
logging_str += f"Image #{i_image+1}\n"
|
447 |
feature = features[i_image]
|
@@ -459,6 +521,7 @@ def ncut_run(
|
|
459 |
n_neighbors=n_neighbors,
|
460 |
min_dist=min_dist,
|
461 |
sampling_method=sampling_method,
|
|
|
462 |
)
|
463 |
logging_str += _logging_str
|
464 |
rgb.append(_rgb[0])
|
@@ -486,6 +549,7 @@ def ncut_run(
|
|
486 |
if "AlignedThreeModelAttnNodes" == model_name:
|
487 |
# dirty patch for the alignedcut paper
|
488 |
start = time.time()
|
|
|
489 |
pil_images = []
|
490 |
for i_image in range(rgb.shape[0]):
|
491 |
_im = plot_one_image_36_grid(images[i_image], rgb[i_image])
|
@@ -506,15 +570,18 @@ def ncut_run(
|
|
506 |
|
507 |
if not video_output:
|
508 |
start = time.time()
|
|
|
|
|
509 |
h, w = features.shape[1], features.shape[2]
|
510 |
if torch.cuda.is_available():
|
511 |
images = images.cuda()
|
512 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
513 |
-
cluster_images = make_cluster_plot(eigvecs, _images, h=h, w=w)
|
514 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
515 |
|
516 |
|
517 |
if video_output:
|
|
|
518 |
video_path = get_random_path()
|
519 |
video_cache.add_video(video_path)
|
520 |
pil_images_to_video(to_pil_images(rgb), video_path)
|
@@ -526,26 +593,26 @@ def ncut_run(
|
|
526 |
|
527 |
def _ncut_run(*args, **kwargs):
|
528 |
n_ret = kwargs.pop("n_ret", 1)
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
|
533 |
-
|
534 |
|
535 |
-
|
536 |
-
|
537 |
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
ret = ncut_run(*args, **kwargs)
|
547 |
-
ret = list(ret)[:n_ret] + [ret[-1]]
|
548 |
-
return ret
|
549 |
|
550 |
if USE_HUGGINGFACE_ZEROGPU:
|
551 |
@spaces.GPU(duration=20)
|
@@ -744,10 +811,15 @@ def run_fn(
|
|
744 |
n_ret=1,
|
745 |
):
|
746 |
|
|
|
|
|
|
|
|
|
747 |
if images is None:
|
748 |
gr.Warning("No images selected.")
|
749 |
return *(None for _ in range(n_ret)), "No images selected."
|
750 |
|
|
|
751 |
video_output = False
|
752 |
if isinstance(images, str):
|
753 |
images = extract_video_frames(images, max_frames=max_frames)
|
@@ -767,6 +839,7 @@ def run_fn(
|
|
767 |
images = [transform_image(image, resolution=resolution, stablediffusion=stablediffusion) for image in images]
|
768 |
images = torch.stack(images)
|
769 |
|
|
|
770 |
|
771 |
if is_lisa:
|
772 |
import subprocess
|
@@ -976,10 +1049,13 @@ def make_dataset_images_section(advanced=False, is_random=False):
|
|
976 |
def load_dataset_images(is_advanced, dataset_name, num_images=10,
|
977 |
is_filter=True, filter_by_class_text="0,1,2",
|
978 |
is_random=False, seed=1):
|
|
|
|
|
979 |
if is_advanced == "Basic":
|
980 |
gr.Info("Loaded images from Ego-Exo4D")
|
981 |
return default_images
|
982 |
try:
|
|
|
983 |
dataset = load_dataset(dataset_name, trust_remote_code=True)
|
984 |
key = list(dataset.keys())[0]
|
985 |
dataset = dataset[key]
|
@@ -990,6 +1066,7 @@ def make_dataset_images_section(advanced=False, is_random=False):
|
|
990 |
num_images = len(dataset)
|
991 |
|
992 |
if is_filter:
|
|
|
993 |
classes = [int(i) for i in filter_by_class_text.split(",")]
|
994 |
labels = np.array(dataset['label'])
|
995 |
unique_labels = np.unique(labels)
|
@@ -1193,6 +1270,7 @@ with demo:
|
|
1193 |
with gr.Column(scale=5, min_width=200):
|
1194 |
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
1195 |
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section()
|
|
|
1196 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
1197 |
|
1198 |
with gr.Column(scale=5, min_width=200):
|
|
|
77 |
min_dist=0.1,
|
78 |
sampling_method="fps",
|
79 |
metric="cosine",
|
80 |
+
progess_start=0.4,
|
81 |
):
|
82 |
+
progress = gr.Progress()
|
83 |
logging_str = ""
|
84 |
|
85 |
num_nodes = np.prod(features.shape[:-1])
|
|
|
90 |
logging_str += f"Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.\n"
|
91 |
|
92 |
start = time.time()
|
93 |
+
progress(progess_start+0.0, desc="NCut")
|
94 |
eigvecs, eigvals = NCUT(
|
95 |
num_eig=num_eig,
|
96 |
num_sample=num_sample_ncut,
|
|
|
105 |
logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
|
106 |
|
107 |
start = time.time()
|
108 |
+
progress(progess_start+0.01, desc="spectral-tSNE")
|
109 |
_, rgb = eigenvector_to_rgb(
|
110 |
eigvecs,
|
111 |
method=embedding_method,
|
|
|
253 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
254 |
return blended.astype(np.uint8)
|
255 |
|
256 |
+
def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6):
|
257 |
+
progress = gr.Progress()
|
258 |
+
progress(progess_start, desc="Finding Clusters by FPS")
|
259 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
260 |
+
eigvecs = eigvecs.to(device)
|
261 |
from ncut_pytorch.ncut_pytorch import farthest_point_sampling
|
262 |
magnitude = torch.norm(eigvecs, dim=-1)
|
263 |
+
p = 0.8
|
264 |
top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
|
265 |
+
num_samples = 300
|
266 |
+
if num_samples > top_p_idx.shape[0]:
|
267 |
+
num_samples = top_p_idx.shape[0]
|
268 |
fps_idx = farthest_point_sampling(eigvecs[top_p_idx], num_samples)
|
269 |
fps_idx = top_p_idx[fps_idx]
|
270 |
|
271 |
+
# fps round 2 on the heatmap
|
272 |
+
left = eigvecs[fps_idx, :].clone()
|
273 |
+
right = eigvecs.clone()
|
274 |
+
left = F.normalize(left, dim=-1)
|
275 |
+
right = F.normalize(right, dim=-1)
|
276 |
+
heatmap = left @ right.T
|
277 |
+
heatmap = F.normalize(heatmap, dim=-1)
|
278 |
+
num_samples = 80
|
279 |
+
if num_samples > fps_idx.shape[0]:
|
280 |
+
num_samples = fps_idx.shape[0]
|
281 |
+
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
282 |
+
fps_idx = fps_idx[r2_fps_idx]
|
283 |
+
|
284 |
# downsample to 256x256
|
285 |
images = F.interpolate(images, (256, 256), mode="bilinear")
|
286 |
images = images.cpu().numpy()
|
|
|
292 |
# sort the fps_idx by the mean of the heatmap
|
293 |
fps_heatmaps = {}
|
294 |
sort_values = []
|
295 |
+
top3_image_idx = {}
|
296 |
for _, idx in enumerate(fps_idx):
|
|
|
|
|
297 |
heatmap = F.cosine_similarity(eigvecs, eigvecs[idx][None], dim=-1)
|
298 |
+
|
299 |
+
# def top_percentile(tensor, p=0.8, max_size=10000):
|
300 |
+
# tensor = tensor.clone().flatten()
|
301 |
+
# if tensor.shape[0] > max_size:
|
302 |
+
# tensor = tensor[torch.randperm(tensor.shape[0])[:max_size]]
|
303 |
+
# return tensor.quantile(p)
|
304 |
+
# top_p = top_percentile(heatmap, p=0.5)
|
305 |
+
top_p = 0.5
|
306 |
+
|
307 |
heatmap = heatmap.reshape(-1, h, w)
|
308 |
+
mask = (heatmap > top_p).float()
|
309 |
+
# take top 3 masks only
|
310 |
+
mask_sort_values = mask.mean((1, 2))
|
311 |
+
mask_sort_idx = torch.argsort(mask_sort_values, descending=True)
|
312 |
+
mask = mask[mask_sort_idx[:3]]
|
313 |
sort_values.append(mask.mean().item())
|
314 |
+
# fps_heatmaps[idx.item()] = heatmap.cpu()
|
315 |
+
fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:3]].cpu()
|
316 |
+
top3_image_idx[idx.item()] = mask_sort_idx[:3]
|
317 |
+
# do the sorting
|
318 |
+
_sort_idx = torch.tensor(sort_values).argsort(descending=True)
|
319 |
+
fps_idx = fps_idx[_sort_idx]
|
320 |
+
# reverse the fps_idx
|
321 |
+
# fps_idx = fps_idx.flip(0)
|
322 |
+
# discard the big clusters
|
323 |
+
fps_idx = fps_idx[10:]
|
324 |
+
# shuffle the fps_idx
|
325 |
+
fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
|
326 |
|
327 |
fig_images = []
|
328 |
i_cluster = 0
|
329 |
+
num_plots = 10
|
330 |
+
plot_step_float = (1.0 - progess_start) / num_plots
|
331 |
+
for i_fig in range(num_plots):
|
332 |
+
progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
|
333 |
fig, axs = plt.subplots(3, 5, figsize=(15, 9))
|
334 |
for ax in axs.flatten():
|
335 |
ax.axis("off")
|
336 |
for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
|
337 |
heatmap = fps_heatmaps[idx.item()]
|
338 |
+
# mask = (heatmap > 0.1).float()
|
339 |
+
# sorted_image_idxs = torch.argsort(mask.mean((1, 2)), descending=True)
|
340 |
size = (images.shape[1], images.shape[2])
|
341 |
heatmap = apply_reds_colormap(heatmap, size)
|
342 |
+
# for i, image_idx in enumerate(sorted_image_idxs[:3]):
|
343 |
+
for i, image_idx in enumerate(top3_image_idx[idx.item()]):
|
344 |
+
# _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[image_idx])
|
345 |
+
_heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
|
346 |
axs[i, j].imshow(_heatmap)
|
347 |
if i == 0:
|
348 |
axs[i, j].set_title(f"cluster {i_cluster+1}", fontsize=24)
|
|
|
399 |
lisa_prompt2="",
|
400 |
lisa_prompt3="",
|
401 |
):
|
402 |
+
progress = gr.Progress()
|
403 |
+
progress(0.2, desc="Feature Extraction")
|
404 |
+
|
405 |
logging_str = ""
|
406 |
if "AlignedThreeModelAttnNodes" == model_name:
|
407 |
# dirty patch for the alignedcut paper
|
|
|
450 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
451 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
452 |
|
453 |
+
progress(0.4, desc="NCut")
|
454 |
+
|
455 |
if recursion:
|
456 |
rgbs = []
|
457 |
recursion_gammas = [recursion_l1_gamma, recursion_l2_gamma, recursion_l3_gamma]
|
458 |
inp = features
|
459 |
+
progress_start = 0.4
|
460 |
for i, n_eigs in enumerate([num_eig, recursion_l2_n_eigs, recursion_l3_n_eigs]):
|
461 |
logging_str += f"Recursion #{i+1}\n"
|
462 |
+
progress_start += + 0.1 * i
|
463 |
rgb, _logging_str, eigvecs = compute_ncut(
|
464 |
inp,
|
465 |
num_eig=n_eigs,
|
|
|
475 |
min_dist=min_dist,
|
476 |
sampling_method=sampling_method,
|
477 |
metric="cosine" if i == 0 else recursion_metric,
|
478 |
+
progess_start=progress_start,
|
479 |
)
|
480 |
logging_str += _logging_str
|
481 |
|
|
|
483 |
if "AlignedThreeModelAttnNodes" == model_name:
|
484 |
# dirty patch for the alignedcut paper
|
485 |
start = time.time()
|
486 |
+
progress(progress_start + 0.09, desc=f"Plotting Recursion {i+1}")
|
487 |
pil_images = []
|
488 |
for i_image in range(rgb.shape[0]):
|
489 |
_im = plot_one_image_36_grid(images[i_image], rgb[i_image])
|
|
|
502 |
if old_school_ncut: # individual images
|
503 |
logging_str += "Running NCut for each image independently\n"
|
504 |
rgb = []
|
505 |
+
progress_start = 0.4
|
506 |
+
step_float = 0.6 / features.shape[0]
|
507 |
for i_image in range(features.shape[0]):
|
508 |
logging_str += f"Image #{i_image+1}\n"
|
509 |
feature = features[i_image]
|
|
|
521 |
n_neighbors=n_neighbors,
|
522 |
min_dist=min_dist,
|
523 |
sampling_method=sampling_method,
|
524 |
+
progess_start=progress_start+step_float*i_image,
|
525 |
)
|
526 |
logging_str += _logging_str
|
527 |
rgb.append(_rgb[0])
|
|
|
549 |
if "AlignedThreeModelAttnNodes" == model_name:
|
550 |
# dirty patch for the alignedcut paper
|
551 |
start = time.time()
|
552 |
+
progress(0.6, desc="Plotting")
|
553 |
pil_images = []
|
554 |
for i_image in range(rgb.shape[0]):
|
555 |
_im = plot_one_image_36_grid(images[i_image], rgb[i_image])
|
|
|
570 |
|
571 |
if not video_output:
|
572 |
start = time.time()
|
573 |
+
progress_start = 0.6
|
574 |
+
progress(progress_start, desc="Plotting Clusters")
|
575 |
h, w = features.shape[1], features.shape[2]
|
576 |
if torch.cuda.is_available():
|
577 |
images = images.cuda()
|
578 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
579 |
+
cluster_images = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start)
|
580 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
581 |
|
582 |
|
583 |
if video_output:
|
584 |
+
progress(0.8, desc="Saving Video")
|
585 |
video_path = get_random_path()
|
586 |
video_cache.add_video(video_path)
|
587 |
pil_images_to_video(to_pil_images(rgb), video_path)
|
|
|
593 |
|
594 |
def _ncut_run(*args, **kwargs):
|
595 |
n_ret = kwargs.pop("n_ret", 1)
|
596 |
+
try:
|
597 |
+
if torch.cuda.is_available():
|
598 |
+
torch.cuda.empty_cache()
|
599 |
|
600 |
+
ret = ncut_run(*args, **kwargs)
|
601 |
|
602 |
+
if torch.cuda.is_available():
|
603 |
+
torch.cuda.empty_cache()
|
604 |
|
605 |
+
ret = list(ret)[:n_ret] + [ret[-1]]
|
606 |
+
return ret
|
607 |
+
except Exception as e:
|
608 |
+
gr.Error(str(e))
|
609 |
+
if torch.cuda.is_available():
|
610 |
+
torch.cuda.empty_cache()
|
611 |
+
return *(None for _ in range(n_ret)), "Error: " + str(e)
|
612 |
+
|
613 |
+
# ret = ncut_run(*args, **kwargs)
|
614 |
+
# ret = list(ret)[:n_ret] + [ret[-1]]
|
615 |
+
# return ret
|
616 |
|
617 |
if USE_HUGGINGFACE_ZEROGPU:
|
618 |
@spaces.GPU(duration=20)
|
|
|
811 |
n_ret=1,
|
812 |
):
|
813 |
|
814 |
+
progress=gr.Progress()
|
815 |
+
progress(0, desc="Starting")
|
816 |
+
|
817 |
+
|
818 |
if images is None:
|
819 |
gr.Warning("No images selected.")
|
820 |
return *(None for _ in range(n_ret)), "No images selected."
|
821 |
|
822 |
+
progress(0.05, desc="Processing Images")
|
823 |
video_output = False
|
824 |
if isinstance(images, str):
|
825 |
images = extract_video_frames(images, max_frames=max_frames)
|
|
|
839 |
images = [transform_image(image, resolution=resolution, stablediffusion=stablediffusion) for image in images]
|
840 |
images = torch.stack(images)
|
841 |
|
842 |
+
progress(0.1, desc="Downloading Model")
|
843 |
|
844 |
if is_lisa:
|
845 |
import subprocess
|
|
|
1049 |
def load_dataset_images(is_advanced, dataset_name, num_images=10,
|
1050 |
is_filter=True, filter_by_class_text="0,1,2",
|
1051 |
is_random=False, seed=1):
|
1052 |
+
progress = gr.Progress()
|
1053 |
+
progress(0, desc="Loading Images")
|
1054 |
if is_advanced == "Basic":
|
1055 |
gr.Info("Loaded images from Ego-Exo4D")
|
1056 |
return default_images
|
1057 |
try:
|
1058 |
+
progress(0.5, desc="Downloading Dataset")
|
1059 |
dataset = load_dataset(dataset_name, trust_remote_code=True)
|
1060 |
key = list(dataset.keys())[0]
|
1061 |
dataset = dataset[key]
|
|
|
1066 |
num_images = len(dataset)
|
1067 |
|
1068 |
if is_filter:
|
1069 |
+
progress(0.8, desc="Filtering Images")
|
1070 |
classes = [int(i) for i in filter_by_class_text.split(",")]
|
1071 |
labels = np.array(dataset['label'])
|
1072 |
unique_labels = np.unique(labels)
|
|
|
1270 |
with gr.Column(scale=5, min_width=200):
|
1271 |
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
1272 |
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section()
|
1273 |
+
num_images_slider.value = 30
|
1274 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
1275 |
|
1276 |
with gr.Column(scale=5, min_width=200):
|
app_text.py
CHANGED
@@ -150,6 +150,7 @@ def ncut_run(
|
|
150 |
min_dist=0.1,
|
151 |
sampling_method="fps",
|
152 |
):
|
|
|
153 |
logging_str = ""
|
154 |
if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
|
155 |
# raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
|
@@ -163,6 +164,7 @@ def ncut_run(
|
|
163 |
|
164 |
node_type = node_type.split(":")[0].strip()
|
165 |
|
|
|
166 |
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
|
167 |
|
168 |
start = time.time()
|
@@ -180,6 +182,7 @@ def ncut_run(
|
|
180 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
181 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
182 |
|
|
|
183 |
rgb, _logging_str, _ = compute_ncut(
|
184 |
features,
|
185 |
num_eig=num_eig,
|
@@ -197,6 +200,7 @@ def ncut_run(
|
|
197 |
logging_str += _logging_str
|
198 |
|
199 |
start = time.time()
|
|
|
200 |
title = f"{model_name}, Layer {layer}, {node_type}"
|
201 |
fig = make_plot(token_texts, rgb, title=title)
|
202 |
logging_str += f"Plotting time: {time.time() - start:.2f}s\n"
|
@@ -223,6 +227,8 @@ else:
|
|
223 |
return _ncut_run(*args, **kwargs)
|
224 |
|
225 |
def real_run(model_name, text, layer, node_type, num_eig, affinity_focal_gamma, num_sample_ncut, knn_ncut, embedding_method, num_sample_tsne, knn_tsne, perplexity, n_neighbors, min_dist, sampling_method):
|
|
|
|
|
226 |
model = TEXT_MODEL_DICT[model_name]()
|
227 |
return __ncut_run(model, text, model_name, layer, num_eig, node_type,
|
228 |
affinity_focal_gamma, num_sample_ncut, knn_ncut, embedding_method,
|
@@ -251,7 +257,9 @@ def make_demo():
|
|
251 |
clear_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop')
|
252 |
with gr.Column(scale=5, min_width=200):
|
253 |
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
254 |
-
|
|
|
|
|
255 |
layer = gr.Slider(1, 32, step=1, value=32, label="Layer")
|
256 |
node_type = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Node Type", value="block: sum of residual")
|
257 |
num_eig = gr.Slider(minimum=1, maximum=1000, step=1, value=100, label="Number of Eigenvectors")
|
|
|
150 |
min_dist=0.1,
|
151 |
sampling_method="fps",
|
152 |
):
|
153 |
+
progress = gr.Progress()
|
154 |
logging_str = ""
|
155 |
if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
|
156 |
# raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
|
|
|
164 |
|
165 |
node_type = node_type.split(":")[0].strip()
|
166 |
|
167 |
+
progress(0.5, desc="Feature Extraction")
|
168 |
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
|
169 |
|
170 |
start = time.time()
|
|
|
182 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
183 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
184 |
|
185 |
+
progress(0.6, desc="NCUT & spectral-tSNE")
|
186 |
rgb, _logging_str, _ = compute_ncut(
|
187 |
features,
|
188 |
num_eig=num_eig,
|
|
|
200 |
logging_str += _logging_str
|
201 |
|
202 |
start = time.time()
|
203 |
+
progress(0.8, desc="Plotting")
|
204 |
title = f"{model_name}, Layer {layer}, {node_type}"
|
205 |
fig = make_plot(token_texts, rgb, title=title)
|
206 |
logging_str += f"Plotting time: {time.time() - start:.2f}s\n"
|
|
|
227 |
return _ncut_run(*args, **kwargs)
|
228 |
|
229 |
def real_run(model_name, text, layer, node_type, num_eig, affinity_focal_gamma, num_sample_ncut, knn_ncut, embedding_method, num_sample_tsne, knn_tsne, perplexity, n_neighbors, min_dist, sampling_method):
|
230 |
+
progress = gr.Progress()
|
231 |
+
progress(0.1, desc="Downloading model")
|
232 |
model = TEXT_MODEL_DICT[model_name]()
|
233 |
return __ncut_run(model, text, model_name, layer, num_eig, node_type,
|
234 |
affinity_focal_gamma, num_sample_ncut, knn_ncut, embedding_method,
|
|
|
257 |
clear_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop')
|
258 |
with gr.Column(scale=5, min_width=200):
|
259 |
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
260 |
+
model_list = list(TEXT_MODEL_DICT.keys())
|
261 |
+
model_list = [model for model in model_list if model != "meta-llama/Meta-Llama-3-8B"]
|
262 |
+
model_name = gr.Dropdown(model_list, label="Model", value="meta-llama/Meta-Llama-3.1-8B")
|
263 |
layer = gr.Slider(1, 32, step=1, value=32, label="Layer")
|
264 |
node_type = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Node Type", value="block: sum of residual")
|
265 |
num_eig = gr.Slider(minimum=1, maximum=1000, step=1, value=100, label="Number of Eigenvectors")
|