skkk commited on
Commit
20829d2
1 Parent(s): 6be6b29

Fix a bug of preprocess

Browse files
Files changed (1) hide show
  1. Rodin.py +30 -16
Rodin.py CHANGED
@@ -6,6 +6,7 @@ import random
6
  import base64
7
  import io
8
  from PIL import Image
 
9
  from requests_toolbelt.multipart.encoder import MultipartEncoder
10
  from constant import *
11
 
@@ -98,6 +99,24 @@ def rodin_update(prompt, task_uuid, token, settings):
98
  response = requests.post(f"{BASE_URL}/task/rodin_update", data={"uuid": task_uuid, "prompt": prompt, "settings": settings}, headers=headers)
99
  return response.json()
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  class Generator:
103
  def __init__(self, user_id, password) -> None:
@@ -109,22 +128,17 @@ class Generator:
109
  return prompt, cache_image_base64
110
  print("Preprocessing image...")
111
 
112
- image_file = open(image_path, 'rb')
113
- if image_file == None:
114
- print("Invalid image file.")
115
- try:
116
- if prompt and task_uuid:
117
- preprocess_response = rodin_preprocess_image(generate_prompt=False, image=image_file, name="images.jpeg", token=self.token)
118
- else:
119
- preprocess_response = rodin_preprocess_image(generate_prompt=True, image=image_file, name="images.jpeg", token=self.token)
120
- if 'error' in preprocess_response:
121
- print("Error in image preprocessing:", preprocess_response['error'])
122
- else:
123
- if not (prompt and task_uuid):
124
- prompt = preprocess_response.get('prompt', 'Default prompt if none returned')
125
- processed_image = "data:image/png;base64," + preprocess_response.get('processed_image', None)
126
- finally:
127
- image_file.close()
128
 
129
  return prompt, processed_image
130
 
 
6
  import base64
7
  import io
8
  from PIL import Image
9
+ from io import BytesIO
10
  from requests_toolbelt.multipart.encoder import MultipartEncoder
11
  from constant import *
12
 
 
99
  response = requests.post(f"{BASE_URL}/task/rodin_update", data={"uuid": task_uuid, "prompt": prompt, "settings": settings}, headers=headers)
100
  return response.json()
101
 
102
+ def load_image(img_path):
103
+ image = Image.open(img_path)
104
+
105
+ # 按比例缩小图像到长度为1024
106
+ width, height = image.size
107
+ if width > height:
108
+ scale = 1024 / width
109
+ else:
110
+ scale = 1024 / height
111
+ new_width = int(width * scale)
112
+ new_height = int(height * scale)
113
+ resized_image = image.resize((new_width, new_height))
114
+
115
+ # 将 PIL.Image 对象转换为字节流
116
+ byte_io = BytesIO()
117
+ resized_image.save(byte_io, format='PNG')
118
+ image_bytes = byte_io.getvalue()
119
+ return image_bytes
120
 
121
  class Generator:
122
  def __init__(self, user_id, password) -> None:
 
128
  return prompt, cache_image_base64
129
  print("Preprocessing image...")
130
 
131
+ image_file = load_image(image_path)
132
+ if prompt and task_uuid:
133
+ preprocess_response = rodin_preprocess_image(generate_prompt=False, image=image_file, name="images.png", token=self.token)
134
+ else:
135
+ preprocess_response = rodin_preprocess_image(generate_prompt=True, image=image_file, name="images.png", token=self.token)
136
+ if 'error' in preprocess_response:
137
+ print("Error in image preprocessing:", preprocess_response['error'])
138
+ else:
139
+ if not (prompt and task_uuid):
140
+ prompt = preprocess_response.get('prompt', 'Default prompt if none returned')
141
+ processed_image = "data:image/png;base64," + preprocess_response.get('processed_image', None)
 
 
 
 
 
142
 
143
  return prompt, processed_image
144