import gradio as gr from loadimg import load_img import spaces from transformers import AutoModelForImageSegmentation import torch from torchvision import transforms import moviepy.editor as mp from pydub import AudioSegment from PIL import Image import numpy as np import os import tempfile import uuid import time from concurrent.futures import ThreadPoolExecutor import asyncio torch.set_float32_matmul_precision("medium") device = "cuda" if torch.cuda.is_available() else "cpu" # Load both BiRefNet models birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True) birefnet.to(device) birefnet_lite = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet_lite", trust_remote_code=True) birefnet_lite.to(device) transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) # Function to process a single frame asynchronously async def process_frame_async(frame, bg_type, bg, fast_mode, bg_frame_index, background_frames, color): pil_image = Image.fromarray(frame) if bg_type == "Color": processed_image = process(pil_image, color, fast_mode) elif bg_type == "Image": processed_image = process(pil_image, bg, fast_mode) elif bg_type == "Video": background_frame = background_frames[bg_frame_index % len(background_frames)] bg_frame_index += 1 background_image = Image.fromarray(background_frame) processed_image = process(pil_image, background_image, fast_mode) else: processed_image = pil_image # Default to original image if no background is selected return np.array(processed_image), bg_frame_index @spaces.GPU async def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down", fast_mode=True, max_workers=6): start_time = time.time() # Start the timer video = mp.VideoFileClip(vid) if fps == 0: fps = video.fps audio = video.audio frames = list(video.iter_frames(fps=fps)) processed_frames = [] yield gr.update(visible=True), gr.update(visible=False), f"Processing started... Elapsed time: 0 seconds" if bg_type == "Video": background_video = mp.VideoFileClip(bg_video) if background_video.duration < video.duration: if video_handling == "slow_down": background_video = background_video.fx(mp.vfx.speedx, factor=video.duration / background_video.duration) else: # video_handling == "loop" background_video = mp.concatenate_videoclips([background_video] * int(video.duration / background_video.duration + 1)) background_frames = list(background_video.iter_frames(fps=fps)) else: background_frames = None bg_frame_index = 0 # Use ThreadPoolExecutor for parallel processing with specified max_workers loop = asyncio.get_event_loop() tasks = [ loop.run_in_executor( None, process_frame_async, frames[i], bg_type, bg_image, fast_mode, bg_frame_index, background_frames, color ) for i in range(len(frames)) ] for future in asyncio.as_completed(tasks): result, bg_frame_index = await future processed_frames.append(result) elapsed_time = time.time() - start_time yield result, None, f"Processing frame {len(processed_frames)}... Elapsed time: {elapsed_time:.2f} seconds" processed_video = mp.ImageSequenceClip(processed_frames, fps=fps) processed_video = processed_video.set_audio(audio) with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: temp_filepath = temp_file.name processed_video.write_videofile(temp_filepath, codec="libx264") elapsed_time = time.time() - start_time yield gr.update(visible=False), gr.update(visible=True), f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds" yield processed_frames[-1], temp_filepath, f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds" def process(image, bg, fast_mode=False): image_size = image.size input_images = transform_image(image).unsqueeze(0).to(device) model = birefnet_lite if fast_mode else birefnet with torch.no_grad(): preds = model(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) if isinstance(bg, str) and bg.startswith("#"): color_rgb = tuple(int(bg[i:i+2], 16) for i in (1, 3, 5)) background = Image.new("RGBA", image_size, color_rgb + (255,)) elif isinstance(bg, Image.Image): background = bg.convert("RGBA").resize(image_size) else: background = Image.open(bg).convert("RGBA").resize(image_size) image = Image.composite(image, background, mask) return image with gr.Blocks(theme=gr.themes.Ocean()) as demo: gr.Markdown("# Video Background Remover & Changer\n### You can replace image background with any color, image or video.\nNOTE: As this Space is running on ZERO GPU it has limit. It can handle approx 200 frames at once. So, if you have a big video than use small chunks or Duplicate this space.") with gr.Row(): in_video = gr.Video(label="Input Video", interactive=True) stream_image = gr.Image(label="Streaming Output", visible=False) out_video = gr.Video(label="Final Output Video") submit_button = gr.Button("Change Background", interactive=True) with gr.Row(): fps_slider = gr.Slider( minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 will inherit the original fps value)", interactive=True ) bg_type = gr.Radio(["Color", "Image", "Video"], label="Background Type", value="Color", interactive=True) color_picker = gr.ColorPicker(label="Background Color", value="#00FF00", visible=True, interactive=True) bg_image = gr.Image(label="Background Image", type="filepath", visible=False, interactive=True) bg_video = gr.Video(label="Background Video", visible=False, interactive=True) with gr.Column(visible=False) as video_handling_options: video_handling_radio = gr.Radio(["slow_down", "loop"], label="Video Handling", value="slow_down", interactive=True) fast_mode_checkbox = gr.Checkbox(label="Fast Mode (Use BiRefNet_lite)", value=True, interactive=True) max_workers_slider = gr.Slider( minimum=1, maximum=32, step=1, value=6, label="Max Workers", info="Determines how many frames to process in parallel", interactive=True ) time_textbox = gr.Textbox(label="Time Elapsed", interactive=False) def update_visibility(bg_type): if bg_type == "Color": return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) elif bg_type == "Image": return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) elif bg_type == "Video": return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) else: return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) bg_type.change(update_visibility, inputs=bg_type, outputs=[color_picker, bg_image, bg_video, video_handling_options]) examples = gr.Examples( [ ["rickroll-2sec.mp4", "Video", None, "background.mp4"], ["rickroll-2sec.mp4", "Image", "images.webp", None], ["rickroll-2sec.mp4", "Color", None, None], ], inputs=[in_video, bg_type, bg_image, bg_video], outputs=[stream_image, out_video, time_textbox], fn=fn, cache_examples=True, cache_mode="eager", ) submit_button.click( fn, inputs=[in_video, bg_type, bg_image, bg_video, color_picker, fps_slider, video_handling_radio, fast_mode_checkbox, max_workers_slider], outputs=[stream_image, out_video, time_textbox], ) if __name__ == "__main__": demo.launch(show_error=True)