ZehanWang commited on
Commit
7b88137
1 Parent(s): 864ec44

Upload ./infer_inpaint.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. infer_inpaint.py +97 -0
infer_inpaint.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import pdb
5
+
6
+ from peft import LoraConfig, get_peft_model
7
+ import torch
8
+ from safetensors.torch import load_model, save_model
9
+ from marigold.marigold_inpaint_pipeline import MarigoldInpaintPipeline
10
+ from marigold.duplicate_unet import DoubleUNet2DConditionModel
11
+ import json
12
+ from depth_anything_v2.dpt import DepthAnythingV2
13
+ from torchvision.transforms.functional import pil_to_tensor
14
+ from PIL import Image
15
+ import random
16
+ import numpy as np
17
+ from pycocotools import mask as coco_mask
18
+ from diffusers.schedulers import DDIMScheduler, PNDMScheduler
19
+ from torchvision.transforms import InterpolationMode, Resize, CenterCrop
20
+ import torchvision.transforms as transforms
21
+
22
+ model = MarigoldInpaintPipeline.from_pretrained('stabilityai/stable-diffusion-2')
23
+ unet_config_path = '/home/aiops/wangzh/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2/snapshots/1e128c8891e52218b74cde8f26dbfc701cb99d79/unet/config.json'
24
+ # unet_checkpoint_path = '/home/aiops/wangzh/marigold/768_gen/diffusion_pytorch_model.safetensors'
25
+ model.unet = DoubleUNet2DConditionModel(**json.load(open(unet_config_path)))
26
+ # model.unet.load_state_dict(torch.load(unet_checkpoint_path, map_location='cpu'), strict=False)
27
+
28
+ model.unet.config["in_channels"] = 13
29
+ model.unet.duplicate_model()
30
+ model.unet.inpaint_rgb_conv_in()
31
+ model.unet.inpaint_depth_conv_in()
32
+
33
+ unet_lora_config = LoraConfig(
34
+ r=128,
35
+ lora_alpha=128,
36
+ init_lora_weights="gaussian",
37
+ target_modules=['to_k','to_q','to_v','to_out.0'],
38
+ )
39
+ model.unet = get_peft_model(model.unet, unet_lora_config)
40
+
41
+ sd2inpaint_ckpt = torch.load('/home/aiops/wangzh/marigold/output/512-inpaint-0.5-128-vitl-partition/checkpoint/latest/pytorch_model.bin', map_location='cpu')
42
+ model.unet.load_state_dict(sd2inpaint_ckpt)
43
+ model.to('cuda')
44
+
45
+ model_configs = {
46
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
47
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
48
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
49
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
50
+ }
51
+
52
+ model.rgb_scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
53
+ model.depth_scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
54
+
55
+ depth_model = DepthAnythingV2(**model_configs['vitl'])
56
+ depth_model.load_state_dict(
57
+ torch.load(f'/home/aiops/wangzh/Depth-Anything-V2/checkpoints/depth_anything_v2_vitl.pth', map_location='cpu'))
58
+ depth_model = depth_model.to('cuda').eval()
59
+
60
+ image_path = ['/dataset/~sa-1b/data/sa_001000/sa_10000335.jpg',
61
+ '/dataset/~sa-1b/data/sa_000357/sa_3572319.jpg',
62
+ '/dataset/~sa-1b/data/sa_000045/sa_457934.jpg']
63
+
64
+ prompt = ['A white car is parked in front of the factory',
65
+ 'church with cemetery next to it',
66
+ 'A house with a red brick roof']
67
+
68
+ imgs = [pil_to_tensor(Image.open(p)) for p in image_path]
69
+ depth_imgs = [depth_model(img.unsqueeze(0).cpu().numpy()) for img in imgs]
70
+
71
+ masks = []
72
+ for rgb_path in image_path:
73
+ anno = json.load(open(rgb_path.replace('.jpg', '.json')))['annotations']
74
+ random.shuffle(anno)
75
+ object_num = random.randint(5, 10)
76
+ mask = np.array(coco_mask.decode(anno[0]['segmentation']), dtype=np.uint8)
77
+ for single_anno in (anno[0:object_num] if len(anno)>object_num else anno):
78
+ mask += np.array(coco_mask.decode(single_anno['segmentation']), dtype=np.uint8)
79
+ mask = mask
80
+ mask = torch.stack([torch.tensor(mask) * 3], dim=0)
81
+ masks.append(mask)
82
+
83
+ # mask = torch.zeros((512,512))
84
+ # mask[100:300, 200:400] = 1
85
+ # masks.append(mask)
86
+
87
+ resize_transform = Resize(size=[512, 512], interpolation=InterpolationMode.NEAREST_EXACT)
88
+ imgs = [resize_transform(img) for img in imgs]
89
+ depth_imgs = [resize_transform(depth_img.unsqueeze(0)) for depth_img in depth_imgs]
90
+ masks = [resize_transform(mask.unsqueeze(0)) for mask in masks]
91
+
92
+ # for gs in [1,2,3,4,5]:
93
+ for i in range(len(imgs)):
94
+ output_image = model._rgbd_inpaint(imgs[i], depth_imgs[i].unsqueeze(0), masks[i], [prompt[i]], processing_res=512,
95
+ guidance_scale=3, mode='joint_inpaint' #'full_rgb_depth_inpaint', 'full_depth_rgb_inpaint', 'joint_inpaint'
96
+ )
97
+ output_image.save(f'./joint-{i}.jpg')