lnyan's picture
Update
628d57a
import gradio as gr
import torch
from torch import Tensor, nn
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
T5Tokenizer)
import spaces
import numpy as np
import io
import base64
class HFEmbedder(nn.Module):
def __init__(self, version: str, max_length: int, **hf_kwargs):
super().__init__()
self.is_clip = version.startswith("openai")
self.max_length = max_length
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
if self.is_clip:
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
else:
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str]) -> Tensor:
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
return outputs[self.output_key]
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
return HFEmbedder("lnyan/t5-v1_1-xxl-encoder", max_length=max_length, torch_dtype=torch.bfloat16).to("cuda")
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to("cuda")
# @spaces.GPU(duration=20)
def load_encoders():
is_schnell = True
t5 = load_t5("cuda", max_length=256 if is_schnell else 512)
clip = load_clip("cuda")
return t5, clip
import numpy as np
def b64(txt,vec):
buffer = io.BytesIO()
torch.save((txt,vec), buffer)
buffer.seek(0)
encoded = base64.b64encode(buffer.getvalue()).decode('utf-8')
return encoded
t5,clip=load_encoders()
@spaces.GPU(duration=10)
def convert(prompt):
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
vec = clip(prompt)
return b64(txt,vec)
with gr.Blocks() as demo:
gr.Markdown("""A workaround for flux-flax to fit into 40G VRAM""")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="prompt")
convert_btn = gr.Button(value="Convert")
with gr.Column():
output = gr.Textbox(label="output")
convert_btn.click(convert, inputs=prompt, outputs=output, api_name="convert")
demo.launch()