Spaces:
Running
on
Zero
Running
on
Zero
ABOUT = """ | |
# TB-OCR Preview 0.1 Unofficial Demo | |
This is an unofficial demo of [yifeihu/TB-OCR-preview-0.1](https://huggingface.co/yifeihu/TB-OCR-preview-0.1). | |
Overview of TB-OCR: | |
> TB-OCR-preview (Text Block OCR), created by [Yifei Hu](https://x.com/hu_yifei), is an end-to-end OCR model handling text, math latex, and markdown formats all at once. The model takes a block of text as the input and returns clean markdown output. Headers are marked with `##`. Math expressions are guaranteed to be wrapped in brackets `\( inline math \) \[ display math \]` for easier parsing. This model does not require line-detection or math formula detection. | |
(From the [model card](https://huggingface.co/yifeihu/TB-OCR-preview-0.1)) | |
""" | |
# check out https://huggingface.co/microsoft/Phi-3.5-vision-instruct for more details | |
import torch | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
from PIL import Image | |
import requests | |
model_id = "yifeihu/TB-OCR-preview-0.1" | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if not torch.cuda.is_available(): | |
ABOUT += "\n\n### This demo is running on CPU\n\nThis demo is running on CPU, it will be very slow. Consider duplicating it or running it locally to skip the queue and for faster response times." | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map=DEVICE, | |
trust_remote_code=True, | |
torch_dtype="auto", | |
# _attn_implementation='flash_attention_2', | |
#load_in_4bit=True # Optional: Load model in 4-bit mode to save memory | |
) | |
processor = AutoProcessor.from_pretrained(model_id, | |
trust_remote_code=True, | |
num_crops=16 | |
) | |
def phi_ocr(image_url): | |
question = "Convert the text to markdown format." | |
image = Image.open(image_url) | |
prompt_message = [{ | |
'role': 'user', | |
'content': f'<|image_1|>\n{question}', | |
}] | |
prompt = processor.tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True) | |
inputs = processor(prompt, [image], return_tensors="pt").to(DEVICE) | |
generation_args = { | |
"max_new_tokens": 1024, | |
"temperature": 0.1, | |
"do_sample": False | |
} | |
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args | |
) | |
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] | |
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
response = response.split("<image_end>")[0] # remove the image_end token | |
return response | |
import gradio as gr | |
with gr.Blocks() as demo: | |
gr.Markdown(ABOUT) | |
with gr.Row(): | |
with gr.Column(): | |
img = gr.Image(label="Input image", type="filepath") | |
btn = gr.Button("OCR") | |
with gr.Column(): | |
out = gr.Markdown() | |
btn.click(phi_ocr, inputs=img, outputs=out) | |
demo.queue().launch() |