minDALLE / app.py
valhalla's picture
Update app.py
history blame
No virus
3.23 kB
import os
import sys
import numpy as np
import streamlit as st
from PIL import Image
import clip
from dalle.models import Dalle
from dalle.utils.utils import clip_score, download
url = "https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz"
root = os.path.expanduser("~/.cache/minDALLE")
filename = os.path.basename(url)
pathname = filename[:-len('.tar.gz')]
expected_md5 = url.split("/")[-2]
download_target = os.path.join(root, filename)
result_path = os.path.join(root, pathname)
if not os.path.exists(result_path):
result_path = download(url, root)
device = "cpu"
model = Dalle.from_pretrained(result_path) # This will automatically download the pretrained model.
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
def sample(prompt):
# Sampling
images = (
model.sampling(prompt=prompt, top_k=256, top_p=None, softmax_temperature=1.0, num_candidates=3, device=device)
images = np.transpose(images, (0, 2, 3, 1))
# CLIP Re-ranking
rank = clip_score(
prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device
# Save images
images = images[rank]
# print(rank, images.shape)
pil_images = []
for i in range(len(images)):
im = Image.fromarray((images[i] * 255).astype(np.uint8))
# im = Image.fromarray((images[0] * 255).astype(np.uint8))
return pil_images
st.subheader("Generate images from text")
prompt = st.text_input("What do you want to see?")
DEBUG = False
if prompt != "":
container = st.empty()
<style> p {{ margin:0 }} div {{ margin:0 }} </style>
<div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
<div class="stAlert">
<div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
<div class="st-b7">
<div class="css-whx05o e13vu3m50">
<div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
<img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
Generating predictions for: <b>{prompt}</b>
<small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
print(f"Getting selections: {prompt}")
selected = sample(prompt)
margin = 0.1 #for better position of zoom in arrow
n_columns = 3
cols = st.columns([1] + [margin, 1] * (n_columns - 1))
for i, img in enumerate(selected):
cols[(i % n_columns) * 2].image(img)
st.button("Again!", key="again_button")