Spaces:
fantos
/
Runtime error

nuking / app.py
arxivgpt kim
Update app.py
63e5794 verified
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
import os
import requests
from moviepy.editor import VideoFileClip
from moviepy.audio.AudioClip import AudioClip
def search_pexels_images(query):
API_KEY = os.getenv("API_KEY")
url = f"https://api.pexels.com/v1/search?query={query}&per_page=80"
headers = {"Authorization": API_KEY}
response = requests.get(url, headers=headers)
data = response.json()
# ๊ณ ํ•ด์ƒ๋„ ์ด๋ฏธ์ง€ URL๋งŒ ์„ ํƒํ•˜์—ฌ ๋ฆฌ์ŠคํŠธ ์ƒ์„ฑ
images_urls = []
for photo in data.get('photos', []):
# 'large2x' ํ•ด์ƒ๋„์˜ ์ด๋ฏธ์ง€๊ฐ€ ์ œ๊ณต๋˜๋Š” ๊ฒฝ์šฐ, ํ•ด๋‹น URL ์‚ฌ์šฉ
if 'src' in photo and 'large2x' in photo['src']:
images_urls.append(photo['src']['large2x'])
# 'large2x' ํ•ด์ƒ๋„์˜ ์ด๋ฏธ์ง€๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ, 'large' ๋˜๋Š” 'original'์„ ๋Œ€์ฒด๋กœ ์‚ฌ์šฉ
elif 'large' in photo['src']:
images_urls.append(photo['src']['large'])
elif 'original' in photo['src']:
images_urls.append(photo['src']['original'])
return images_urls
def show_search_results(query):
images_urls = search_pexels_images(query)
return images_urls
net=BriaRMBG()
# model_path = "./model1.pth"
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net=net.cuda()
else:
net.load_state_dict(torch.load(model_path,map_location="cpu"))
net.eval()
def resize_image(image):
image = image.convert('RGB')
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image):
# ์ด๋ฏธ์ง€๊ฐ€ numpy ๋ฐฐ์—ด์ธ ๊ฒฝ์šฐ์—๋งŒ PIL.Image ๊ฐ์ฒด๋กœ ๋ณ€ํ™˜
if isinstance(image, np.ndarray):
orig_image = Image.fromarray(image)
else:
# ์ด๋ฏธ PIL.Image.Image ๊ฐ์ฒด์ธ ๊ฒฝ์šฐ, ๋ณ€ํ™˜ ์—†์ด ์‚ฌ์šฉ
orig_image = image
w, h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# inference
result = net(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
# image to pil
im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
new_im.paste(orig_image, mask=pil_im)
return new_im
def calculate_position(org_size, add_size, position):
if position == "์ƒ๋‹จ ์ขŒ์ธก":
return (0, 0)
elif position == "์ƒ๋‹จ ๊ฐ€์šด๋ฐ":
return ((org_size[0] - add_size[0]) // 2, 0)
elif position == "์ƒ๋‹จ ์šฐ์ธก":
return (org_size[0] - add_size[0], 0)
elif position == "์ค‘์•™ ์ขŒ์ธก":
return (0, (org_size[1] - add_size[1]) // 2)
elif position == "์ค‘์•™ ๊ฐ€์šด๋ฐ":
return ((org_size[0] - add_size[0]) // 2, (org_size[1] - add_size[1]) // 2)
elif position == "์ค‘์•™ ์šฐ์ธก":
return (org_size[0] - add_size[0], (org_size[1] - add_size[1]) // 2)
elif position == "ํ•˜๋‹จ ์ขŒ์ธก":
return (0, org_size[1] - add_size[1])
elif position == "ํ•˜๋‹จ ๊ฐ€์šด๋ฐ":
return ((org_size[0] - add_size[0]) // 2, org_size[1] - add_size[1])
elif position == "ํ•˜๋‹จ ์šฐ์ธก":
return (org_size[0] - add_size[0], org_size[1] - add_size[1])
def merge(org_image, add_image, scale, position, display_size):
# ์‚ฌ์šฉ์ž๊ฐ€ ์„ ํƒํ•œ ๋””์Šคํ”Œ๋ ˆ์ด ํฌ๊ธฐ์— ๋”ฐ๋ผ ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ํฌ๊ธฐ ์กฐ์ ˆ
display_width, display_height = map(int, display_size.split('x'))
# ์ด๋ฏธ์ง€ ๋ณ‘ํ•ฉ ๋กœ์ง
scale_percentage = scale / 100.0
new_size = (int(add_image.width * scale_percentage), int(add_image.height * scale_percentage))
add_image = add_image.resize(new_size, Image.Resampling.LANCZOS)
position = calculate_position(org_image.size, add_image.size, position)
merged_image = Image.new("RGBA", org_image.size)
merged_image.paste(org_image, (0, 0))
merged_image.paste(add_image, position, add_image)
# ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ๋””์Šคํ”Œ๋ ˆ์ด ํฌ๊ธฐ ์กฐ์ ˆ
final_image = merged_image.resize((display_width, display_height), Image.Resampling.LANCZOS)
return final_image
with gr.Blocks() as demo:
with gr.Tab("Background Removal"):
with gr.Column():
gr.Markdown("๋ˆ„๋ผ๋”ฐ๊ธฐ์˜ ์™• '๋ˆ„ํ‚น'(Nuking)")
gr.HTML('''
<p style="margin-bottom: 10px; font-size: 94%">
This is a demo for BRIA RMBG 1.4 that using
<a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
</p>
''')
input_image = gr.Image(type="pil")
output_image = gr.Image()
process_button = gr.Button("Remove Background")
process_button.click(fn=process, inputs=input_image, outputs=output_image)
with gr.Tab("Merge"):
with gr.Column():
org_image = gr.Image(label="Background", type='pil', image_mode='RGBA', height=400) # ์˜ˆ์‹œ๋กœ ๋†’์ด ์กฐ์ ˆ
add_image = gr.Image(label="Foreground", type='pil', image_mode='RGBA', height=400) # ์˜ˆ์‹œ๋กœ ๋†’์ด ์กฐ์ ˆ
scale = gr.Slider(minimum=10, maximum=200, step=1, value=100, label="Scale of Foreground Image (%)")
position = gr.Radio(choices=["์ค‘์•™ ๊ฐ€์šด๋ฐ", "์ƒ๋‹จ ์ขŒ์ธก", "์ƒ๋‹จ ๊ฐ€์šด๋ฐ", "์ƒ๋‹จ ์šฐ์ธก", "์ค‘์•™ ์ขŒ์ธก", "์ค‘์•™ ์šฐ์ธก", "ํ•˜๋‹จ ์ขŒ์ธก", "ํ•˜๋‹จ ๊ฐ€์šด๋ฐ", "ํ•˜๋‹จ ์šฐ์ธก"], value="์ค‘์•™ ๊ฐ€์šด๋ฐ", label="Position of Foreground Image")
display_size = gr.Textbox(value="1024x768", label="Display Size (Width x Height)")
btn_merge = gr.Button("Merge Images")
result_merge = gr.Image()
btn_merge.click(
fn=merge,
inputs=[org_image, add_image, scale, position, display_size],
outputs=result_merge,
)
with gr.TabItem("Image Search"):
with gr.Column():
gr.Markdown("### FREE Image Search")
search_query = gr.Textbox(label="์‚ฌ์ง„ ๊ฒ€์ƒ‰")
search_btn = gr.Button("๊ฒ€์ƒ‰")
images_output = gr.Gallery(label="๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€")
search_btn.click(
fn=show_search_results,
inputs=search_query,
outputs=images_output
)
demo.launch()