jiayong commited on
Commit
96538ab
1 Parent(s): b2f0a65

Update gen_client.py

Browse files
Files changed (1) hide show
  1. gen_client.py +19 -22
gen_client.py CHANGED
@@ -401,14 +401,12 @@ class HumanGenService:
401
 
402
  def click_button_func_async(self, user_id, request_id, input_mode, ref_image_path, ref_video_path, input_prompt='', prompt_template='',model_id=False):
403
  start_time = time.time()
404
- if is_wanx_platform:
405
- user_id = 'wanx_lab'
406
- request_id = get_random_string()
407
- print(f"request_id: {request_id}, generate user_id: {user_id} and request_id: {request_id}")
408
  if user_id is None or user_id == '':
409
- user_id = 'test_version_phone'
 
410
  if request_id is None or request_id == '':
411
  request_id = get_random_string()
 
412
  # key by: ref_video_name, digest(ref_image_path), prompt_template, input_prompt,
413
  # scale_depth, scale_pose
414
  #print("ref_image_path:%s ref_video_path:%s" % (ref_image_path, ref_video_path) )
@@ -543,12 +541,12 @@ class HumanGenService:
543
 
544
 
545
  def valid_check(self, user_id, request_id, input_mode, ref_image_path, ref_video_path, input_prompt='', prompt_template='',model_id=False):
546
- if is_wanx_platform:
547
- user_id = 'wanx_lab'
548
  if user_id is None or user_id == '':
549
- user_id = 'test_version_phone'
 
550
  if request_id is None or request_id == '':
551
  request_id = get_random_string()
 
552
  print(f"-----------------request_id: {request_id}, user_id: {user_id}---------------")
553
 
554
  self.lock.acquire()
@@ -577,22 +575,22 @@ class HumanGenService:
577
 
578
  if ref_image_path is None or ref_image_path == '' or (not os.path.exists(ref_image_path)):
579
  print(f"request_id: {request_id}, No image input, task over!")
580
- # raise gr.Error("请输入图片!")
581
  return "Please input a image."
582
 
583
  if input_mode == 'template_mode' and cartoon_recog == 'realhuman' and (prompt_template == '' or prompt_template == [] or prompt_template is None):
584
  print(f"request_id: {request_id}, No prompt input, task over!")
585
- # raise gr.Error("请输入prompt!")
586
  return "Please input a prompt."
587
 
588
  if input_mode == 'prompt_mode' and cartoon_recog == 'realhuman' and (input_prompt == '' or input_prompt == [] or input_prompt is None):
589
  print(f"request_id: {request_id}, No prompt input, task over!")
590
- # raise gr.Error("请输入prompt!")
591
  return "Please input a prompt."
592
 
593
  if input_mode == 'template_mode' and (ref_video_path is None or ref_video_path == ''):
594
  print(f"request_id: {request_id}, No video input, task over!")
595
- # raise gr.Error("请输入视频!")
596
  return "Please input a video."
597
 
598
  ref_video_name = ''
@@ -627,13 +625,12 @@ class HumanGenService:
627
  self.lock.release()
628
 
629
  def get_ranking_location(self, user_id):
630
- if is_wanx_platform:
631
- user_id = 'wanx_lab'
632
  if user_id is None or user_id == '':
633
- user_id = 'test_version_phone'
 
634
  process_status = ''
635
 
636
- print(f'[get_ranking_location] ------ clean timeout and process over request start ------ ')
637
  if len(self.all_requests) > 0:
638
  for i in range(min(num_instance_dashone, len(self.all_requests))):
639
  req = self.all_requests[i]
@@ -648,7 +645,7 @@ class HumanGenService:
648
  if req in self.all_user_requests[uid]:
649
  uuid = uid
650
  break
651
- print(f'[get_ranking_location] find timeout request: {req}, uuid: {uuid}')
652
  data = '{"header":{"request_id":"","service_id":"","task_id":""},"payload":{"input": {"ref_image_path": "", "ref_video_path": "", "ref_video_name": "", "input_prompt": "", "prompt_template": "", "scale_depth": 0.7, "scale_pose": 0.5},"parameters":{}}}'
653
  data = json.loads(data) # string to dict
654
  data['header']['service_id'] = DASHONE_SERVICE_ID
@@ -661,20 +658,20 @@ class HumanGenService:
661
  ret_status, ret_json = query_video_generation(request_id=req, data=data)
662
  # print(f'ret_json = {ret_json}')
663
  if ret_status == "SUCCESS" or ret_status == "FAILED":
664
- print(f'[get_ranking_location] query timeout request process over: {req}, uuid: {uuid}')
665
  if req in self.all_requests:
666
  self.all_requests.remove(req) # delete request_id
667
  if req in self.all_requests_time:
668
  del self.all_requests_time[req]
669
  else:
670
- print(f'[get_ranking_location] query timeout request process running: {req}, uuid: {uuid}')
671
  break
672
  else:
673
- print(f'[get_ranking_location] no timeout request.')
674
  break
675
  else:
676
  print(f'size of all_requests is empty.')
677
- print(f'[get_ranking_location] ------ clean timeout and process over request end ------ ')
678
 
679
  if user_id not in self.all_user_requests:
680
  return f'You have not request a video generation task.', ''
@@ -701,7 +698,7 @@ class HumanGenService:
701
  data['payload']['input']['user_id'] = user_id
702
  data = json.dumps(data) # to string
703
  ret_status, ret_json = query_video_generation(request_id=request_id, data=data)
704
- print(f'ret_json = {ret_json}')
705
  if ret_status == "SUCCESS":
706
  req = self.all_user_requests[user_id][0]
707
  self.delete_request_id(user_id, req) # delete request_id
 
401
 
402
  def click_button_func_async(self, user_id, request_id, input_mode, ref_image_path, ref_video_path, input_prompt='', prompt_template='',model_id=False):
403
  start_time = time.time()
 
 
 
 
404
  if user_id is None or user_id == '':
405
+ user_id = get_random_string()
406
+ print(f"[click_button_func_async] generate user_id: {user_id}")
407
  if request_id is None or request_id == '':
408
  request_id = get_random_string()
409
+ print(f"[click_button_func_async] generate request_id: {request_id}")
410
  # key by: ref_video_name, digest(ref_image_path), prompt_template, input_prompt,
411
  # scale_depth, scale_pose
412
  #print("ref_image_path:%s ref_video_path:%s" % (ref_image_path, ref_video_path) )
 
541
 
542
 
543
  def valid_check(self, user_id, request_id, input_mode, ref_image_path, ref_video_path, input_prompt='', prompt_template='',model_id=False):
 
 
544
  if user_id is None or user_id == '':
545
+ user_id = get_random_string()
546
+ print(f"[valid_check] generate user_id: {user_id}")
547
  if request_id is None or request_id == '':
548
  request_id = get_random_string()
549
+ print(f"[valid_check] generate request_id: {request_id}")
550
  print(f"-----------------request_id: {request_id}, user_id: {user_id}---------------")
551
 
552
  self.lock.acquire()
 
575
 
576
  if ref_image_path is None or ref_image_path == '' or (not os.path.exists(ref_image_path)):
577
  print(f"request_id: {request_id}, No image input, task over!")
578
+ # raise gr.Error("Please input a image!")
579
  return "Please input a image."
580
 
581
  if input_mode == 'template_mode' and cartoon_recog == 'realhuman' and (prompt_template == '' or prompt_template == [] or prompt_template is None):
582
  print(f"request_id: {request_id}, No prompt input, task over!")
583
+ # raise gr.Error("Please input a prompt!")
584
  return "Please input a prompt."
585
 
586
  if input_mode == 'prompt_mode' and cartoon_recog == 'realhuman' and (input_prompt == '' or input_prompt == [] or input_prompt is None):
587
  print(f"request_id: {request_id}, No prompt input, task over!")
588
+ # raise gr.Error("Please input a prompt!")
589
  return "Please input a prompt."
590
 
591
  if input_mode == 'template_mode' and (ref_video_path is None or ref_video_path == ''):
592
  print(f"request_id: {request_id}, No video input, task over!")
593
+ # raise gr.Error("Please input a video!")
594
  return "Please input a video."
595
 
596
  ref_video_name = ''
 
625
  self.lock.release()
626
 
627
  def get_ranking_location(self, user_id):
 
 
628
  if user_id is None or user_id == '':
629
+ user_id = get_random_string()
630
+ print(f"[get_ranking_location] generate user_id: {user_id}")
631
  process_status = ''
632
 
633
+ print(f'----------- [get_ranking_location] clean timeout and process over request start ----------- ')
634
  if len(self.all_requests) > 0:
635
  for i in range(min(num_instance_dashone, len(self.all_requests))):
636
  req = self.all_requests[i]
 
645
  if req in self.all_user_requests[uid]:
646
  uuid = uid
647
  break
648
+ print(f'find timeout request: {req}, uuid: {uuid}')
649
  data = '{"header":{"request_id":"","service_id":"","task_id":""},"payload":{"input": {"ref_image_path": "", "ref_video_path": "", "ref_video_name": "", "input_prompt": "", "prompt_template": "", "scale_depth": 0.7, "scale_pose": 0.5},"parameters":{}}}'
650
  data = json.loads(data) # string to dict
651
  data['header']['service_id'] = DASHONE_SERVICE_ID
 
658
  ret_status, ret_json = query_video_generation(request_id=req, data=data)
659
  # print(f'ret_json = {ret_json}')
660
  if ret_status == "SUCCESS" or ret_status == "FAILED":
661
+ print(f'query timeout request process over: {req}, uuid: {uuid}')
662
  if req in self.all_requests:
663
  self.all_requests.remove(req) # delete request_id
664
  if req in self.all_requests_time:
665
  del self.all_requests_time[req]
666
  else:
667
+ print(f'query timeout request process running: {req}, uuid: {uuid}')
668
  break
669
  else:
670
+ print(f'no timeout request.')
671
  break
672
  else:
673
  print(f'size of all_requests is empty.')
674
+ print(f'----------- [get_ranking_location] clean timeout and process over request end ----------- ')
675
 
676
  if user_id not in self.all_user_requests:
677
  return f'You have not request a video generation task.', ''
 
698
  data['payload']['input']['user_id'] = user_id
699
  data = json.dumps(data) # to string
700
  ret_status, ret_json = query_video_generation(request_id=request_id, data=data)
701
+ # print(f'ret_json = {ret_json}')
702
  if ret_status == "SUCCESS":
703
  req = self.all_user_requests[user_id][0]
704
  self.delete_request_id(user_id, req) # delete request_id