3DGen-Arena / offline /utils.py
ZhangYuhan's picture
add serve
7c1eee1
raw
history blame
No virus
2.06 kB
import os
import json
import base64
import requests
import numpy as np
import matplotlib.pyplot as plt
# MAX_LEN = 40
# STEP = 2
# x = np.arange(0, MAX_LEN, STEP)
# token_counts = [0] * (MAX_LEN//STEP)
# with open("prompts.json", 'r') as f:
# prompts = json.load(f)
# for prompt in prompts:
# tokens = len(prompt.strip().split(' '))
# token_counts[min(tokens//STEP, MAX_LEN//STEP-1)] += 1
# plt.xticks(x, x+1)
# plt.xlabel("token counts")
# plt.bar(x, token_counts, width=1.3)
# # plt.show()
# plt.savefig("token_counts.png")
## Generate image prompts
with open("prompts.json") as f:
text_prompts = json.load(f)
engine_id = "stable-diffusion-v1-6"
api_host = os.getenv('API_HOST', 'https://api.stability.ai')
api_key = os.getenv("STABILITY_API_KEY", "sk-ZvoFiXEbln6yh0hvSlm1K60WYcWFY5rmyW8a9FgoVBrKKP9N")
if api_key is None:
raise Exception("Missing Stability API key.")
for idx, text in enumerate(text_prompts):
if idx<=20: continue
print(f"Start generate prompt[{idx}]: {text}")
response = requests.post(
f"{api_host}/v1/generation/{engine_id}/text-to-image",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key}"
},
json={
"text_prompts": [
{
"text": text.strip()
}
],
"cfg_scale": 7,
"height": 1024,
"width": 1024,
"samples": 3,
"steps": 30,
},
)
if response.status_code != 200:
# raise Exception("Non-200 response: " + str(response.text))
print(f"{idx} Failed!!! {str(response.text)}")
continue
print("Finished!")
data = response.json()
for i, image in enumerate(data["artifacts"]):
img_path = f"./images/{idx}/v1_txt2img_{i}.png"
os.makedirs(os.path.dirname(img_path), exist_ok=True)
with open(img_path, "wb") as f:
f.write(base64.b64decode(image["base64"]))