import gradio as gr from test import inference_img from models import * import numpy as np device='cpu' model = StyleMatte() model = model.to(device) checkpoint = f"stylematte.pth" state_dict = torch.load(checkpoint, map_location=f'{device}') model.load_state_dict(state_dict) model.eval() def predict(inp): print("***********Inference****************") mask = inference_img(model, inp) inp_np = np.array(inp) fg = np.uint8((mask*inp_np).permute(1,2,0).numpy()) print("***********Inference finish****************") print("***********FG****************", fg.shape, mask.shape) return [mask, fg] print("MODEL LOADED") print("************************************") iface = gr.Interface(fn=predict, inputs=gr.Image(type="numpy"), outputs=[gr.Image(type="numpy"),gr.Image(type="numpy")], examples=["./logo.jpeg"]) print("****************Interface created******************") iface.launch()