Spaces:
Runtime error
Runtime error
import gradio as gr | |
from test import inference_img | |
from models import * | |
import numpy as np | |
device='cpu' | |
model = StyleMatte() | |
model = model.to(device) | |
checkpoint = f"stylematte.pth" | |
state_dict = torch.load(checkpoint, map_location=f'{device}') | |
model.load_state_dict(state_dict) | |
model.eval() | |
def predict(inp): | |
print("***********Inference****************") | |
mask = inference_img(model, inp) | |
inp_np = np.array(inp) | |
fg = np.uint8((mask*inp_np).permute(1,2,0).numpy()) | |
print("***********Inference finish****************") | |
print("***********FG****************", fg.shape, mask.shape) | |
return [mask, fg] | |
print("MODEL LOADED") | |
print("************************************") | |
iface = gr.Interface(fn=predict, | |
inputs=gr.Image(type="numpy"), | |
outputs=[gr.Image(type="numpy"),gr.Image(type="numpy")], | |
examples=["./logo.jpeg"]) | |
print("****************Interface created******************") | |
iface.launch() |