haritsahm's picture
Update app.py
b88ad11
raw
history blame
2.99 kB
import json
import os
from pathlib import Path
import gradio as gr
import numpy as np
import torch
from monai.bundle import ConfigParser
from utils import page_utils
with open("configs/inference.json") as f:
inference_config = json.load(f)
device = torch.device('cpu')
if torch.cuda.is_available():
device = torch.device('cuda:0')
# * NOTE: device must be hardcoded, config file won't affect the device selection
inference_config["device"] = device
parser = ConfigParser()
parser.read_config(f=inference_config)
parser.read_meta(f="configs/metadata.json")
inference = parser.get_parsed_content("inferer")
# loader = parser.get_parsed_content("dataloader")
network = parser.get_parsed_content("network_def")
preprocess = parser.get_parsed_content("preprocessing")
postprocess = parser.get_parsed_content("postprocessing")
use_fp16 = os.environ.get('USE_FP16', False)
state_dict = torch.load("models/model.pt")
network.load_state_dict(state_dict, strict=True)
network = network.to(device)
network.eval()
if use_fp16 and torch.cuda.is_available():
network = network.half()
label2color = {0: (0, 0, 0),
1: (225, 24, 69), # RED
2: (135, 233, 17), # GREEN
3: (0, 87, 233), # BLUE
4: (242, 202, 25), # YELLOW
5: (137, 49, 239),} # PURPLE
example_files = list(Path("sample_data").glob("*.png"))
def visualize_instance_seg_mask(mask):
image = np.zeros((mask.shape[0], mask.shape[1], 3))
labels = np.unique(mask)
for i in range(image.shape[0]):
for j in range(image.shape[1]):
image[i, j, :] = label2color[mask[i, j]]
image = image / 255
return image
def query_image(img):
data = {"image": img}
batch = preprocess(data)
batch['image'] = batch['image'].to(device)
if use_fp16 and torch.cuda.is_available():
batch['image'] = batch['image'].half()
with torch.no_grad():
pred = inference(batch['image'].unsqueeze(dim=0), network)
batch["pred"] = pred
for k,v in batch["pred"].items():
batch["pred"][k] = v.squeeze(dim=0)
batch = postprocess(batch)
result = visualize_instance_seg_mask(batch["type_map"].squeeze())
# Combine image
result = batch["image"].permute(1, 2, 0).cpu().numpy() * 0.5 + result * 0.5
# Solve rotating problem
result = np.fliplr(result)
result = np.rot90(result, k=1)
return result
# load Markdown file
with open('index.html', encoding='utf-8') as f:
html_content = f.read()
demo = gr.Interface(
query_image,
inputs=[gr.Image(type="filepath")],
outputs="image",
theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set(
button_primary_background_fill="*primary_600",
button_primary_background_fill_hover="*primary_500",
button_primary_text_color="white",
),
description = html_content,
examples=example_files,
)
demo.queue(max_size=10).launch()