moondream1 / MyPipe.py
not-lain's picture
add pipeline
73bb7ab
raw
history blame
1.17 kB
from transformers.pipelines import PIPELINE_REGISTRY
from transformers import Pipeline, AutoModelForCausalLM, CodeGenTokenizerFast as Tokenizer
from PIL import Image
from typing import Union
class VQA(Pipeline):
def __init__(self,**kwargs):
# kwargs["trust_remote_code"]=True # custom architecture
Pipeline.__init__(self,**kwargs)
self.tokenizer = Tokenizer.from_pretrained("vikhyatk/moondream1")
def _sanitize_parameters(self, **kwargs):
# preprocess_params = {}
process = {}
# if "image" in kwargs :
# preprocess_params["image"] = kwargs["image"]
if "question" in kwargs :
process["question"] = kwargs["question"]
return {}, process, {}
def preprocess(self, inputs:Union[str, Image.Image]):
if isinstance(inputs,str) :
return Image.open(inputs)
else:
return inputs
def _forward(self, inputs,question):
enc_image = self.model.encode_image(inputs)
out = self.model.answer_question(enc_image, question, self.tokenizer)
return out
def postprocess(self, out):
return out