# Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # More information about the method can be found at https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- import functools import os import gradio as gr import numpy as np import torch as torch from PIL import Image import spaces import diffusers from stablenormal.pipeline_yoso_normal import YOSONormalsPipeline from stablenormal.pipeline_stablenormal import StableNormalPipeline from stablenormal.scheduler.heuristics_ddimsampler import HEURI_DDIMScheduler from data_utils import HWC3, resize_image import sys import cv2 sys.path.append('./geowizard') from models.geowizard_pipeline import DepthNormalEstimationPipeline class Geowizard(object): ''' Simple Stable Diffusion Package ''' def __init__(self): self.model = DepthNormalEstimationPipeline.from_pretrained("lemonaddie/Geowizard", torch_dtype=torch.float16) def cuda(self): self.model.cuda() return self def cpu(self): self.model.cpu() return self def float(self): self.model.float() return self def to(self, device): self.model.to(device) return self def eval(self): self.model.eval() return self def train(self): self.model.train() return self @torch.no_grad() def __call__(self, img, image_resolution=768): pipe_out = self.model(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)), denoising_steps = 10, ensemble_size= 1, processing_res = image_resolution, match_input_res = True, domain = "indoor", color_map = "Spectral", show_progress_bar = False, ) pred_normal = pipe_out.normal_np pred_normal = (pred_normal + 1) / 2 * 255 pred_normal = pred_normal.astype(np.uint8) return pred_normal def __repr__(self): return f"model: \n{self.model}" class Marigold(Geowizard): ''' Simple Stable Diffusion Package ''' def __init__(self): self.model= diffusers.MarigoldNormalsPipeline.from_pretrained("prs-eth/marigold-normals-v0-1", torch_dtype=torch.float16) @torch.no_grad() def __call__(self, img, image_resolution=768): pipe_out = self.model(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))) pred_normal = pipe_out.prediction[0] pred_normal[..., 0] = -pred_normal[..., 0] pred_normal = (pred_normal + 1) / 2 * 255 pred_normal = pred_normal.astype(np.uint8) return pred_normal def __repr__(self): return f"model: \n{self.model}" class StableNormal(Geowizard): ''' Simple Stable Diffusion Package ''' def __init__(self): x_start_pipeline = YOSONormalsPipeline.from_pretrained('Stable-X/yoso-normal-v0-3', trust_remote_code=True, variant="fp16", torch_dtype=torch.float16) self.model = StableNormalPipeline.from_pretrained('Stable-X/stable-normal-v0-1', trust_remote_code=True, variant="fp16", torch_dtype=torch.float16, scheduler=HEURI_DDIMScheduler(prediction_type='sample', beta_start=0.00085, beta_end=0.0120, beta_schedule = "scaled_linear")) # two stage concat self.model.x_start_pipeline = x_start_pipeline self.model.x_start_pipeline.to('cuda', torch.float16) self.model.prior.to('cuda', torch.float16) @torch.no_grad() def __call__(self, img, image_resolution=768): pipe_out = self.model(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))) pred_normal = pipe_out.prediction[0] pred_normal = (pred_normal + 1) / 2 * 255 pred_normal = pred_normal.astype(np.uint8) return pred_normal def to(self, device): self.model.to(device, torch.float16) def __repr__(self): return f"model: \n{self.model}" class YosoNormal(Geowizard): def __init__(self): self.model = YOSONormalsPipeline.from_pretrained('Stable-X/yoso-normal-v0-3', trust_remote_code=True, variant="fp16", torch_dtype=torch.float16, t_start=0) # two stage concat self.model.x_start_pipeline = x_start_pipeline self.model.x_start_pipeline.to('cuda', torch.float16) self.model.prior.to('cuda', torch.float16) @torch.no_grad() def __call__(self, img, image_resolution=768): pipe_out = self.model(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))) pred_normal = pipe_out.prediction[0] pred_normal = (pred_normal + 1) / 2 * 255 pred_normal = pred_normal.astype(np.uint8) return pred_normal def to(self, device): self.model.to(device, torch.float16) def __repr__(self): return f"model: \n{self.model}" class DSINE(object): ''' Simple Stable Diffusion Package ''' def __init__(self): self.model = torch.hub.load("hugoycj/DSINE-hub", "DSINE", local_file_path='./models/dsine.pt', trust_repo=True) def cuda(self): self.model.cuda() return self def float(self): self.model.float() return self def to(self, device): self.model.to(device) return self def eval(self): self.model.eval() return self def train(self): self.model.train() return self @torch.no_grad() def __call__(self, img, image_resolution=768): pred_normal = self.model.infer_cv2(img)[0] # (3, H, W) pred_normal = (pred_normal + 1) / 2 * 255 pred_normal = pred_normal.cpu().numpy().transpose(1, 2, 0) # rgb pred_normal = pred_normal.astype(np.uint8) return pred_normal def __repr__(self): return f"model: \n{self.model}" def process( pipe_list, path_input, ): names = ['DSINE', 'Marigold', 'GeoWizard', 'StableNormal'] path_out_vis_list = [] for pipe in pipe_list: try: pipe.to('cuda') except: pass img = cv2.imread(path_input) raw_input_image = HWC3(img) ori_H, ori_W, _ = raw_input_image.shape img = resize_image(raw_input_image, 768) pipe_out = pipe( img, 768, ) pred_normal= cv2.resize(pipe_out, (ori_W, ori_H)) path_out_vis_list.append(Image.fromarray(pred_normal)) try: pipe.to('cpu') except: pass _output = path_out_vis_list + [None] * (4 - len(path_out_vis_list)) yield _output def run_demo_server(pipe): process_pipe = spaces.GPU(functools.partial(process, pipe), duration=120) os.environ["GRADIO_ALLOW_FLAGGING"] = "never" with gr.Blocks( analytics_enabled=False, title="Normal Estimation Arena", css=""" #download { height: 118px; } .slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } """, ) as demo: with gr.Row(): with gr.Column(): input_image = gr.Image( label="Input Image", type="filepath", height=256, ) with gr.Column(): submit_btn = gr.Button(value="Compute normal", variant="primary") clear_btn = gr.Button(value="Clear") with gr.Row(): with gr.Column(): DSINE_output_slider = gr.Image( label="DSINE", type="filepath", ) with gr.Column(): marigold_output_slider = gr.Image( label="Marigold", type="filepath", ) with gr.Row(): with gr.Column(): geowizard_output_slider = gr.Image( label="Geowizard", type="filepath", ) with gr.Column(): Ours_slider = gr.Image( label="StableNormal", type="filepath", ) outputs = [ DSINE_output_slider, marigold_output_slider, geowizard_output_slider, Ours_slider, ] submit_btn.click( fn=process_pipe, inputs=input_image, outputs=outputs, concurrency_limit=1, ) gr.Examples( fn=process_pipe, examples=sorted([ os.path.join("files", "images", name) for name in os.listdir(os.path.join("files", "images")) ]), inputs=input_image, outputs=outputs, cache_examples=False, ) def clear_fn(): out = [] out += [ gr.Button(interactive=True), gr.Button(interactive=True), gr.Image(value=None, interactive=True), None, None, None, None, None, None, ] return out clear_btn.click( fn=clear_fn, inputs=[], outputs= [ submit_btn, input_image, marigold_output_slider, geowizard_output_slider, DSINE_output_slider, Ours_slider, ], ) demo.queue( api_open=False, ).launch( server_name="0.0.0.0", server_port=7860, share=False ) def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dsine_pipe = DSINE() marigold_pipe = Marigold() geowizard_pipe = Geowizard() our_pipe = StableNormal() yoso_pipe = YosoNormal() run_demo_server([dsine_pipe, marigold_pipe, geowizard_pipe, our_pipe, yoso_pipe]) if __name__ == "__main__": main()