LKCell / app.py
xiazhi1
initial commit
aea73e2
raw
history blame
No virus
4.67 kB
import gradio as gr
import os, requests
import numpy as np
import torch
import cv2
from cell_segmentation.inference.inference_cellvit_experiment_pannuke import InferenceCellViTParser,InferenceCellViT
from cell_segmentation.inference.inference_cellvit_experiment_monuseg import InferenceCellViTMoNuSegParser,MoNuSegInference
## local | remote
RUN_MODE = "remote"
if RUN_MODE != "local":
os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/model_best.pth")
## examples
os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/1.png")
os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/2.png")
os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/3.png")
os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/4.png")
## step 1: set up model
device = "cpu"
## pannuke set
pannuke_parser = InferenceCellViTParser()
pannuke_configurations = pannuke_parser.parse_arguments()
pannuke_inf = InferenceCellViT(
run_dir=pannuke_configurations["run_dir"],
checkpoint_name=pannuke_configurations["checkpoint_name"],
gpu=pannuke_configurations["gpu"],
magnification=pannuke_configurations["magnification"],
)
pannuke_checkpoint = torch.load(
pannuke_inf.run_dir / pannuke_inf.checkpoint_name, map_location="cpu"
)
pannuke_model = pannuke_inf.get_model(model_type=pannuke_checkpoint["arch"])
pannuke_model.load_state_dict(pannuke_checkpoint["model_state_dict"])
# # put model in eval mode
pannuke_model.to(device)
pannuke_model.eval()
## monuseg set
monuseg_parser = InferenceCellViTMoNuSegParser()
monuseg_configurations = monuseg_parser.parse_arguments()
monuseg_inf = MoNuSegInference(
model_path=monuseg_configurations["model"],
dataset_path=monuseg_configurations["dataset"],
outdir=monuseg_configurations["outdir"],
gpu=monuseg_configurations["gpu"],
patching=monuseg_configurations["patching"],
magnification=monuseg_configurations["magnification"],
overlap=monuseg_configurations["overlap"],
)
def click_process(image_input , type_dataset):
if type_dataset == "pannuke":
pannuke_inf.run_single_image_inference(pannuke_model,image_input)
else:
monuseg_inf.run_single_image_inference(monuseg_inf.model, image_input)
image_output = cv2.imread("pred_img.png")
image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2RGB)
return image_output
demo = gr.Blocks(title="LkCell")
with demo:
gr.Markdown(value="""
**Gradio demo for LKCell: Efficient Cell Nuclei Instance Segmentation with Large Convolution Kernels**. Check our [Github Repo](https://github.com/ziwei-cui/LKCellv1) πŸ˜›.
""")
with gr.Row():
with gr.Column():
with gr.Row():
Image_input = gr.Image(type="numpy", label="Input", interactive=True,height=480)
with gr.Row():
Type_dataset = gr.Radio(choices=["pannuke", "monuseg"], label=" input image's dataset type",value="pannuke")
with gr.Column():
with gr.Row():
image_output = gr.Image(type="numpy", label="Output",height=480)
with gr.Row():
Button_run = gr.Button("πŸš€ Submit (发送) ")
clear_button = gr.ClearButton(components=[Image_input,Type_dataset,image_output],value="🧹 Clear (清陀)")
Button_run.click(fn=click_process, inputs=[Image_input, Type_dataset ], outputs=[image_output])
## guiline
gr.Markdown(value="""
πŸ””**Guideline**
1. Upload your image or select one from the examples.
2. Set up the arguments: "Type_dataset".
3. Run the Submit button to get the output.
""")
# if RUN_MODE != "local":
gr.Examples(examples=[
['1.png', "pannuke"],
['2.png', "pannuke"],
['3.png', "monuseg"],
['4.png', "monuseg"],
],
inputs=[Image_input, Type_dataset], outputs=[image_output], label="Examples")
gr.HTML(value="""
<p style="text-align:center; color:orange"> <a href='https://github.com/ziwei-cui/LKCellv1' target='_blank'>Github Repo</a></p>
""")
gr.Markdown(value="""
Template is adapted from [Here](https://huggingface.co/spaces/menghanxia/disco)
""")
if RUN_MODE == "local":
demo.launch(server_name='127.0.0.1',server_port=8003)
else:
demo.launch()