import torch from PIL import Image import gradio as gr from sam2.sam2_image_predictor import SAM2ImagePredictor # Load the SAM2 model predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large") # Function to predict masks from the image and prompts def generate_mask(image, prompt): with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): predictor.set_image(image) masks, _, _ = predictor.predict(prompt) return masks[0] # Returning the first mask for simplicity # Set up the Gradio interface with gr.Blocks() as demo: gr.Markdown("# Image Segmentation using SAM2") # Input: Upload an image image_input = gr.Image(label="Upload Image", type="pil") # Input: Text prompt for image segmentation prompt_input = gr.Textbox(label="Enter segmentation prompt", placeholder="Describe what you want to segment") # Output: Display the mask generated by the SAM2 model output_mask = gr.Image(label="Generated Mask") # Button to trigger mask generation generate_button = gr.Button("Generate Mask") # Link button click with the segmentation function generate_button.click(fn=generate_mask, inputs=[image_input, prompt_input], outputs=output_mask) # Launch the Gradio app demo.launch()