# from utils.args import parse_args import logging import os import argparse from pathlib import Path from PIL import Image import numpy as np import torch from tqdm.auto import tqdm from diffusers.utils import check_min_version from pipeline import LotusGPipeline, LotusDPipeline from utils.image_utils import colorize_depth_map from utils.seed_all import seed_all from contextlib import nullcontext import cv2 from tqdm import tqdm # 添加这一行以导入 tqdm check_min_version('0.28.0.dev0') def infer_pipe(pipe, image_input, task_name, seed, device): if seed is None: generator = None else: generator = torch.Generator(device=device).manual_seed(seed) if torch.backends.mps.is_available(): autocast_ctx = nullcontext() else: autocast_ctx = torch.autocast(pipe.device.type) with autocast_ctx: test_image = Image.open(image_input).convert('RGB') test_image = np.array(test_image).astype(np.float16) test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0) test_image = test_image / 127.5 - 1.0 test_image = test_image.to(device) task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device) task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1) # Run pred = pipe( rgb_in=test_image, prompt='', num_inference_steps=1, generator=generator, # guidance_scale=0, output_type='np', timesteps=[999], task_emb=task_emb, ).images[0] # Post-process the prediction if task_name == 'depth': output_npy = pred.mean(axis=-1) # 修改为输出灰度图 output_color = Image.fromarray((output_npy * 255).astype(np.uint8), mode='L') else: output_npy = pred output_color = Image.fromarray((output_npy * 255).astype(np.uint8)) return output_color def lotus_video(input_video, task_name, seed, device): if task_name == 'depth': model_g = 'jingheya/lotus-depth-g-v1-0' else: model_g = 'jingheya/lotus-normal-g-v1-0' dtype = torch.float16 pipe_g = LotusGPipeline.from_pretrained( model_g, torch_dtype=dtype, ) pipe_g.to(device) pipe_g.set_progress_bar_config(disable=True) logging.info(f"Successfully loading pipeline from {model_g}.") # load the video and split it into frames cap = cv2.VideoCapture(input_video) frames = [] while True: ret, frame = cap.read() if not ret: break frames.append(frame) cap.release() logging.info(f"There are {len(frames)} frames in the video.") if seed is None: generator = None else: generator = torch.Generator(device=device).manual_seed(seed) task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device) task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1) output_g = [] for frame in tqdm(frames, desc="Processing frames"): # 使用 tqdm 包裹 frames 列表 if torch.backends.mps.is_available(): autocast_ctx = nullcontext() else: autocast_ctx = torch.autocast(pipe_g.device.type) with autocast_ctx: test_image = frame test_image = np.array(test_image).astype(np.float16) test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0) test_image = test_image / 127.5 - 1.0 test_image = test_image.to(device) # Run pred_g = pipe_g( rgb_in=test_image, prompt='', num_inference_steps=1, generator=generator, # guidance_scale=0, output_type='np', timesteps=[999], task_emb=task_emb, ).images[0] # Post-process the prediction if task_name == 'depth': output_npy = pred_g.mean(axis=-1) # 修改为输出灰度图 output_color_g = Image.fromarray((output_npy * 255).astype(np.uint8), mode='L') else: output_npy_g = pred_g output_color_g = Image.fromarray((output_npy_g * 255).astype(np.uint8)) output_g.append(output_color_g) return output_g def lotus(image_input, task_name, seed, device): if task_name == 'depth': model_g = 'jingheya/lotus-depth-g-v1-0' else: model_g = 'jingheya/lotus-normal-g-v1-0' dtype = torch.float16 pipe_g = LotusGPipeline.from_pretrained( model_g, torch_dtype=dtype, ) pipe_g.to(device) pipe_g.set_progress_bar_config(disable=True) logging.info(f"Successfully loading pipeline from {model_g}.") output_g = infer_pipe(pipe_g, image_input, task_name, seed, device) return output_g def parse_args(): '''Set the Args''' parser = argparse.ArgumentParser( description="Run Lotus..." ) # model settings parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, help="pretrained model path from hugging face or local dir", ) parser.add_argument( "--prediction_type", type=str, default="sample", help="The used prediction_type. ", ) parser.add_argument( "--timestep", type=int, default=999, ) parser.add_argument( "--mode", type=str, default="regression", # "generation" help="Whether to use the generation or regression pipeline." ) parser.add_argument( "--task_name", type=str, default="depth", # "normal" ) parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) # inference settings parser.add_argument("--seed", type=int, default=None, help="Random seed.") parser.add_argument( "--output_dir", type=str, required=True, help="Output directory." ) parser.add_argument( "--input_dir", type=str, required=True, help="Input directory." ) parser.add_argument( "--half_precision", action="store_true", help="Run with half-precision (16-bit float), might lead to suboptimal result.", ) args = parser.parse_args() return args def main(): logging.basicConfig(level=logging.INFO) logging.info(f"Run inference...") args = parse_args() # -------------------- Preparation -------------------- # Random seed if args.seed is not None: seed_all(args.seed) # Output directories os.makedirs(args.output_dir, exist_ok=True) logging.info(f"Output dir = {args.output_dir}") output_dir_color = os.path.join(args.output_dir, f'{args.task_name}_vis') output_dir_npy = os.path.join(args.output_dir, f'{args.task_name}') if not os.path.exists(output_dir_color): os.makedirs(output_dir_color) if not os.path.exists(output_dir_npy): os.makedirs(output_dir_npy) # half_precision if args.half_precision: dtype = torch.float16 logging.info(f"Running with half precision ({dtype}).") else: dtype = torch.float16 # -------------------- Device -------------------- if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") logging.warning("CUDA is not available. Running on CPU will be slow.") logging.info(f"Device = {device}") # -------------------- Data -------------------- root_dir = Path(args.input_dir) test_images = list(root_dir.rglob('*.png')) + list(root_dir.rglob('*.jpg')) test_images = sorted(test_images) print('==> There are', len(test_images), 'images for validation.') # -------------------- Model -------------------- if args.mode == 'generation': pipeline = LotusGPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=dtype, ) elif args.mode == 'regression': pipeline = LotusDPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=dtype, ) else: raise ValueError(f'Invalid mode: {args.mode}') logging.info(f"Successfully loading pipeline from {args.pretrained_model_name_or_path}.") pipeline = pipeline.to(device) pipeline.set_progress_bar_config(disable=True) if args.enable_xformers_memory_efficient_attention: pipeline.enable_xformers_memory_efficient_attention() if args.seed is None: generator = None else: generator = torch.Generator(device=device).manual_seed(args.seed) # -------------------- Inference and saving -------------------- with torch.no_grad(): for i in tqdm(range(len(test_images))): # Preprocess validation image test_image = Image.open(test_images[i]).convert('RGB') test_image = np.array(test_image).astype(np.float16) test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0) test_image = test_image / 127.5 - 1.0 test_image = test_image.to(device) task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device) task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1) # Run pred = pipeline( rgb_in=test_image, prompt='', num_inference_steps=1, generator=generator, # guidance_scale=0, output_type='np', timesteps=[args.timestep], task_emb=task_emb, ).images[0] # Post-process the prediction save_file_name = os.path.basename(test_images[i])[:-4] if args.task_name == 'depth': output_npy = pred.mean(axis=-1) # 修改为输出灰度图 output_color = Image.fromarray((output_npy * 255).astype(np.uint8), mode='L') else: output_npy = pred output_color = Image.fromarray((output_npy * 255).astype(np.uint8)) output_color.save(os.path.join(output_dir_color, f'{save_file_name}.png')) np.save(os.path.join(output_dir_npy, f'{save_file_name}.npy'), output_npy) print('==> Inference is done. \n==> Results saved to:', args.output_dir) if __name__ == '__main__': main()