Spaces:
Runtime error
Runtime error
import os | |
from PIL import Image | |
from io import BytesIO | |
import base64 | |
import re | |
import argparse | |
import torch | |
from transformers import AutoTokenizer | |
from sat.model.mixins import CachedAutoregressiveMixin | |
from sat.quantization.kernels import quantize | |
import hashlib | |
from .visualglm import VisualGLMModel | |
def get_infer_setting(gpu_device=0, quant=None): | |
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_device) | |
args = argparse.Namespace( | |
fp16=True, | |
skip_init=True, | |
device='cuda' if quant is None else 'cpu', | |
) | |
model, args = VisualGLMModel.from_pretrained('visualglm-6b', args) | |
model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) | |
assert quant in [None, 4, 8] | |
if quant is not None: | |
quantize(model.transformer, quant) | |
model.eval() | |
model = model.cuda() | |
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) | |
return model, tokenizer | |
def is_chinese(text): | |
zh_pattern = re.compile(u'[\u4e00-\u9fa5]+') | |
return zh_pattern.search(text) | |
def generate_input(input_text, input_image_prompt, history=[], input_para=None, image_is_encoded=True): | |
if not image_is_encoded: | |
image = input_image_prompt | |
else: | |
decoded_image = base64.b64decode(input_image_prompt) | |
image = Image.open(BytesIO(decoded_image)) | |
input_data = {'input_query': input_text, 'input_image': image, 'history': history, 'gen_kwargs': input_para} | |
return input_data | |
def process_image(image_encoded): | |
decoded_image = base64.b64decode(image_encoded) | |
image = Image.open(BytesIO(decoded_image)) | |
image_hash = hashlib.sha256(image.tobytes()).hexdigest() | |
image_path = f'./examples/{image_hash}.png' | |
if not os.path.isfile(image_path): | |
image.save(image_path) | |
return os.path.abspath(image_path) |