import os import gradio as gr import sys sys.path.insert(0, 'U-2-Net') from skimage import io, transform import torch import torchvision from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms#, utils # import torch.optim as optim import numpy as np from PIL import Image import glob from data_loader import RescaleT from data_loader import ToTensor from data_loader import ToTensorLab from data_loader import SalObjDataset from model import U2NET # full size version 173.6 MB from model import U2NETP # small version u2net 4.7 MB # normalize the predicted SOD probability map def normPRED(d): ma = torch.max(d) mi = torch.min(d) dn = (d-mi)/(ma-mi) return dn def save_output(image_name,pred,d_dir): predict = pred predict = predict.squeeze() predict_np = predict.cpu().data.numpy() im = Image.fromarray(predict_np*255).convert('RGB') img_name = image_name.split(os.sep)[-1] image = io.imread(image_name) imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR) pb_np = np.array(imo) aaa = img_name.split(".") bbb = aaa[0:-1] imidx = bbb[0] for i in range(1,len(bbb)): imidx = imidx + "." + bbb[i] imo.save(d_dir+'/'+imidx+'.png') return d_dir+'/'+imidx+'.png' # --------- 1. get image path and name --------- model_name='u2net_portrait'#u2netp image_dir = 'portrait_im' prediction_dir = 'portrait_results' if(not os.path.exists(prediction_dir)): os.mkdir(prediction_dir) model_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'U-2-Net/saved_models/u2net_portrait/u2net_portrait.pth') # --------- 3. model define --------- print("...load U2NET---173.6 MB") net = U2NET(3,1) net.load_state_dict(torch.load(model_dir, map_location='cpu')) # if torch.cuda.is_available(): # net.cuda() net.eval() def process(im): img_name_list = glob.glob(im.name) print("Number of images: ", len(img_name_list)) # --------- 2. dataloader --------- # 1. dataloader test_salobj_dataset = SalObjDataset(img_name_list=img_name_list, lbl_name_list=[], transform=transforms.Compose([RescaleT(512), ToTensorLab(flag=0)]) ) test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1, shuffle=False, num_workers=1) results = [] # --------- 4. inference for each image --------- for i_test, data_test in enumerate(test_salobj_dataloader): print("inferencing:", img_name_list[i_test].split(os.sep)[-1]) inputs_test = data_test['image'] inputs_test = inputs_test.type(torch.FloatTensor) # if torch.cuda.is_available(): # inputs_test = Variable(inputs_test.cuda()) # else: inputs_test = Variable(inputs_test) d1, d2, d3, d4, d5, d6, d7 = net(inputs_test) # normalization pred = 1.0 - d1[:, 0, :, :] pred = normPRED(pred) # save results to test_results folder results.append(save_output(img_name_list[i_test], pred, prediction_dir)) del d1, d2, d3, d4, d5, d6, d7 print(results) return Image.open(results[0]) title = "U-2-Net" description = "Gradio demo for U-2-Net, https://github.com/xuebinqin/U-2-Net" article = "" gr.Interface( process, [gr.inputs.Image(type="pil", label="Input") ], gr.outputs.Image(type="pil", label="Output"), title=title, description=description, article=article, examples=[], allow_flagging=False, allow_screenshot=False ).launch(enable_queue=True,cache_examples=True)