|
import io |
|
import os |
|
import time |
|
from pathlib import Path |
|
|
|
import requests |
|
from PIL import Image |
|
|
|
API_ENDPOINT = "https://api.bfl.ml" |
|
|
|
|
|
class ApiException(Exception): |
|
def __init__(self, status_code: int, detail: str = None): |
|
super().__init__() |
|
self.detail = detail |
|
self.status_code = status_code |
|
|
|
def __str__(self) -> str: |
|
return self.__repr__() |
|
|
|
def __repr__(self) -> str: |
|
if self.detail is None: |
|
message = None |
|
elif isinstance(self.detail, str): |
|
message = self.detail |
|
else: |
|
message = "[" + ",".join(d["msg"] for d in self.detail) + "]" |
|
return f"ApiException({self.status_code=}, {message=}, detail={self.detail})" |
|
|
|
|
|
class ImageRequest: |
|
def __init__( |
|
self, |
|
prompt: str, |
|
width: int = 1024, |
|
height: int = 1024, |
|
name: str = "flux.1-pro", |
|
num_steps: int = 50, |
|
prompt_upsampling: bool = False, |
|
seed: int = None, |
|
validate: bool = True, |
|
launch: bool = True, |
|
api_key: str = None, |
|
): |
|
""" |
|
Manages an image generation request to the API. |
|
|
|
Args: |
|
prompt: Prompt to sample |
|
width: Width of the image in pixel |
|
height: Height of the image in pixel |
|
name: Name of the model |
|
num_steps: Number of network evaluations |
|
prompt_upsampling: Use prompt upsampling |
|
seed: Fix the generation seed |
|
validate: Run input validation |
|
launch: Directly launches request |
|
api_key: Your API key if not provided by the environment |
|
|
|
Raises: |
|
ValueError: For invalid input |
|
ApiException: For errors raised from the API |
|
""" |
|
if validate: |
|
if name not in ["flux.1-pro"]: |
|
raise ValueError(f"Invalid model {name}") |
|
elif width % 32 != 0: |
|
raise ValueError(f"width must be divisible by 32, got {width}") |
|
elif not (256 <= width <= 1440): |
|
raise ValueError(f"width must be between 256 and 1440, got {width}") |
|
elif height % 32 != 0: |
|
raise ValueError(f"height must be divisible by 32, got {height}") |
|
elif not (256 <= height <= 1440): |
|
raise ValueError(f"height must be between 256 and 1440, got {height}") |
|
elif not (1 <= num_steps <= 50): |
|
raise ValueError(f"steps must be between 1 and 50, got {num_steps}") |
|
|
|
self.request_json = { |
|
"prompt": prompt, |
|
"width": width, |
|
"height": height, |
|
"variant": name, |
|
"steps": num_steps, |
|
"prompt_upsampling": prompt_upsampling, |
|
} |
|
if seed is not None: |
|
self.request_json["seed"] = seed |
|
|
|
self.request_id: str = None |
|
self.result: dict = None |
|
self._image_bytes: bytes = None |
|
self._url: str = None |
|
if api_key is None: |
|
self.api_key = os.environ.get("BFL_API_KEY") |
|
else: |
|
self.api_key = api_key |
|
|
|
if launch: |
|
self.request() |
|
|
|
def request(self): |
|
""" |
|
Request to generate the image. |
|
""" |
|
if self.request_id is not None: |
|
return |
|
response = requests.post( |
|
f"{API_ENDPOINT}/v1/image", |
|
headers={ |
|
"accept": "application/json", |
|
"x-key": self.api_key, |
|
"Content-Type": "application/json", |
|
}, |
|
json=self.request_json, |
|
) |
|
result = response.json() |
|
if response.status_code != 200: |
|
raise ApiException(status_code=response.status_code, detail=result.get("detail")) |
|
self.request_id = response.json()["id"] |
|
|
|
def retrieve(self) -> dict: |
|
""" |
|
Wait for the generation to finish and retrieve response. |
|
""" |
|
if self.request_id is None: |
|
self.request() |
|
while self.result is None: |
|
response = requests.get( |
|
f"{API_ENDPOINT}/v1/get_result", |
|
headers={ |
|
"accept": "application/json", |
|
"x-key": self.api_key, |
|
}, |
|
params={ |
|
"id": self.request_id, |
|
}, |
|
) |
|
result = response.json() |
|
if "status" not in result: |
|
raise ApiException(status_code=response.status_code, detail=result.get("detail")) |
|
elif result["status"] == "Ready": |
|
self.result = result["result"] |
|
elif result["status"] == "Pending": |
|
time.sleep(0.5) |
|
else: |
|
raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'") |
|
return self.result |
|
|
|
@property |
|
def bytes(self) -> bytes: |
|
""" |
|
Generated image as bytes. |
|
""" |
|
if self._image_bytes is None: |
|
response = requests.get(self.url) |
|
if response.status_code == 200: |
|
self._image_bytes = response.content |
|
else: |
|
raise ApiException(status_code=response.status_code) |
|
return self._image_bytes |
|
|
|
@property |
|
def url(self) -> str: |
|
""" |
|
Public url to retrieve the image from |
|
""" |
|
if self._url is None: |
|
result = self.retrieve() |
|
self._url = result["sample"] |
|
return self._url |
|
|
|
@property |
|
def image(self) -> Image.Image: |
|
""" |
|
Load the image as a PIL Image |
|
""" |
|
return Image.open(io.BytesIO(self.bytes)) |
|
|
|
def save(self, path: str): |
|
""" |
|
Save the generated image to a local path |
|
""" |
|
suffix = Path(self.url).suffix |
|
if not path.endswith(suffix): |
|
path = path + suffix |
|
Path(path).resolve().parent.mkdir(parents=True, exist_ok=True) |
|
with open(path, "wb") as file: |
|
file.write(self.bytes) |
|
|
|
|
|
if __name__ == "__main__": |
|
from fire import Fire |
|
|
|
Fire(ImageRequest) |
|
|