Spaces:
Runtime error
Runtime error
from huggingface_hub import hf_hub_url, cached_download | |
import streamlit as st | |
import io | |
import gc | |
######################################################################################################## | |
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM | |
######################################################################################################## | |
MODEL_REPO = 'BlinkDL/clip-guided-binary-autoencoder' | |
import torch, types | |
import numpy as np | |
from PIL import Image | |
import torch.nn as nn | |
from torch.nn import functional as F | |
import torchvision as vision | |
import torchvision.transforms as transforms | |
from torchvision.transforms import functional as VF | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
class ToBinary(torch.autograd.Function): | |
def forward(ctx, x): | |
return torch.floor( | |
x + 0.5) # no need for noise when we have plenty of data | |
def backward(ctx, grad_output): | |
return grad_output.clone() # pass-through | |
class ResBlock(nn.Module): | |
def __init__(self, c_x, c_hidden): | |
super().__init__() | |
self.B0 = nn.BatchNorm2d(c_x) | |
self.C0 = nn.Conv2d(c_x, c_hidden, kernel_size=3, padding=1) | |
self.C1 = nn.Conv2d(c_hidden, c_x, kernel_size=3, padding=1) | |
self.C2 = nn.Conv2d(c_x, c_hidden, kernel_size=3, padding=1) | |
self.C3 = nn.Conv2d(c_hidden, c_x, kernel_size=3, padding=1) | |
def forward(self, x): | |
ACT = F.mish | |
x = x + self.C1(ACT(self.C0(ACT(self.B0(x))))) | |
x = x + self.C3(ACT(self.C2(x))) | |
return x | |
class REncoderSmall(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
dd = 8 | |
self.Bxx = nn.BatchNorm2d(dd * 64) | |
self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1) | |
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) | |
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) | |
self.B00 = nn.BatchNorm2d(dd * 4) | |
self.C00 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) | |
self.C01 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) | |
self.C02 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) | |
self.C03 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) | |
self.B10 = nn.BatchNorm2d(dd * 16) | |
self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) | |
self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) | |
self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) | |
self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) | |
self.B20 = nn.BatchNorm2d(dd * 64) | |
self.C20 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) | |
self.C21 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) | |
self.C22 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) | |
self.C23 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) | |
self.COUT = nn.Conv2d(dd * 64, | |
args.my_img_bit, | |
kernel_size=3, | |
padding=1) | |
def forward(self, img): | |
ACT = F.mish | |
x = self.CIN(img) | |
xx = self.Bxx(F.pixel_unshuffle(x, 8)) | |
x = x + self.Cx1(ACT(self.Cx0(x))) | |
x = F.pixel_unshuffle(x, 2) | |
x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) | |
x = x + self.C03(ACT(self.C02(x))) | |
x = F.pixel_unshuffle(x, 2) | |
x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) | |
x = x + self.C13(ACT(self.C12(x))) | |
x = F.pixel_unshuffle(x, 2) | |
x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) | |
x = x + self.C23(ACT(self.C22(x))) | |
x = self.COUT(x + xx) | |
return torch.sigmoid(x) | |
class RDecoderSmall(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
dd = 8 | |
self.CIN = nn.Conv2d(args.my_img_bit, | |
dd * 64, | |
kernel_size=3, | |
padding=1) | |
self.B00 = nn.BatchNorm2d(dd * 64) | |
self.C00 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) | |
self.C01 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) | |
self.C02 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) | |
self.C03 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) | |
self.B10 = nn.BatchNorm2d(dd * 16) | |
self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) | |
self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) | |
self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) | |
self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) | |
self.B20 = nn.BatchNorm2d(dd * 4) | |
self.C20 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) | |
self.C21 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) | |
self.C22 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) | |
self.C23 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) | |
self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) | |
self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) | |
self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1) | |
def forward(self, code): | |
ACT = F.mish | |
x = self.CIN(code) | |
x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) | |
x = x + self.C03(ACT(self.C02(x))) | |
x = F.pixel_shuffle(x, 2) | |
x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) | |
x = x + self.C13(ACT(self.C12(x))) | |
x = F.pixel_shuffle(x, 2) | |
x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) | |
x = x + self.C23(ACT(self.C22(x))) | |
x = F.pixel_shuffle(x, 2) | |
x = x + self.Cx1(ACT(self.Cx0(x))) | |
x = self.COUT(x) | |
return torch.sigmoid(x) | |
class REncoderLarge(nn.Module): | |
def __init__(self, args, dd, ee, ff): | |
super().__init__() | |
self.args = args | |
self.CXX = nn.Conv2d(3, dd, kernel_size=3, padding=1) | |
self.BXX = nn.BatchNorm2d(dd) | |
self.CX0 = nn.Conv2d(dd, ee, kernel_size=3, padding=1) | |
self.CX1 = nn.Conv2d(ee, dd, kernel_size=3, padding=1) | |
self.R0 = ResBlock(dd * 4, ff) | |
self.R1 = ResBlock(dd * 16, ff) | |
self.R2 = ResBlock(dd * 64, ff) | |
self.CZZ = nn.Conv2d(dd * 64, | |
args.my_img_bit, | |
kernel_size=3, | |
padding=1) | |
def forward(self, x): | |
ACT = F.mish | |
x = self.BXX(self.CXX(x)) | |
x = x + self.CX1(ACT(self.CX0(x))) | |
x = F.pixel_unshuffle(x, 2) | |
x = self.R0(x) | |
x = F.pixel_unshuffle(x, 2) | |
x = self.R1(x) | |
x = F.pixel_unshuffle(x, 2) | |
x = self.R2(x) | |
x = self.CZZ(x) | |
return torch.sigmoid(x) | |
class RDecoderLarge(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
if 'd16_512' in model_prefix: | |
dd, ee, ff = 16, 64, 512 | |
elif 'd32_1024' in model_prefix: | |
dd, ee, ff = 32, 128, 1024 | |
self.CZZ = nn.Conv2d(args.my_img_bit, | |
dd * 64, | |
kernel_size=3, | |
padding=1) | |
self.BZZ = nn.BatchNorm2d(dd * 64) | |
self.R0 = ResBlock(dd * 64, ff) | |
self.R1 = ResBlock(dd * 16, ff) | |
self.R2 = ResBlock(dd * 4, ff) | |
self.CX0 = nn.Conv2d(dd, ee, kernel_size=3, padding=1) | |
self.CX1 = nn.Conv2d(ee, dd, kernel_size=3, padding=1) | |
self.CXX = nn.Conv2d(dd, 3, kernel_size=3, padding=1) | |
def forward(self, x): | |
ACT = F.mish | |
x = self.BZZ(self.CZZ(x)) | |
x = self.R0(x) | |
x = F.pixel_shuffle(x, 2) | |
x = self.R1(x) | |
x = F.pixel_shuffle(x, 2) | |
x = self.R2(x) | |
x = F.pixel_shuffle(x, 2) | |
x = x + self.CX1(ACT(self.CX0(x))) | |
x = self.CXX(x) | |
return torch.sigmoid(x) | |
def prepare_model(model_prefix): | |
gc.collect() | |
if model_prefix == 'out-v7c_d8_256-224-13bit-OB32x0.5-745': | |
R_ENCODER, R_DECODER = REncoderSmall, RDecoderSmall | |
else: | |
if 'd16_512' in model_prefix: | |
dd, ee, ff = 16, 64, 512 | |
elif 'd32_1024' in model_prefix: | |
dd, ee, ff = 32, 128, 1024 | |
R_ENCODER, R_DECODER = ((lambda args: REncoderLarge(args, dd, ee, ff)), | |
(lambda args: RDecoderLarge(args, dd, ee, ff))) | |
args = types.SimpleNamespace() | |
args.my_img_bit = 13 | |
encoder = R_ENCODER(args).eval().to(device) | |
decoder = R_DECODER(args).eval().to(device) | |
zpow = torch.tensor([2**i for i in range(0, 13)]).reshape(13, 1, 1) | |
zpow = zpow.to(device).long() | |
encoder.load_state_dict( | |
torch.load( | |
cached_download(hf_hub_url(MODEL_REPO, f'{model_prefix}-E.pth')))) | |
decoder.load_state_dict( | |
torch.load( | |
cached_download(hf_hub_url(MODEL_REPO, f'{model_prefix}-D.pth')))) | |
encoder.eval() | |
decoder.eval() | |
return encoder, decoder | |
def encode(model_prefix, img): | |
encoder, _ = prepare_model(model_prefix) | |
img_transform = transforms.Compose([ | |
transforms.PILToTensor(), | |
transforms.ConvertImageDtype(torch.float), | |
transforms.Resize((224, 224)) | |
]) | |
with torch.no_grad(): | |
img = img_transform(img.convert("RGB")).unsqueeze(0).to(device) | |
z = encoder(img) | |
z = ToBinary.apply(z) | |
return z.cpu().numpy() | |
def decode(model_prefix, z): | |
_, decoder = prepare_model(model_prefix) | |
decoded = decoder(torch.Tensor(z).to(device)) | |
return VF.to_pil_image(decoded[0]) | |
st.title("clip-guided-binary-autoencoder") | |
model_prefix = st.selectbox('The model to use', | |
('out-v7c_d8_256-224-13bit-OB32x0.5-745', | |
'out-v7d_d16_512-224-13bit-OB32x0.5-2487', | |
'out-v7d_d32_1024-224-13bit-OB32x0.5-5560')) | |
encoder_tab, decoder_tab = st.tabs(["Encode", "Decode"]) | |
with encoder_tab: | |
col_in, col_out = st.columns(2) | |
uploaded_file = col_in.file_uploader('Choose an Image') | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
col_in.image(image, 'Input Image') | |
z = encode(model_prefix, image) | |
with io.BytesIO() as buffer: | |
np.save(buffer, z) | |
col_out.download_button( | |
label="Download Encoded Data", | |
data=buffer, | |
file_name=uploaded_file.name + '.npy', | |
) | |
col_out.image(decode(model_prefix, z), 'Output Image preview') | |
with decoder_tab: | |
col_in, col_out = st.columns(2) | |
uploaded_file = col_in.file_uploader('Choose an Encoded Data') | |
if uploaded_file is not None: | |
z = np.load(uploaded_file) | |
image = decode(model_prefix, z) | |
col_out.image(image, 'Output Image') | |