selfitcamera
init
b4297d4
raw
history blame
7.73 kB
from utils import *
from config import *
temp_examples = get_temps_examples(taskType)
user_examples = get_user_examples(taskType)
showcase_examples = get_showcase_examples(taskType)
user_recorder = UserRecorder()
css = """
.gradio-container {width: 85% !important}
"""
def onClick(temp_image, user_image, caption_text, token_text,
param4_text, param5_text, request: gr.Request):
user_mask = None
if taskType=='2':
user_mask = user_image['layers'][0]
user_image = user_image['background']
user_mask = (user_mask.sum(2)>0).astype(np.uint8)*255
user_image = np.array(Image.fromarray(user_image).convert('RGB'))
if user_image.sum()==0:
yield None, "please upload a photo!!!"
return None, "please upload a photo!!!"
if user_mask.sum()==0:
yield None, "please draw a area!!!"
return None, "please draw a area!!!"
if taskType=='7':
try:
param4_text, param5_text = str(float(param4_text)), str(float(param5_text))
except ValueError:
yield None, "Invalid width/height: Please enter a valid float"
return None, "Invalid width/height: Please enter a valid float"
if len(caption_text)==0:
yield None, "Please enter English caption text !!! "
return None, "Please enter English caption text !!! "
else:
param4_text, param5_text = '', ''
# print("======> temp_image ", type(temp_image))
# print("======> user_image ", type(user_image))
# print("======> caption_text ", type(caption_text))
if temp_image is None:
yield None, "please choose a template!!!"
return None, "please choose a template!!!"
if user_image is None:
yield None, "please upload a photo!!!"
return None, "please upload a photo!!!"
try:
client_ip = request.client.host
x_forwarded_for = dict(request.headers).get('x-forwarded-for')
if x_forwarded_for: client_ip = x_forwarded_for
if not check_region_warp(client_ip):
return None, "Failed !!! Our server is under maintenance, please try again later"
# 检查是否可以继续试用
check_res, info = user_recorder.check_record(ip=client_ip, token=token_text)
if not check_res:
yield None, info
return None, info
# 上传用户照片
yield None, "start to upload, please wait..."
upload_url, uploadm_url = upload_user_img_mask(client_ip, user_image, user_mask)
if len(upload_url)==0:
yield None, "fail to upload"
return None, "fail to upload"
# return
# 发布任务
yield None, "start to public, please wait..."
taskId = publicSelfitTask(upload_url, uploadm_url, temp_image,
caption_text, param4_text, param5_text)
if not taskId:
yield None, "fail to public task..."
return None, "fail to public task..."
max_try = 30
wait_s = 3
yield None, "start to process, please wait..."
# time.sleep(2)
for i in range(max_try):
time.sleep(wait_s)
taskStatus = getTaskRes(taskId)
if taskStatus is None: continue
user_recorder.save_record(taskStatus, ip=client_ip, token=token_text)
status = taskStatus['status']
if status in ['FAILED', 'CANCELLED', 'TIMED_OUT', ]:
yield None, f"task failed, query {i}, status {status}"
return None, f"task failed, query {i}, status {status}"
elif status in ['IN_QUEUE', 'IN_PROGRESS', 'IN_QUEUE', ]:
yield None, f"task is on processing, query {i}, status {status}"
elif status=='COMPLETED':
out = taskStatus['output']['job_results']['output1']
yield out, f"task is COMPLETED"
return out, f"{i} task COMPLETED"
yield None, "fail to query task.."
return None, "fail to query task.."
except Exception as e:
print(e)
raise e
yield None, "fail to create task"
return None, "fail to create task"
def onLoad(token_text, request: gr.Request):
client_ip = request.client.host
x_forwarded_for = dict(request.headers).get('x-forwarded-for')
if x_forwarded_for:
client_ip = x_forwarded_for
his_datas, total_n, msg = user_recorder.get_record(ip=client_ip, token=token_text)
left_n = max(0, LimitTask-total_n)
his_datas.append(msg)
his_datas.append(f"Submit ({left_n} attempts left)")
return his_datas
with gr.Blocks(css=css) as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column():
with gr.Column():
temp_image = gr.Image(sources='clipboard', type="filepath", label=TempLabel,
value=temp_examples[0][0], visible=TempVisible, interactive=TempInter)
temp_example = gr.Examples(inputs=[temp_image], examples_per_page=9,
examples=temp_examples, visible=TempVisible)
with gr.Column():
with gr.Column():
if taskType=='2':
brush = gr.Brush(colors=['#FF0000'], color_mode='fixed')
user_image = gr.ImageEditor(value=None, type="numpy",
eraser=False, brush=brush ,layers=False, sources=['upload',],
transforms=[], label=UserLabel)
else:
user_image = gr.Image(value=None, type="numpy", label=UserLabel)
param4_text = gr.Textbox(value="0.5", interactive=True, label=Param4Label, visible=Param4Visible)
param5_text = gr.Textbox(value="0.5", interactive=True, label=Param5Label, visible=Param5Visible)
caption_text = gr.Textbox(value="", interactive=True, label=CaptionLabel, visible=CapVisible)
with gr.Column():
with gr.Column():
res_image = gr.Image(label="generate image", value=None, type="filepath")
info_text = gr.Markdown(value="", label='Runtime Info') # 创建 Markdown 输出组件
run_button = gr.Button(value="Submit")
token_text = gr.Textbox(value="", interactive=True,
label='Enter Your Api Key (optional)', visible=is_show_token)
with gr.Column():
show_case = gr.Examples(examples=showcase_examples,
inputs=[temp_image, user_image, res_image, ],label=None)
with gr.Tab('history'):
with gr.Row(): # 用 Row 包裹按钮
with gr.Column(scale=0.5): # Button 占用 Row 的一半
refresh_button = gr.Button("Refresh History", size="small")
MK02 = gr.Markdown(value="") # 示例 Markdown 内容
with gr.Row():
his_input1 = gr.HTML()
his_output1 = gr.HTML()
with gr.Row():
his_input2 = gr.HTML()
his_output2 = gr.HTML()
with gr.Row():
his_input3 = gr.HTML()
his_output3 = gr.HTML()
outputs_onload = [his_input1, his_output1, his_input2, his_output2, his_input3, his_output3,
MK02, run_button]
run_button.click(fn=onClick, inputs=[temp_image, user_image, caption_text,
token_text, param4_text, param5_text], outputs=[res_image, info_text])
refresh_button.click(fn=onLoad, inputs=[token_text], outputs=outputs_onload)
demo.load(onLoad, inputs=[token_text], outputs=outputs_onload)
if __name__ == "__main__":
demo.queue(max_size=50)
demo.launch(server_name='0.0.0.0')