File size: 5,514 Bytes
b320464
 
 
 
 
 
 
380cd6b
b320464
 
 
380cd6b
 
 
 
 
 
 
 
 
 
 
 
b320464
380cd6b
 
 
 
 
 
 
 
 
 
 
b320464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380cd6b
b320464
 
 
380cd6b
 
b320464
 
 
 
 
 
 
 
 
 
380cd6b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from typing import Dict, List, Any, Optional, Tuple, Literal
# from transformers import pipeline
import torch, PIL, triton, protobuf
from torchvision import transforms
# import torchvision, einops
# import xformers, accelerate
from transformers import AutoModelForCausalLM, LlamaTokenizer, PretrainedConfig

LANGUAGE_TOKEN_TYPE = 0
VISION_TOKEN_TYPE = 1
config = PretrainedConfig.from_json_file('config.json')

class EndpointHandler():
    def __init__(self, path=""):
        self.model = AutoModelForCausalLM.from_pretrained(
            'THUDM/cogvlm-chat-hf',
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            trust_remote_code=True,
            #   cache_dir='/tmp'
        )
        self.tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
        # create inference pipeline
        # self.pipeline = pipeline("text-generation", model="THUDM/cogvlm-chat-hf", trust_remote_code=True)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
         Args:
             data (:obj:):
                 includes the input data and the parameters for the inference.
         Return:
             A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
                 - "label": A string representing what the label/class is. There can be multiple labels.
                 - "score": A score between 0 and 1 describing how confident the model is for this label/class.
         """

        def _history_to_prompt(signal_type, history, query):
            if signal_type == 'base':
                return query
            elif signal_type == 'vqa':
                answer_format = 'Short answer:'
            elif signal_type == 'chat':
                answer_format = 'Answer:'
            else:
                assert False, f"Unknown signal type {signal_type}"

            prompt = ''
            for i, (old_query, response) in enumerate(history):
                prompt += 'Question: ' + old_query + " {} ".format(answer_format) + response + "\n"
            prompt += 'Question: {} {}'.format(query, answer_format)
            return prompt

        def build_conversation_input_ids(
                tokenizer: "PreTrainedTokenizer",
                *,
                query: str,
                history: Optional[List[Tuple[str, str]]] = None,
                images: Optional[List["PIL.Image"]] = None,
                template_version: Optional[Literal["base", "chat", "vqa"]] = None,
                config=config
        ):
            image_size: int = config.vision_config['image_size']
            patch_size: int = config.vision_config['patch_size']
            template_version = template_version or config.template_version
            assert images is None or len(images) <= 1, f"not support multi images by now."
            history = history or []
            text = _history_to_prompt(template_version, history, query)

            input_ids = [tokenizer.bos_token_id]
            token_type_ids = [LANGUAGE_TOKEN_TYPE]
            if images is not None and len(images) == 1:
                # vision
                transform = transforms.Compose(
                    [
                        transforms.Resize(
                            (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
                        ),
                        transforms.ToTensor(),
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
                    ]
                )
                images = [transform(images[0])]
                # language
                vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
                input_ids += [tokenizer.pad_token_id] * vision_token_num
                token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
            text_ids = tokenizer.encode(text, add_special_tokens=False)

            input_ids += text_ids
            token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
            attention_mask = [1] * len(input_ids)

            return {
                'input_ids': torch.tensor(input_ids, dtype=torch.long),
                'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
                'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
                'images': images,
            }

        inputs = data.pop("inputs", data)
        query = inputs.pop("query", data)
        image = inputs.pop("image", data)

        gen_kwargs = {"max_length": 2048, "do_sample": False}

        inputs = build_conversation_input_ids(self.tokenizer, query=query, history=[], images=[image],
                                              template_version='vqa')

        inputs = {'inputs': {
            'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
            'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
            'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
            'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]],
        }}

        # pass inputs with all kwargs in data
        # prediction = self.pipeline(inputs)

        outputs = self.model.generate(**inputs, **gen_kwargs)
        outputs = outputs[:, inputs['input_ids'].shape[1]:]
        prediction = self.tokenizer.decode(outputs[0])

        # post process the prediction
        return prediction