Spaces:
Running
on
Zero
Running
on
Zero
import socketio | |
import requests | |
import json | |
import time | |
import random | |
import base64 | |
import io | |
import PIL | |
from PIL import Image | |
from io import BytesIO | |
import gradio as gr | |
from requests_toolbelt.multipart.encoder import MultipartEncoder | |
from constant import * | |
def login(email, password): | |
payload = {'password': password} | |
if email: | |
payload['email'] = email | |
response = requests.post(f"{BASE_URL}/user/login", json=payload) | |
try: | |
response_data = response.json() | |
except json.JSONDecodeError as e: | |
log("ERROR", f"Error in login: {response}") | |
raise e | |
if 'error' in response_data and response_data['error']: | |
raise Exception(response_data['error']) | |
log("INFO", f"Logged successfully") | |
user_uuid = response_data['user_uuid'] | |
token = response_data['token'] | |
return user_uuid, token | |
def rodin_history(task_uuid, token): | |
headers = { | |
'Authorization': f'Bearer {token}' | |
} | |
response = requests.post(f"{BASE_URL}/task/rodin_history", data={"uuid": task_uuid}, headers=headers) | |
return response.json() | |
def rodin_preprocess_image(generate_prompt, image, name, token): | |
m = MultipartEncoder( | |
fields={ | |
'generate_prompt': "true" if generate_prompt else "false", | |
'images': (name, image, 'image/jpeg') | |
} | |
) | |
headers = { | |
'Content-Type': m.content_type, | |
'Authorization': f'Bearer {token}' | |
} | |
response = requests.post(f"{BASE_URL}/task/rodin_mesh_image_process", data=m, headers=headers) | |
return response.json() | |
def crop_image(image, type): | |
if image == None: | |
raise gr.Error("Please generate the object first") | |
new_image_width = 360 * (11520 // 720) # 每隔720像素裁切一次,每次裁切宽度为360 | |
new_image_height = 360 # 新图片的高度 | |
new_image = Image.new('RGB', (new_image_width, new_image_height)) | |
for i in range(11520 // 720): | |
left = i * 720 + type[1] | |
upper = type[0] | |
right = left + 360 | |
lower = upper + 360 | |
cropped_image = image.crop((left, upper, right, lower)) | |
new_image.paste(cropped_image, (i * 360, 0)) | |
return new_image | |
# Perform Rodin mesh operation | |
def rodin_mesh(prompt, group_uuid, settings, images, name, token): | |
images = [convert_base64_to_binary(img) for img in images] | |
m = MultipartEncoder( | |
fields={ | |
'prompt': prompt, | |
'group_uuid': group_uuid, | |
'settings': json.dumps(settings), # Convert settings dictionary to JSON string | |
**{f'images': (name, image, 'image/jpeg') for i, image in enumerate(images)} | |
} | |
) | |
headers = { | |
'Content-Type': m.content_type, | |
'Authorization': f'Bearer {token}' | |
} | |
response = requests.post(f"{BASE_URL}/task/rodin_mesh", data=m, headers=headers) | |
return response.json() | |
# Convert base64 to binary since the result from `rodin_preprocess_image` is encoded with base64 | |
def convert_base64_to_binary(base64_string): | |
if ',' in base64_string: | |
base64_string = base64_string.split(',')[1] | |
image_data = base64.b64decode(base64_string) | |
image_buffer = io.BytesIO(image_data) | |
return image_buffer | |
def rodin_update(prompt, task_uuid, token, settings): | |
headers = { | |
'Authorization': f'Bearer {token}' | |
} | |
response = requests.post(f"{BASE_URL}/task/rodin_update", data={"uuid": task_uuid, "prompt": prompt, "settings": settings}, headers=headers) | |
return response.json() | |
def load_image(img_path): | |
try: | |
image = Image.open(img_path) | |
except PIL.UnidentifiedImageError as e: | |
raise gr.Error("Invalid Image Format") | |
# 按比例缩小图像到长度为1024 | |
width, height = image.size | |
if width > height: | |
scale = 1024 / width | |
else: | |
scale = 1024 / height | |
new_width = int(width * scale) | |
new_height = int(height * scale) | |
resized_image = image.resize((new_width, new_height)) | |
# 将 PIL.Image 对象转换为字节流 | |
byte_io = BytesIO() | |
resized_image.save(byte_io, format='PNG') | |
image_bytes = byte_io.getvalue() | |
return image_bytes | |
def log(level, info_text): | |
print(f"[ {level} ] - {time.strftime('%Y%m%d_%H:%M:%S', time.localtime())} - {info_text}") | |
class Generator: | |
def __init__(self, user_id, password, token) -> None: | |
# _, self.token = login(user_id, password) | |
self.token = token | |
self.user_id = user_id | |
self.password = password | |
self.task_uuid = None | |
self.processed_image = None | |
def preprocess(self, prompt, image_path, processed_image , task_uuid=""): | |
if processed_image and prompt and (not task_uuid): | |
log("INFO", "Using cached image and prompt...") | |
return prompt, processed_image | |
log("INFO", "Preprocessing image...") | |
success = False | |
while not success: | |
image_file = load_image(image_path) | |
log("INFO", "Image loaded, processing...") | |
if prompt and task_uuid: | |
preprocess_response = rodin_preprocess_image(generate_prompt=False, image=image_file, name=os.path.basename(image_path), token=self.token) | |
else: | |
preprocess_response = rodin_preprocess_image(generate_prompt=True, image=image_file, name=os.path.basename(image_path), token=self.token) | |
log("INFO", f"Image preprocessed: {preprocess_response.get('statusCode')}") | |
if 'error' in preprocess_response: | |
log("ERROR", f"Error in image preprocessing: {preprocess_response['error']}") | |
raise RuntimeError | |
elif preprocess_response.get("statusCode") == 401: | |
log("WARNING", "Token expired. Logging in again...") | |
_, self.token = login(self.user_id, self.password) | |
continue | |
else: | |
try: | |
if not (prompt and task_uuid): | |
prompt = preprocess_response.get('prompt', None) | |
processed_image = "data:image/png;base64," + preprocess_response.get('processed_image', None) | |
success = True | |
except Exception as e: | |
log("ERROR", f"Error in image preprocessing: {preprocess_response}") | |
raise gr.Error("Busy connection, please try again later.") | |
return prompt, processed_image | |
def generate_mesh(self, prompt, processed_image, task_uuid=""): | |
log("INFO", "Generating mesh...") | |
if task_uuid == "": | |
settings = {'view_weights': [1]} # Define weights as per your requirements, for multiple images, use multiple values, e,g [0.5, 0.5] | |
images = [processed_image] # List of images, all the images should be processed first | |
mesh_response = rodin_mesh(prompt=prompt, group_uuid=None, settings=settings, images=images, name="images.jpeg", token=self.token) | |
progress_checker = JobStatusChecker(BASE_URL, mesh_response['job']['subscription_key']) | |
try: | |
progress_checker.start() | |
except Exception as e: | |
log("ERROR", f"Error in generating mesh: {e}") | |
time.sleep(5) | |
task_uuid = mesh_response['uuid'] # The task_uuid should be same during whole generation process | |
else: | |
new_prompt = prompt | |
settings = { | |
"view_weights": [1], | |
"seed": random.randint(0, 10000), # Customize your seed here | |
"escore": 5.5, # Temprature | |
} | |
update_response = rodin_update(new_prompt, task_uuid, self.token, settings) | |
# Check progress | |
subscription_key = update_response['job']['subscription_key'] | |
checker = JobStatusChecker(BASE_URL, subscription_key) | |
try: | |
checker.start() | |
except Exception as e: | |
log("ERROR", f"Error in updating mesh: {e}") | |
time.sleep(5) | |
try: | |
history = rodin_history(task_uuid, self.token) | |
preview_image = next(reversed(history.items()))[1]["preview_image"] | |
except Exception as e: | |
log("ERROR", f"Error in generating mesh: {history}") | |
raise gr.Error("Busy connection, please try again later.") | |
response = requests.get(preview_image, stream=True) | |
if response.status_code == 200: | |
# 创建一个PIL Image对象 | |
image = Image.open(response.raw) | |
# 在这里对image对象进行处理,如显示、保存等 | |
else: | |
log("ERROR", f"Error in generating mesh: {response}") | |
raise RuntimeError | |
response.close() | |
return image, task_uuid, crop_image(image, DEFAULT) | |
class JobStatusChecker: | |
def __init__(self, base_url, subscription_key): | |
self.base_url = base_url | |
self.subscription_key = subscription_key | |
self.sio = socketio.Client(logger=True, engineio_logger=True) | |
def connect(): | |
print("Connected to the server.") | |
def disconnect(): | |
print("Disconnected from server.") | |
def message(*args, **kwargs): | |
if len(args) > 2: | |
data = args[2] | |
if data.get('jobStatus') == 'Succeeded': | |
print("Job Succeeded! Please find the SDF image in history") | |
self.sio.disconnect() | |
else: | |
print("Received event with insufficient arguments.") | |
def start(self): | |
self.sio.connect(f"{self.base_url}/scheduler_socket?subscription={self.subscription_key}", | |
namespaces=['/api/scheduler_socket'], transports='websocket') | |
self.sio.wait() |