File size: 5,115 Bytes
4ffdbdc
 
 
770775f
6957169
 
 
 
 
 
e1361b1
6957169
 
 
 
 
ab82892
 
 
 
e1361b1
 
 
6957169
e1361b1
2eb2d02
6957169
 
 
 
 
 
 
 
ff71a38
15b745f
6957169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff71a38
 
 
6957169
 
 
2053e3b
ff71a38
6957169
c3599fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b37fd21
 
 
 
 
 
 
 
6957169
 
 
 
c3599fc
6957169
 
ff71a38
411af63
c004618
11afc99
25027a4
b37fd21
 
6957169
fbee232
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# A100 Zero GPU
import spaces

import time
import torch
import gradio as gr
from PIL import Image
from utils.utils import *
from threading import Thread
import torch.nn.functional as F
from accelerate import Accelerator
from meteor.load_mmamba import load_mmamba
from meteor.load_meteor import load_meteor
from transformers import TextIteratorStreamer
from torchvision.transforms.functional import pil_to_tensor

# flash attention
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

# accel
accel = Accelerator()

# loading meteor model
mmamba = load_mmamba('BK-Lee/Meteor-Mamba')
meteor, tok_meteor = load_meteor('BK-Lee/Meteor-MLM', bits=4)

# freeze model
freeze_model(mmamba)
freeze_model(meteor)

# previous length
previous_length = 0

def threading_function(inputs, image_token_number, streamer, device, temperature, new_max_token, top_p):

    # Meteor Mamba
    mmamba_inputs = mmamba.eval_process(inputs=inputs, tokenizer=tok_meteor, device=device, img_token_number=image_token_number)
    if 'image' in mmamba_inputs.keys():
        clip_features = meteor.clip_features(mmamba_inputs['image'])
        mmamba_inputs.update({"image_features": clip_features})
    mmamba_outputs = mmamba(**mmamba_inputs)
    
    # Meteor
    meteor_inputs = meteor.eval_process(inputs=inputs, data='demo', tokenizer=tok_meteor, device=device, img_token_number=image_token_number)
    if 'image' in mmamba_inputs.keys():
        meteor_inputs.update({"image_features": clip_features})
    meteor_inputs.update({"tor_features": mmamba_outputs.tor_features})

    generation_kwargs = meteor_inputs
    generation_kwargs.update({'streamer': streamer})
    generation_kwargs.update({'do_sample': True})
    generation_kwargs.update({'max_new_tokens': new_max_token})
    generation_kwargs.update({'top_p': top_p})
    generation_kwargs.update({'temperature': temperature})
    generation_kwargs.update({'use_cache': True})
    return meteor.generate(**generation_kwargs)

@spaces.GPU
def bot_streaming(message, history, temperature, new_max_token, top_p):

    try:
        # param
        for param in mmamba.parameters():
            param.data = param.to(accel.device)
        for param in meteor.parameters():
            param.data = param.to(accel.device)

        # prompt type -> input prompt
        image_token_number = int((490/14)**2)
        if len(message['files']) != 0:
            # Image Load
            image = F.interpolate(pil_to_tensor(Image.open(message['files'][0]).convert("RGB")).unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
            inputs = [{'image': image, 'question': message['text']}]
        else:
            inputs = [{'question': message['text']}]

        # [4] Meteor Generation
        with torch.inference_mode():
            # kwargs
            streamer = TextIteratorStreamer(tok_meteor, skip_special_tokens=True)

            # Threading generation
            thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
                                                                image_token_number=image_token_number,
                                                                streamer=streamer,
                                                                device=accel.device,
                                                                temperature=temperature,
                                                                new_max_token=new_max_token,
                                                                top_p=top_p))
            thread.start()

            # generated text
            generated_text = ""
            for new_text in streamer:
                generated_text += new_text
            generated_text

        # Text decoding
        response = generated_text.split('assistant\n')[-1].split('[U')[0].strip()
    
    except:
        response = "There may be unsupported format: ex) pdf, video, sound. Only supported is single image in this version."

    # private log print
    text = message['text']
    files = message['files']
    print(f'Text: {text}')
    print(f'MM Files: {files}')


    buffer = ""
    for character in response:
        buffer += character
        time.sleep(0.015)
        yield buffer

demo = gr.ChatInterface(fn=bot_streaming,
                        additional_inputs = [gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
                        additional_inputs_accordion="Generation Hyperparameters",
                        theme=gr.themes.Soft(),
                        title="☄️Meteor",
                        description="Meteor is efficient 7B size Large Language and Vision Model built on the help of traversal of rationale.\n"
                                    "Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity)",
                        stop_btn="Stop Generation", multimodal=True)
demo.launch()