import os import json import base64 import requests import numpy as np import matplotlib.pyplot as plt # MAX_LEN = 40 # STEP = 2 # x = np.arange(0, MAX_LEN, STEP) # token_counts = [0] * (MAX_LEN//STEP) # with open("prompts.json", 'r') as f: # prompts = json.load(f) # for prompt in prompts: # tokens = len(prompt.strip().split(' ')) # token_counts[min(tokens//STEP, MAX_LEN//STEP-1)] += 1 # plt.xticks(x, x+1) # plt.xlabel("token counts") # plt.bar(x, token_counts, width=1.3) # # plt.show() # plt.savefig("token_counts.png") ## Generate image prompts with open("prompts.json") as f: text_prompts = json.load(f) engine_id = "stable-diffusion-v1-6" api_host = os.getenv('API_HOST', 'https://api.stability.ai') api_key = os.getenv("STABILITY_API_KEY", "sk-ZvoFiXEbln6yh0hvSlm1K60WYcWFY5rmyW8a9FgoVBrKKP9N") if api_key is None: raise Exception("Missing Stability API key.") for idx, text in enumerate(text_prompts): if idx<=20: continue print(f"Start generate prompt[{idx}]: {text}") response = requests.post( f"{api_host}/v1/generation/{engine_id}/text-to-image", headers={ "Content-Type": "application/json", "Accept": "application/json", "Authorization": f"Bearer {api_key}" }, json={ "text_prompts": [ { "text": text.strip() } ], "cfg_scale": 7, "height": 1024, "width": 1024, "samples": 3, "steps": 30, }, ) if response.status_code != 200: # raise Exception("Non-200 response: " + str(response.text)) print(f"{idx} Failed!!! {str(response.text)}") continue print("Finished!") data = response.json() for i, image in enumerate(data["artifacts"]): img_path = f"./images/{idx}/v1_txt2img_{i}.png" os.makedirs(os.path.dirname(img_path), exist_ok=True) with open(img_path, "wb") as f: f.write(base64.b64decode(image["base64"]))