Spaces:
Sleeping
Sleeping
import asyncio | |
import io | |
import json | |
import os | |
import sys | |
from typing import IO | |
import click | |
from PIL import Image | |
from ..bg import remove | |
from ..session_factory import new_session | |
from ..sessions import sessions_names | |
def rs_command( | |
model: str, | |
extras: str, | |
image_width: int, | |
image_height: int, | |
output_specifier: str, | |
**kwargs | |
) -> None: | |
try: | |
kwargs.update(json.loads(extras)) | |
except Exception: | |
pass | |
session = new_session(model) | |
bytes_per_img = image_width * image_height * 3 | |
if output_specifier: | |
output_dir = os.path.dirname( | |
os.path.abspath(os.path.expanduser(output_specifier)) | |
) | |
if not os.path.isdir(output_dir): | |
os.makedirs(output_dir, exist_ok=True) | |
def img_to_byte_array(img: Image) -> bytes: | |
buff = io.BytesIO() | |
img.save(buff, format="PNG") | |
return buff.getvalue() | |
async def connect_stdin_stdout(): | |
loop = asyncio.get_event_loop() | |
reader = asyncio.StreamReader() | |
protocol = asyncio.StreamReaderProtocol(reader) | |
await loop.connect_read_pipe(lambda: protocol, sys.stdin) | |
w_transport, w_protocol = await loop.connect_write_pipe( | |
asyncio.streams.FlowControlMixin, sys.stdout | |
) | |
writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop) | |
return reader, writer | |
async def main(): | |
reader, writer = await connect_stdin_stdout() | |
idx = 0 | |
while True: | |
try: | |
img_bytes = await reader.readexactly(bytes_per_img) | |
if not img_bytes: | |
break | |
img = Image.frombytes("RGB", (image_width, image_height), img_bytes) | |
output = remove(img, session=session, **kwargs) | |
if output_specifier: | |
output.save((output_specifier % idx), format="PNG") | |
else: | |
writer.write(img_to_byte_array(output)) | |
idx += 1 | |
except asyncio.IncompleteReadError: | |
break | |
asyncio.run(main()) | |