Stable-X commited on
Commit
a5a71d2
β€’
1 Parent(s): 65d11db

feat: Add Yoso-normal

Browse files
Files changed (2) hide show
  1. app.py +30 -2
  2. stablenormal/pipeline_yoso_normal.py +1 -1
app.py CHANGED
@@ -131,7 +131,7 @@ class StableNormal(Geowizard):
131
  '''
132
 
133
  def __init__(self):
134
- x_start_pipeline = YOSONormalsPipeline.from_pretrained('Stable-X/yoso-normal-v0-2', trust_remote_code=True,
135
  variant="fp16", torch_dtype=torch.float16)
136
  self.model = StableNormalPipeline.from_pretrained('Stable-X/stable-normal-v0-1', trust_remote_code=True,
137
  variant="fp16", torch_dtype=torch.float16,
@@ -162,6 +162,33 @@ class StableNormal(Geowizard):
162
 
163
  return f"model: \n{self.model}"
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  class DSINE(object):
166
  '''
167
  Simple Stable Diffusion Package
@@ -380,8 +407,9 @@ def main():
380
  marigold_pipe = Marigold()
381
  geowizard_pipe = Geowizard()
382
  our_pipe = StableNormal()
 
383
 
384
- run_demo_server([dsine_pipe, marigold_pipe, geowizard_pipe, our_pipe])
385
 
386
 
387
  if __name__ == "__main__":
 
131
  '''
132
 
133
  def __init__(self):
134
+ x_start_pipeline = YOSONormalsPipeline.from_pretrained('Stable-X/yoso-normal-v0-3', trust_remote_code=True,
135
  variant="fp16", torch_dtype=torch.float16)
136
  self.model = StableNormalPipeline.from_pretrained('Stable-X/stable-normal-v0-1', trust_remote_code=True,
137
  variant="fp16", torch_dtype=torch.float16,
 
162
 
163
  return f"model: \n{self.model}"
164
 
165
+ class YosoNormal(Geowizard):
166
+ def __init__(self):
167
+ self.model = YOSONormalsPipeline.from_pretrained('Stable-X/yoso-normal-v0-3', trust_remote_code=True,
168
+ variant="fp16", torch_dtype=torch.float16, t_start=0)
169
+
170
+ # two stage concat
171
+ self.model.x_start_pipeline = x_start_pipeline
172
+ self.model.x_start_pipeline.to('cuda', torch.float16)
173
+ self.model.prior.to('cuda', torch.float16)
174
+
175
+
176
+ @torch.no_grad()
177
+ def __call__(self, img, image_resolution=768):
178
+ pipe_out = self.model(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)))
179
+ pred_normal = pipe_out.prediction[0]
180
+ pred_normal = (pred_normal + 1) / 2 * 255
181
+ pred_normal = pred_normal.astype(np.uint8)
182
+
183
+ return pred_normal
184
+
185
+ def to(self, device):
186
+ self.model.to(device, torch.float16)
187
+
188
+ def __repr__(self):
189
+
190
+ return f"model: \n{self.model}"
191
+
192
  class DSINE(object):
193
  '''
194
  Simple Stable Diffusion Package
 
407
  marigold_pipe = Marigold()
408
  geowizard_pipe = Geowizard()
409
  our_pipe = StableNormal()
410
+ yoso_pipe = YosoNormal()
411
 
412
+ run_demo_server([dsine_pipe, marigold_pipe, geowizard_pipe, our_pipe, yoso_pipe])
413
 
414
 
415
  if __name__ == "__main__":
stablenormal/pipeline_yoso_normal.py CHANGED
@@ -588,7 +588,7 @@ class YOSONormalsPipeline(StableDiffusionControlNetPipeline):
588
  image_latent = image_latent * self.vae.config.scaling_factor
589
  image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w]
590
 
591
- pred_latent = latents
592
  if pred_latent is None:
593
  pred_latent = randn_tensor(
594
  image_latent.shape,
 
588
  image_latent = image_latent * self.vae.config.scaling_factor
589
  image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w]
590
 
591
+ pred_latent = torch.zeros_like(latents)
592
  if pred_latent is None:
593
  pred_latent = randn_tensor(
594
  image_latent.shape,