import os import torch import numpy as np import matplotlib.pyplot as plt from PIL import Image from sam2.build_sam import build_sam2_video_predictor import json def build_sam2(cfg, checkpoints): return build_sam2_video_predictor(cfg, checkpoints) def show_mask(mask, ax, obj_id=None, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: cmap = plt.get_cmap("tab10") cmap_idx = 0 if obj_id is None else obj_id color = np.array([*cmap(cmap_idx)[:3], 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_points(coords, labels, ax, marker_size=200): pos_points = coords[labels==1] neg_points = coords[labels==0] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) # 给帧添加points提示 # ann_frame_idx: the frame index we interact with # ann_obj_id: give a unique id to each object we interact with (it can be any integers) def add_new_points(predictor, inference_state, ann_frame_idx, ann_obj_id, points, labels): _, out_obj_ids, out_mask_logits = predictor.add_new_points( inference_state=inference_state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels, ) return out_obj_ids, out_mask_logits # 获取所有帧的分割结果 def all_frames_masks(predictor, inference_state): video_segments = {} # video_segments contains the per-frame segmentation results for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } return video_segments def resize_mask_to_img(masks, target_width, target_height): frame_mask = [] origin_size = masks[0][1].shape # 1表示object id for frame, objects_mask in masks.items(): # 每个frame和该frame对应的分割结果 # 每个frame可能包含多个object对应的mask masks = list(objects_mask.values()) if not masks: # masks为空,即当前frame不包含object frame_mask.append(np.ones(origin_size, dtype=bool)) else: # 将当前frame包含的所有object的mask取并集 union_mask = masks[0] for mask in masks[1:]: union_mask = np.logical_or(union_mask, mask) frame_mask.append(union_mask) resized_mask = [] for mask in frame_mask: mask_image = Image.fromarray(mask.squeeze(0).astype(np.uint8) * 255) resized_mask_image = mask_image.resize((target_width, target_height), Image.NEAREST) resized_mask.append(np.array(resized_mask_image) > 0) return resized_mask def sava_mask(output_folder, mask): # 转换为Image对象 binary_image = Image.fromarray(mask.squeeze(0).astype(np.uint8) * 255, 'L') # 'L'代表灰度模式 new_file_path = os.path.join(output_folder, "binary_mask.jpg") # 保存新的图片 binary_image.save(new_file_path) print(f"sava mask to {new_file_path} .") # 经过SAM2获取所有frames的分割结果 def get_masks_from_sam2(dataset_name, scene_name, img_shape, h, w, target_ind): # 加载模型 sam2_checkpoint = "D:\XMU\mac\hujie\\3D\DUST3RwithSAM2\dust3rWithSam2\SAM2\checkpoints\sam2_hiera_large.pt" model_cfg = "sam2_hiera_l.yaml" predictor = build_sam2(model_cfg, sam2_checkpoint) # 视频帧所在的路径 video_dir = os.path.join("data", dataset_name, scene_name, "images_8") # 读取帧图片 frame_names = [ p for p in sorted(os.listdir(video_dir)) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"] ] inference_state = predictor.init_state(video_path=video_dir) predictor.reset_state(inference_state) # 给一个帧添加points # 读取prompts.json json_dir = os.path.join("data", dataset_name, "prompts.json") with open(json_dir, 'r') as file: data = json.load(file) # 解析 prompts prompts = data[scene_name] points = np.array(prompts['points'], dtype=np.float32) labels = np.array(prompts['labels'], dtype=np.int32) out_obj_ids, out_mask_logits = add_new_points(predictor, inference_state, 0, 1, points, labels) # sam2获取所有帧的分割结果 video_segments = all_frames_masks(predictor, inference_state) # 渲染处理后展示结果 vis_frame_stride = 3 plt.close("all") for out_frame_idx in range(0, len(frame_names), vis_frame_stride): plt.figure(figsize=(6, 4)) plt.title(f"frame {out_frame_idx}") plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx]))) for out_obj_id, out_mask in video_segments[out_frame_idx].items(): show_mask(out_mask, plt.gca(), obj_id=out_obj_id) if out_frame_idx == 0: # 显示点 show_points(points, labels, plt.gca()) plt.title(f"Frame {out_frame_idx}") plt.axis('off') # 可选:关闭坐标轴 plt.show() # 保存target_ind对应的view的SAM2输出mask作为ground truth mask,用于计算IoU和Acc mask_dir = os.path.join("data", dataset_name, "masks", scene_name) sava_mask(mask_dir, video_segments[target_ind][1]) # 将 SAM2的mask resize成DUST3R要求的尺寸 resize_mask = resize_mask_to_img(video_segments, w, h) return resize_mask def array_to_tensor_masks(masks_list): # 将列表转换为一个大的 ndarray,形状为 (n, H, W) masks_array = np.stack(masks_list) # 将其 reshape 为 (n, H*W, 1) masks_array = masks_array.reshape(masks_array.shape[0], -1) # 转换为 bool 类型的 Tensor masks_tensor = torch.tensor(masks_array, dtype=torch.bool) return masks_tensor