cbensimon HF staff commited on
Commit
8876d0e
1 Parent(s): ffcbbc8

- By wrapping GPU functions like this the pipeline doesn't have to be transferred from the main process to the GPU worker and GPU coldstart should be much faster (measured between 3s and 4s on this Space)
- prefetch_hf_cache actually calls the pipe, which can't be done outside of decorated functions on ZeroGPU.
When no coldstart happens, execution time does not seem to change, with or without prefetch_hf_cache

Files changed (1) hide show
  1. app.py +3 -12
app.py CHANGED
@@ -46,7 +46,6 @@ default_bas_frame_near = 1
46
  default_bas_frame_far = 1
47
 
48
 
49
- @spaces.GPU
50
  def process_image(
51
  pipe,
52
  path_input,
@@ -89,7 +88,6 @@ def process_image(
89
  )
90
 
91
 
92
- @spaces.GPU
93
  def process_video(
94
  pipe,
95
  path_input,
@@ -187,7 +185,6 @@ def process_video(
187
  )
188
 
189
 
190
- @spaces.GPU
191
  def process_bas(
192
  pipe,
193
  path_input,
@@ -282,9 +279,9 @@ def process_bas(
282
 
283
 
284
  def run_demo_server(pipe):
285
- process_pipe_image = functools.partial(process_image, pipe)
286
- process_pipe_video = functools.partial(process_video, pipe)
287
- process_pipe_bas = functools.partial(process_bas, pipe)
288
  os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
289
 
290
  gradio_theme = gr.themes.Default()
@@ -792,11 +789,6 @@ def run_demo_server(pipe):
792
  )
793
 
794
 
795
- def prefetch_hf_cache(pipe):
796
- process_image(pipe, "files/image/bee.jpg", 1, 1, 64)
797
- shutil.rmtree("files/image/bee_output")
798
-
799
-
800
  def main():
801
  CHECKPOINT = "prs-eth/marigold-v1-0"
802
  CHECKPOINT_UNET_LCM = "prs-eth/marigold-lcm-v1-0"
@@ -821,7 +813,6 @@ def main():
821
  pass # run without xformers
822
 
823
  pipe = pipe.to(device)
824
- prefetch_hf_cache(pipe)
825
  run_demo_server(pipe)
826
 
827
 
 
46
  default_bas_frame_far = 1
47
 
48
 
 
49
  def process_image(
50
  pipe,
51
  path_input,
 
88
  )
89
 
90
 
 
91
  def process_video(
92
  pipe,
93
  path_input,
 
185
  )
186
 
187
 
 
188
  def process_bas(
189
  pipe,
190
  path_input,
 
279
 
280
 
281
  def run_demo_server(pipe):
282
+ process_pipe_image = spaces.GPU(functools.partial(process_image, pipe))
283
+ process_pipe_video = spaces.GPU(functools.partial(process_video, pipe))
284
+ process_pipe_bas = spaces.GPU(functools.partial(process_bas, pipe))
285
  os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
286
 
287
  gradio_theme = gr.themes.Default()
 
789
  )
790
 
791
 
 
 
 
 
 
792
  def main():
793
  CHECKPOINT = "prs-eth/marigold-v1-0"
794
  CHECKPOINT_UNET_LCM = "prs-eth/marigold-lcm-v1-0"
 
813
  pass # run without xformers
814
 
815
  pipe = pipe.to(device)
 
816
  run_demo_server(pipe)
817
 
818