|
from typing import Dict, List, Any, Optional, Tuple, Literal |
|
|
|
import torch, PIL, triton, protobuf |
|
from torchvision import transforms |
|
|
|
|
|
from transformers import AutoModelForCausalLM, LlamaTokenizer, PretrainedConfig |
|
|
|
LANGUAGE_TOKEN_TYPE = 0 |
|
VISION_TOKEN_TYPE = 1 |
|
config = PretrainedConfig.from_json_file('config.json') |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
'THUDM/cogvlm-chat-hf', |
|
torch_dtype=torch.bfloat16, |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True, |
|
|
|
) |
|
self.tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5') |
|
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
Args: |
|
data (:obj:): |
|
includes the input data and the parameters for the inference. |
|
Return: |
|
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing : |
|
- "label": A string representing what the label/class is. There can be multiple labels. |
|
- "score": A score between 0 and 1 describing how confident the model is for this label/class. |
|
""" |
|
|
|
def _history_to_prompt(signal_type, history, query): |
|
if signal_type == 'base': |
|
return query |
|
elif signal_type == 'vqa': |
|
answer_format = 'Short answer:' |
|
elif signal_type == 'chat': |
|
answer_format = 'Answer:' |
|
else: |
|
assert False, f"Unknown signal type {signal_type}" |
|
|
|
prompt = '' |
|
for i, (old_query, response) in enumerate(history): |
|
prompt += 'Question: ' + old_query + " {} ".format(answer_format) + response + "\n" |
|
prompt += 'Question: {} {}'.format(query, answer_format) |
|
return prompt |
|
|
|
def build_conversation_input_ids( |
|
tokenizer: "PreTrainedTokenizer", |
|
*, |
|
query: str, |
|
history: Optional[List[Tuple[str, str]]] = None, |
|
images: Optional[List["PIL.Image"]] = None, |
|
template_version: Optional[Literal["base", "chat", "vqa"]] = None, |
|
config=config |
|
): |
|
image_size: int = config.vision_config['image_size'] |
|
patch_size: int = config.vision_config['patch_size'] |
|
template_version = template_version or config.template_version |
|
assert images is None or len(images) <= 1, f"not support multi images by now." |
|
history = history or [] |
|
text = _history_to_prompt(template_version, history, query) |
|
|
|
input_ids = [tokenizer.bos_token_id] |
|
token_type_ids = [LANGUAGE_TOKEN_TYPE] |
|
if images is not None and len(images) == 1: |
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
(image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC |
|
), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
|
] |
|
) |
|
images = [transform(images[0])] |
|
|
|
vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2 |
|
input_ids += [tokenizer.pad_token_id] * vision_token_num |
|
token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num |
|
text_ids = tokenizer.encode(text, add_special_tokens=False) |
|
|
|
input_ids += text_ids |
|
token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids) |
|
attention_mask = [1] * len(input_ids) |
|
|
|
return { |
|
'input_ids': torch.tensor(input_ids, dtype=torch.long), |
|
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long), |
|
'attention_mask': torch.tensor(attention_mask, dtype=torch.long), |
|
'images': images, |
|
} |
|
|
|
inputs = data.pop("inputs", data) |
|
query = inputs.pop("query", data) |
|
image = inputs.pop("image", data) |
|
|
|
gen_kwargs = {"max_length": 2048, "do_sample": False} |
|
|
|
inputs = build_conversation_input_ids(self.tokenizer, query=query, history=[], images=[image], |
|
template_version='vqa') |
|
|
|
inputs = {'inputs': { |
|
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'), |
|
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'), |
|
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'), |
|
'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]], |
|
}} |
|
|
|
|
|
|
|
|
|
outputs = self.model.generate(**inputs, **gen_kwargs) |
|
outputs = outputs[:, inputs['input_ids'].shape[1]:] |
|
prediction = self.tokenizer.decode(outputs[0]) |
|
|
|
|
|
return prediction |
|
|