image_to_text / handler.py
virginie-d
changing xformers
b320464
raw
history blame contribute delete
No virus
5.51 kB
from typing import Dict, List, Any, Optional, Tuple, Literal
# from transformers import pipeline
import torch, PIL, triton, protobuf
from torchvision import transforms
# import torchvision, einops
# import xformers, accelerate
from transformers import AutoModelForCausalLM, LlamaTokenizer, PretrainedConfig
LANGUAGE_TOKEN_TYPE = 0
VISION_TOKEN_TYPE = 1
config = PretrainedConfig.from_json_file('config.json')
class EndpointHandler():
def __init__(self, path=""):
self.model = AutoModelForCausalLM.from_pretrained(
'THUDM/cogvlm-chat-hf',
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
# cache_dir='/tmp'
)
self.tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
# create inference pipeline
# self.pipeline = pipeline("text-generation", model="THUDM/cogvlm-chat-hf", trust_remote_code=True)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
- "label": A string representing what the label/class is. There can be multiple labels.
- "score": A score between 0 and 1 describing how confident the model is for this label/class.
"""
def _history_to_prompt(signal_type, history, query):
if signal_type == 'base':
return query
elif signal_type == 'vqa':
answer_format = 'Short answer:'
elif signal_type == 'chat':
answer_format = 'Answer:'
else:
assert False, f"Unknown signal type {signal_type}"
prompt = ''
for i, (old_query, response) in enumerate(history):
prompt += 'Question: ' + old_query + " {} ".format(answer_format) + response + "\n"
prompt += 'Question: {} {}'.format(query, answer_format)
return prompt
def build_conversation_input_ids(
tokenizer: "PreTrainedTokenizer",
*,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
images: Optional[List["PIL.Image"]] = None,
template_version: Optional[Literal["base", "chat", "vqa"]] = None,
config=config
):
image_size: int = config.vision_config['image_size']
patch_size: int = config.vision_config['patch_size']
template_version = template_version or config.template_version
assert images is None or len(images) <= 1, f"not support multi images by now."
history = history or []
text = _history_to_prompt(template_version, history, query)
input_ids = [tokenizer.bos_token_id]
token_type_ids = [LANGUAGE_TOKEN_TYPE]
if images is not None and len(images) == 1:
# vision
transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]
)
images = [transform(images[0])]
# language
vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
input_ids += [tokenizer.pad_token_id] * vision_token_num
token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
text_ids = tokenizer.encode(text, add_special_tokens=False)
input_ids += text_ids
token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
attention_mask = [1] * len(input_ids)
return {
'input_ids': torch.tensor(input_ids, dtype=torch.long),
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
'images': images,
}
inputs = data.pop("inputs", data)
query = inputs.pop("query", data)
image = inputs.pop("image", data)
gen_kwargs = {"max_length": 2048, "do_sample": False}
inputs = build_conversation_input_ids(self.tokenizer, query=query, history=[], images=[image],
template_version='vqa')
inputs = {'inputs': {
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]],
}}
# pass inputs with all kwargs in data
# prediction = self.pipeline(inputs)
outputs = self.model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
prediction = self.tokenizer.decode(outputs[0])
# post process the prediction
return prediction