TB-OCR / app.py
mrfakename's picture
Update app.py
142c37c verified
raw
history blame
2.89 kB
ABOUT = """
# TB-OCR Preview 0.1 Unofficial Demo
This is an unofficial demo of [yifeihu/TB-OCR-preview-0.1](https://huggingface.co/yifeihu/TB-OCR-preview-0.1).
Overview of TB-OCR:
> TB-OCR-preview (Text Block OCR), created by [Yifei Hu](https://x.com/hu_yifei), is an end-to-end OCR model handling text, math latex, and markdown formats all at once. The model takes a block of text as the input and returns clean markdown output. Headers are marked with `##`. Math expressions are guaranteed to be wrapped in brackets `\( inline math \) \[ display math \]` for easier parsing. This model does not require line-detection or math formula detection.
(From the [model card](https://huggingface.co/yifeihu/TB-OCR-preview-0.1))
"""
# check out https://huggingface.co/microsoft/Phi-3.5-vision-instruct for more details
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import requests
model_id = "yifeihu/TB-OCR-preview-0.1"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
ABOUT += "\n\n### This demo is running on CPU\n\nThis demo is running on CPU, it will be very slow. Consider duplicating it or running it locally to skip the queue and for faster response times."
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=DEVICE,
trust_remote_code=True,
torch_dtype="auto",
# _attn_implementation='flash_attention_2',
#load_in_4bit=True # Optional: Load model in 4-bit mode to save memory
)
processor = AutoProcessor.from_pretrained(model_id,
trust_remote_code=True,
num_crops=16
)
def phi_ocr(image_url):
question = "Convert the text to markdown format."
image = Image.open(image_url)
prompt_message = [{
'role': 'user',
'content': f'<|image_1|>\n{question}',
}]
prompt = processor.tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, [image], return_tensors="pt").to(DEVICE)
generation_args = {
"max_new_tokens": 1024,
"temperature": 0.1,
"do_sample": False
}
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
response = response.split("<image_end>")[0] # remove the image_end token
return response
import gradio as gr
with gr.Blocks() as demo:
gr.Markdown(ABOUT)
with gr.Row():
with gr.Column():
img = gr.Image(label="Input image", type="filepath")
btn = gr.Button("OCR")
with gr.Column():
out = gr.Markdown()
btn.click(phi_ocr, inputs=img, outputs=out)
demo.queue().launch()