Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,115 Bytes
4ffdbdc 770775f 6957169 e1361b1 6957169 ab82892 e1361b1 6957169 e1361b1 2eb2d02 6957169 ff71a38 15b745f 6957169 ff71a38 6957169 2053e3b ff71a38 6957169 c3599fc b37fd21 6957169 c3599fc 6957169 ff71a38 411af63 c004618 11afc99 25027a4 b37fd21 6957169 fbee232 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
# A100 Zero GPU
import spaces
import time
import torch
import gradio as gr
from PIL import Image
from utils.utils import *
from threading import Thread
import torch.nn.functional as F
from accelerate import Accelerator
from meteor.load_mmamba import load_mmamba
from meteor.load_meteor import load_meteor
from transformers import TextIteratorStreamer
from torchvision.transforms.functional import pil_to_tensor
# flash attention
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# accel
accel = Accelerator()
# loading meteor model
mmamba = load_mmamba('BK-Lee/Meteor-Mamba')
meteor, tok_meteor = load_meteor('BK-Lee/Meteor-MLM', bits=4)
# freeze model
freeze_model(mmamba)
freeze_model(meteor)
# previous length
previous_length = 0
def threading_function(inputs, image_token_number, streamer, device, temperature, new_max_token, top_p):
# Meteor Mamba
mmamba_inputs = mmamba.eval_process(inputs=inputs, tokenizer=tok_meteor, device=device, img_token_number=image_token_number)
if 'image' in mmamba_inputs.keys():
clip_features = meteor.clip_features(mmamba_inputs['image'])
mmamba_inputs.update({"image_features": clip_features})
mmamba_outputs = mmamba(**mmamba_inputs)
# Meteor
meteor_inputs = meteor.eval_process(inputs=inputs, data='demo', tokenizer=tok_meteor, device=device, img_token_number=image_token_number)
if 'image' in mmamba_inputs.keys():
meteor_inputs.update({"image_features": clip_features})
meteor_inputs.update({"tor_features": mmamba_outputs.tor_features})
generation_kwargs = meteor_inputs
generation_kwargs.update({'streamer': streamer})
generation_kwargs.update({'do_sample': True})
generation_kwargs.update({'max_new_tokens': new_max_token})
generation_kwargs.update({'top_p': top_p})
generation_kwargs.update({'temperature': temperature})
generation_kwargs.update({'use_cache': True})
return meteor.generate(**generation_kwargs)
@spaces.GPU
def bot_streaming(message, history, temperature, new_max_token, top_p):
try:
# param
for param in mmamba.parameters():
param.data = param.to(accel.device)
for param in meteor.parameters():
param.data = param.to(accel.device)
# prompt type -> input prompt
image_token_number = int((490/14)**2)
if len(message['files']) != 0:
# Image Load
image = F.interpolate(pil_to_tensor(Image.open(message['files'][0]).convert("RGB")).unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
inputs = [{'image': image, 'question': message['text']}]
else:
inputs = [{'question': message['text']}]
# [4] Meteor Generation
with torch.inference_mode():
# kwargs
streamer = TextIteratorStreamer(tok_meteor, skip_special_tokens=True)
# Threading generation
thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
image_token_number=image_token_number,
streamer=streamer,
device=accel.device,
temperature=temperature,
new_max_token=new_max_token,
top_p=top_p))
thread.start()
# generated text
generated_text = ""
for new_text in streamer:
generated_text += new_text
generated_text
# Text decoding
response = generated_text.split('assistant\n')[-1].split('[U')[0].strip()
except:
response = "There may be unsupported format: ex) pdf, video, sound. Only supported is single image in this version."
# private log print
text = message['text']
files = message['files']
print(f'Text: {text}')
print(f'MM Files: {files}')
buffer = ""
for character in response:
buffer += character
time.sleep(0.015)
yield buffer
demo = gr.ChatInterface(fn=bot_streaming,
additional_inputs = [gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
additional_inputs_accordion="Generation Hyperparameters",
theme=gr.themes.Soft(),
title="☄️Meteor",
description="Meteor is efficient 7B size Large Language and Vision Model built on the help of traversal of rationale.\n"
"Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity)",
stop_btn="Stop Generation", multimodal=True)
demo.launch() |