Ming Li commited on
Commit
a9864eb
1 Parent(s): ea0de4a

set fp32 to avoid errors

Browse files
Files changed (2) hide show
  1. model.py +4 -4
  2. requirements.txt +1 -0
model.py CHANGED
@@ -1,10 +1,10 @@
1
  from __future__ import annotations
2
 
3
  import gc
4
- import spaces
5
 
6
  import numpy as np
7
  import PIL.Image
 
8
  import torch
9
  from controlnet_aux.util import HWC3
10
  from diffusers import (
@@ -53,9 +53,9 @@ class Model:
53
  ):
54
  return self.pipe
55
  model_id = CONTROLNET_MODEL_IDS[task_name]
56
- controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
57
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
58
- base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16
59
  )
60
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
61
  if self.device.type == "cuda":
@@ -87,7 +87,7 @@ class Model:
87
  torch.cuda.empty_cache()
88
  gc.collect()
89
  model_id = CONTROLNET_MODEL_IDS[task_name]
90
- controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
91
  controlnet.to(self.device)
92
  torch.cuda.empty_cache()
93
  gc.collect()
 
1
  from __future__ import annotations
2
 
3
  import gc
 
4
 
5
  import numpy as np
6
  import PIL.Image
7
+ import spaces
8
  import torch
9
  from controlnet_aux.util import HWC3
10
  from diffusers import (
 
53
  ):
54
  return self.pipe
55
  model_id = CONTROLNET_MODEL_IDS[task_name]
56
+ controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float32)
57
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
58
+ base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float32
59
  )
60
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
61
  if self.device.type == "cuda":
 
87
  torch.cuda.empty_cache()
88
  gc.collect()
89
  model_id = CONTROLNET_MODEL_IDS[task_name]
90
+ controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float32)
91
  controlnet.to(self.device)
92
  torch.cuda.empty_cache()
93
  gc.collect()
requirements.txt CHANGED
@@ -8,3 +8,4 @@ kornia
8
  huggingface_hub
9
  gradio==4.26.0
10
  xformers
 
 
8
  huggingface_hub
9
  gradio==4.26.0
10
  xformers
11
+ mediapipe