Spaces:
Runtime error
Runtime error
File size: 4,638 Bytes
a7b8ada 5a4d9e1 a7b8ada 8962d34 a7b8ada 4a16ef8 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada b71e116 4a16ef8 d7dce5e 4a16ef8 b71e116 d7dce5e b71e116 4a16ef8 8962d34 4a16ef8 8962d34 4a16ef8 616e7e7 4a16ef8 616e7e7 4a16ef8 616e7e7 4a16ef8 a7b8ada 4a16ef8 5a4d9e1 4a16ef8 a7b8ada 4a16ef8 8962d34 4a16ef8 8962d34 4a16ef8 8962d34 4a16ef8 8962d34 4a16ef8 616e7e7 b71e116 4a16ef8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import numpy as np
import random
import
torch
import torchvision.transforms as transforms
from PIL import Image
from models.tag2text import tag2text_caption, ram
import gradio as gr
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])
#######Tag2Text Model
pretrained = 'tag2text_swin_14m.pth'
model_tag2text = tag2text_caption(pretrained=pretrained, image_size=image_size, vit='swin_b' )
model_tag2text.eval()
model_tag2text = model_tag2text.to(device)
#######RAM Model
pretrained = 'ram_swin_large_14m.pth'
model_ram = ram(pretrained=pretrained, image_size=image_size, vit='swin_l' )
model_ram.eval()
model_ram = model_ram.to(device)
def inference(raw_image, model_n , input_tag):
raw_image = raw_image.resize((image_size, image_size))
image = transform(raw_image).unsqueeze(0).to(device)
if model_n == 'Recognize Anything Model':
model = model_ram
tags, tags_chinese = model.generate_tag(image)
return tags[0],tags_chinese[0], 'none'
else:
model = model_tag2text
model.threshold = 0.68
if input_tag == '' or input_tag == 'none' or input_tag == 'None':
input_tag_list = None
else:
input_tag_list = []
input_tag_list.append(input_tag.replace(',',' | '))
with torch.no_grad():
caption, tag_predict = model.generate(image,tag_input = input_tag_list,max_length = 50, return_tag_predict = True)
if input_tag_list == None:
tag_1 = tag_predict
tag_2 = ['none']
else:
_, tag_1 = model.generate(image,tag_input = None, max_length = 50, return_tag_predict = True)
tag_2 = tag_predict
return tag_1[0],'none',caption[0]
inputs = [
gr.inputs.Image(type='pil'),
gr.inputs.Radio(choices=['Recognize Anything Model',"Tag2Text Model"],
type="value",
default="Recognize Anything Model",
label="Select Model" ),
gr.inputs.Textbox(lines=2, label="User Specified Tags (Optional, Enter with commas, Currently only Tag2Text is supported)")
]
outputs = [gr.outputs.Textbox(label="Tags"),gr.outputs.Textbox(label="标签"), gr.outputs.Textbox(label="Caption (currently only Tag2Text is supported)")]
# title = "Recognize Anything Model"
title = "<font size='10'> Recognize Anything Model</font>"
description = "Welcome to the Recognize Anything Model (RAM) and Tag2Text Model demo! <li><b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese outputs of the image tags</b>!</li><li><b>Tag2Text Model:</b> Upload your image to get the <b>tags</b> and <b>caption</b> of the image. Optional: You can also input specified tags to get the corresponding caption.</li> "
article = "<p style='text-align: center'>RAM and Tag2Text is training on open-source datasets, and we are persisting in refining and iterating upon it.<br/><a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a> | <a href='https://https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a> | <a href='https://github.com/xinyu1205/Tag2Text' target='_blank'>Github Repo</a></p>"
demo = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[
['images/demo1.jpg',"Recognize Anything Model","none"],
['images/demo2.jpg',"Recognize Anything Model","none"],
['images/demo4.jpg',"Recognize Anything Model","none"],
['images/demo4.jpg',"Tag2Text Model","power line"],
['images/demo4.jpg',"Tag2Text Model","track, train"] ,
])
demo.launch(enable_queue=True) |