theaTRON / src /process.py
mikonvergence's picture
main src files
aca81a2
raw
history blame contribute delete
852 Bytes
import gradio as gr
import cv2
from PIL import Image
import numpy as np
import torch
from .detection import *
from .masking import *
from .synthesis import *
def forward(image_cam, image_upload, prompt="", n_prompt=None, num_steps=20, seed=0, original_resolution=False):
if image_cam is None:
image = image_upload
else:
image = image_cam
if not original_resolution:
w,h = image.size
ratio = 512/h
new_size = int(w*ratio), int(h*ratio)
image = image.resize(new_size)
# detect face
dets = detect_face(image)
# segment hair and face
faces, hairs = process_face(dets)
# build mask
mask = build_mask_multi(image, faces, hairs)
# synthesise
new_image = synthesis(image,mask, prompt, n_prompt, num_steps=num_steps, seed=seed)
return new_image