|
import gradio as gr |
|
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline |
|
import torch |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
checkpoint1 = "microsoft/git-base" |
|
processor = AutoProcessor.from_pretrained(checkpoint1) |
|
|
|
model1 = AutoModelForCausalLM.from_pretrained(checkpoint1) |
|
|
|
|
|
checkpoint2 = "wangjin2000/git-base-finetune" |
|
|
|
model2 = AutoModelForCausalLM.from_pretrained(checkpoint2) |
|
|
|
|
|
|
|
en_zh_translator = pipeline("translation",model="liam168/trans-opus-mt-en-zh") |
|
|
|
def img2cap_com(image): |
|
input1 = processor(images=image, return_tensors="pt").to(device) |
|
pixel_values1 = input1.pixel_values |
|
generated_id1 = model1.generate(pixel_values=pixel_values1, max_length=50) |
|
generated_caption1 = processor.batch_decode(generated_id1, skip_special_tokens=True)[0] |
|
|
|
translated_caption1 = [generated_caption1, en_zh_translator(generated_caption1)] |
|
|
|
input2 = processor(images=image, return_tensors="pt").to(device) |
|
pixel_values2 = input2.pixel_values |
|
generated_id2 = model2.generate(pixel_values=pixel_values2, max_length=50) |
|
generated_caption2 = processor.batch_decode(generated_id2, skip_special_tokens=True)[0] |
|
translated_caption2 = [generated_caption2, en_zh_translator(generated_caption2)] |
|
|
|
return translated_caption1,translated_caption2 |
|
|
|
inputs = [ |
|
gr.Image(type="pil", label="Original Image") |
|
] |
|
|
|
outputs = [ |
|
gr.Textbox(label="Caption from pre-trained model"), |
|
gr.Textbox(label="Caption from fine-tuned model"), |
|
] |
|
|
|
title = "Image Captioning using Pre-trained and Fine-tuned Model" |
|
description = "GIT-base is used to generate Image Caption for the uploaded image." |
|
|
|
examples = [ |
|
["Image1.png"], |
|
["Image2.png"], |
|
["Image3.png"], |
|
["Image4.png"], |
|
["Image5.png"], |
|
["Image6.png"] |
|
] |
|
|
|
gr.Interface( |
|
img2cap_com, |
|
inputs, |
|
outputs, |
|
title=title, |
|
description=description, |
|
examples=examples, |
|
theme="huggingface", |
|
).launch() |
|
|