RxnIM / app.py
CYF200127's picture
Update app.py
b291e50 verified
import os
import gradio as gr
import json
from rxnim import RXNIM
from getReaction import generate_combined_image
import torch
from rxn.reaction import Reaction
from rdkit import Chem
from rdkit.Chem import rdChemReactions
from rdkit.Chem import Draw
PROMPT_DIR = "prompts/"
ckpt_path = "./rxn/model/model.ckpt"
model = Reaction(ckpt_path, device=torch.device('cpu'))
# 定义 prompt 文件名到友好名字的映射
PROMPT_NAMES = {
"2_RxnOCR.txt": "Reaction Image Parsing Workflow",
}
example_diagram = "examples/exp.png"
rdkit_image = "examples/image.webp"
def list_prompt_files_with_names():
"""
列出 prompts 目录下的所有 .txt 文件,为没有名字的生成默认名字。
返回 {friendly_name: filename} 映射。
"""
prompt_files = {}
for f in os.listdir(PROMPT_DIR):
if f.endswith(".txt"):
# 如果文件名有预定义的名字,使用预定义名字
friendly_name = PROMPT_NAMES.get(f, f"Task: {os.path.splitext(f)[0]}")
prompt_files[friendly_name] = f
return prompt_files
def parse_reactions(output_json):
"""
解析 JSON 格式的反应数据并格式化输出,包含颜色定制。
"""
reactions_data = json.loads(output_json) # 转换 JSON 字符串为字典
reactions_list = reactions_data.get("reactions", [])
detailed_output = []
smiles_output = []
for reaction in reactions_list:
reaction_id = reaction.get("reaction_id", "Unknown ID")
reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])]
conditions = [
f"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
for c in reaction.get("conditions", [])
]
conditions_1 = [
f"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
for c in reaction.get("conditions", [])
]
products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])]
# 构造反应的完整字符串,定制字体颜色
full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {', '.join(conditions_1)}"
full_reaction = f"<span style='color:black'>{full_reaction}</span>"
# 详细反应格式化输出
reaction_output = f"<b>Reaction: </b> {reaction_id}<br>"
reaction_output += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>"
reaction_output += f" Conditions: {', '.join(conditions)}<br>"
reaction_output += f" Products: {', '.join(products)}<br>"
reaction_output += f" <b>Full Reaction:</b> {full_reaction}<br>"
reaction_output += "<br>"
detailed_output.append(reaction_output)
reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}"
smiles_output.append(reaction_smiles)
return detailed_output, smiles_output
def process_chem_image(image, selected_task):
chem_mllm = RXNIM()
# 将友好名字转换为实际文件名
prompt_path = os.path.join(PROMPT_DIR, prompts_with_names[selected_task])
image_path = "temp_image.png"
image.save(image_path)
# 调用 RXNIM 处理
rxnim_result = chem_mllm.process(image_path, prompt_path)
# 将 JSON 结果解析为结构化输出
detailed_reactions, smiles_output = parse_reactions(rxnim_result)
# 调用 RxnScribe 模型处理并生成整合图像
predictions = model.predict_image_file(image_path, molscribe=True, ocr=True)
combined_image_path = generate_combined_image(predictions, image_path)
#combined_image_path = model.draw_predictions(predictions, image_path)
json_file_path = "output.json"
with open(json_file_path, "w") as json_file:
json.dump(json.loads(rxnim_result), json_file, indent=4)
# 返回详细反应和整合图像
return "\n\n".join(detailed_reactions), smiles_output, combined_image_path, example_diagram, json_file_path
prompts_with_names = list_prompt_files_with_names()
examples = [
["examples/reaction1.png", "Reaction Image Parsing Workflow"],
["examples/reaction2.png", "Reaction Image Parsing Workflow"],
["examples/reaction3.png", "Reaction Image Parsing Workflow"],
["examples/reaction4.png", "Reaction Image Parsing Workflow"],
]
# 定义 Gradio 界面
with gr.Blocks() as demo:
gr.Markdown(
"""
<center> <h1>Towards Large-scale Chemical Reaction Image Parsing via a Multimodal Large Language Model<h1></center>
Upload a reaction image and select a predefined task prompt.
""")
# 上半部分,输入区域
with gr.Row(equal_height=False):
with gr.Column(scale=1): # 左侧列
image_input = gr.Image(type="pil", label="Upload Reaction Image")
task_radio = gr.Radio(
choices=list(prompts_with_names.keys()),
label="Select a predefined task",
)
with gr.Row(): # Clear 和 Submit 按钮放在同一行
clear_button = gr.Button("Clear")
process_button = gr.Button("Run", elem_id="submit-btn")
gr.Markdown("### Reaction Imge Parsing Output")
reaction_output = gr.HTML(label="Reaction outputs")
with gr.Column(scale=1):
gr.Markdown("### Reaction Extraction Output")
visualization_output = gr.Image(label="Visualization Output")
schematic_diagram = gr.Image(value=example_diagram, label="Schematic Diagram")
with gr.Column(scale=1):
gr.Markdown("### Machine-readable Data Output")
smiles_output = gr.Textbox(
label="Reaction SMILES",
show_copy_button=True,
interactive=False,
visible=False,
)
# 下半部分,图像和 JSON 输出
@gr.render(inputs = smiles_output) # 使用gr.render修饰器绑定输入和渲染逻辑
def show_split(inputs): # 定义处理和展示分割文本的函数
if not inputs or isinstance(inputs, str) and inputs.strip() == "": # 检查输入文本是否为空
return gr.Textbox(label= "SMILES of Reaction i"), gr.Image(value=rdkit_image, label= "RDKit Image of Reaction i")
else:
# 假设输入是逗号分隔的 SMILES 字符串
smiles_list = inputs.split(",")
smiles_list = [item.strip("[]' ") for item in smiles_list]
components = [] # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件
for i, smiles in enumerate(smiles_list):
smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "")
reaction = rdChemReactions.ReactionFromSmarts(smiles)
if reaction:
img = Draw.ReactionToImage(reaction)
components.append(gr.Textbox(value=smiles,label= f"SMILES of Reaction {i + 1} ", show_copy_button=True, interactive=False))
components.append(gr.Image(value=img,label= f"RDKit Image of Reaction {i + 1} "))
return components # 返回包含所有 SMILES Textbox 组件的列表
download_json = gr.File(label="Download JSON File")
# 示例部分
gr.Examples(
examples=examples,
inputs=[image_input, task_radio],
outputs=[reaction_output, smiles_output, visualization_output],
)
# 绑定功能
clear_button.click(
lambda: (None, None, None, None, None),
inputs=[],
outputs=[
image_input,
task_radio,
reaction_output,
smiles_output,
visualization_output,
],
)
process_button.click(
process_chem_image,
inputs=[image_input, task_radio],
outputs=[
reaction_output,
smiles_output,
visualization_output,
schematic_diagram,
download_json,
],
)
demo.css = """
#submit-btn {
background-color: #FF914D;
color: white;
font-weight: bold;
}
"""
demo.launch()