import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "Nechba/Coin-Generative-Recognition"
TITLE = f'
🚀 Coin Generative Recognition'
DESCRIPTION = f"""
A Space for Vision/Multimodal
✨ Tips: Send messages or upload multiple IMAGES at a time.
✨ Tips: Please increase MAX LENGTH when dealing with files.
🤙 Supported Format: png, jpg, webp
🙇♂️ May be rebuilding from time to time.
"""
CSS = """
h1 {
text-align: center;
display: block;
}
img {
max-width: 100%; /* Make sure images are not wider than their container */
height: auto; /* Maintain aspect ratio */
max-height: 300px; /* Limit the height of images */
}
"""
import os
# Directory where the model and tokenizer will be saved
# Load model directly
from transformers import AutoModel
model = AutoModel.from_pretrained("Nechba/Coin-Generative-Recognition", trust_remote_code=True).to(0)
# model = AutoModelForCausalLM.from_pretrained(
# MODEL_ID,
# torch_dtype=torch.bfloat16,
# low_cpu_mem_usage=True,
# trust_remote_code=True
# ).to(0)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model.eval()
def merge_images(paths):
images = [Image.open(path).convert('RGB') for path in paths]
widths, heights = zip(*(i.size for i in images))
total_width = sum(widths)
max_height = max(heights)
new_im = Image.new('RGB', (total_width, max_height))
x_offset = 0
for im in images:
new_im.paste(im, (x_offset,0))
x_offset += im.width
return new_im
def mode_load(paths):
if all(path.lower().endswith(('png', 'jpg', 'jpeg', 'webp')) for path in paths):
content = merge_images(paths)
choice = "image"
return choice, content
else:
raise gr.Error("Unsupported file types. Please upload only images.")
@spaces.GPU()
def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
conversation = []
if message["files"]:
choice, contents = mode_load(message["files"])
conversation.append({"role": "user", "image": contents, "content": message['text']})
elif message["files"] and len(message["files"]) == 1:
content = Image.open( message["files"][-1]).convert('RGB')
choice = "image"
conversation.append({"role": "user", "image": content, "content": message['text']})
else:
raise gr.Error("Please upload one or more images.")
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
max_length=max_length,
streamer=streamer,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=penalty,
eos_token_id=[151329, 151336, 151338],
)
gen_kwargs = {**input_ids, **generate_kwargs}
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(label="Chatbox", height=600, placeholder=DESCRIPTION)
chat_input = gr.MultimodalTextbox(
interactive=True,
placeholder="Enter message or upload images...",
show_label=False,
file_count="multiple",
)
EXAMPLES = [
[{"text": "Give me Country,Denomination and year as json format.", "files": ["./135_back.jpg", "./135_front.jpg"]}],
[{"text": "Give me Country,Denomination and year as json format.", "files": ["./141_back.jpg","./141_front.jpg"]}]
]
with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
gr.HTML(TITLE)
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=1024,
maximum=8192,
step=1,
value=4096,
label="Max Length",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="top_p",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=10,
label="top_k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
label="Repetition penalty",
render=False,
),
],
),
gr.Examples(EXAMPLES, [chat_input])
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False)