reformat overview by adding html and css file, updated app.py file and added utils folder for color schema
9acc5c1
import json | |
import os | |
from pathlib import Path | |
import gradio as gr | |
import numpy as np | |
import torch | |
from monai.bundle import ConfigParser | |
from utils import page_utils | |
with open("configs/inference.json") as f: | |
inference_config = json.load(f) | |
device = torch.device('cpu') | |
if torch.cuda.is_available(): | |
device = torch.device('cuda:0') | |
# * NOTE: device must be hardcoded, config file won't affect the device selection | |
inference_config["device"] = device | |
parser = ConfigParser() | |
parser.read_config(f=inference_config) | |
parser.read_meta(f="configs/metadata.json") | |
inference = parser.get_parsed_content("inferer") | |
# loader = parser.get_parsed_content("dataloader") | |
network = parser.get_parsed_content("network_def") | |
preprocess = parser.get_parsed_content("preprocessing") | |
postprocess = parser.get_parsed_content("postprocessing") | |
use_fp16 = os.environ.get('USE_FP16', False) | |
state_dict = torch.load("models/model.pt") | |
network.load_state_dict(state_dict, strict=True) | |
network = network.to(device) | |
network.eval() | |
if use_fp16 and torch.cuda.is_available(): | |
network = network.half() | |
label2color = {0: (0, 0, 0), | |
1: (225, 24, 69), # RED | |
2: (135, 233, 17), # GREEN | |
3: (0, 87, 233), # BLUE | |
4: (242, 202, 25), # YELLOW | |
5: (137, 49, 239),} # PURPLE | |
example_files = list(Path("sample_data").glob("*.png")) | |
def visualize_instance_seg_mask(mask): | |
image = np.zeros((mask.shape[0], mask.shape[1], 3)) | |
labels = np.unique(mask) | |
for i in range(image.shape[0]): | |
for j in range(image.shape[1]): | |
image[i, j, :] = label2color[mask[i, j]] | |
image = image / 255 | |
return image | |
def query_image(img): | |
data = {"image": img} | |
batch = preprocess(data) | |
batch['image'] = batch['image'].to(device) | |
if use_fp16 and torch.cuda.is_available(): | |
batch['image'] = batch['image'].half() | |
with torch.no_grad(): | |
pred = inference(batch['image'].unsqueeze(dim=0), network) | |
batch["pred"] = pred | |
for k,v in batch["pred"].items(): | |
batch["pred"][k] = v.squeeze(dim=0) | |
batch = postprocess(batch) | |
result = visualize_instance_seg_mask(batch["type_map"].squeeze()) | |
# Combine image | |
result = batch["image"].permute(1, 2, 0).cpu().numpy() * 0.5 + result * 0.5 | |
# Solve rotating problem | |
result = np.fliplr(result) | |
result = np.rot90(result, k=1) | |
return result | |
# load Markdown file | |
with open('index.html', encoding='utf-8') as f: | |
html_content = f.read() | |
demo = gr.Interface( | |
query_image, | |
inputs=[gr.Image(type="filepath")], | |
outputs="image", | |
theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set( | |
button_primary_background_fill="*primary_600", | |
button_primary_background_fill_hover="*primary_500", | |
button_primary_text_color="white", | |
), | |
description = html_content, | |
examples=example_files, | |
) | |
demo.queue(concurrency_count=20).launch() | |