Spaces:
Runtime error
Runtime error
import torch | |
from sat.model.official import ChatGLMModel | |
from sat.model.base_model import BaseMixin | |
from copy import deepcopy | |
import json | |
from .blip2 import BLIP2 | |
from sat.resources.urls import MODEL_URLS | |
MODEL_URLS['visualglm-6b'] = 'https://cloud.tsinghua.edu.cn/f/348b98dffcc940b6a09d/?dl=1' | |
class ImageMixin(BaseMixin): | |
def __init__(self, args): | |
super().__init__() | |
self.args = deepcopy(args) | |
self.model = BLIP2(args.eva_args, args.qformer_args) | |
def word_embedding_forward(self, input_ids, output_cross_layer, **kw_args): | |
if kw_args["pre_image"] > input_ids.shape[1] or kw_args.get("image", None) is None: | |
return self.transformer.word_embeddings(input_ids) | |
image_emb = self.model(**kw_args) | |
# the image is inserted after 问:<img>, override 32 pads | |
pre_id, pads, post_id = torch.tensor_split(input_ids, [kw_args["pre_image"], kw_args["pre_image"]+self.args.image_length], dim=1) | |
pre_txt_emb = self.transformer.word_embeddings(pre_id) | |
post_txt_emb = self.transformer.word_embeddings(post_id) | |
return torch.cat([pre_txt_emb, image_emb, post_txt_emb], dim=1) | |
class VisualGLMModel(ChatGLMModel): | |
def __init__(self, args, transformer=None, **kwargs): | |
super().__init__(args, transformer=transformer, **kwargs) | |
self.image_length = args.image_length | |
self.add_mixin("eva", ImageMixin(args)) | |
def add_model_specific_args(cls, parser): | |
group = parser.add_argument_group('VisualGLM', 'VisualGLM Configurations') | |
group.add_argument('--image_length', type=int, default=32) | |
group.add_argument('--eva_args', type=json.loads, default={}) | |
group.add_argument('--qformer_args', type=json.loads, default={}) | |
return super().add_model_specific_args(parser) | |