File size: 1,998 Bytes
aca81a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from controlnet_aux import OpenposeDetector
from diffusers import StableDiffusionInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
from src.ControlNetInpaint.src.pipeline_stable_diffusion_controlnet_inpaint import *
from kornia.filters import gaussian_blur2d

if not 'controlnet' in globals():
    print('Loading ControlNet...')
    controlnet = ControlNetModel.from_pretrained(
        "fusing/stable-diffusion-v1-5-controlnet-openpose", torch_dtype=torch.float16
    )

if 'pipe' not in globals():
    print('Loading SD...')
    pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
         "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
     ).to('cuda')
    print('DONE')
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
    
if 'openpose' not in globals():
    print('Loading OpenPose...')
    openpose = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
    print('DONE')

def synthesis(image, mask, prompt="", n_prompt="", num_steps=20, seed=0, remix=True):
    
    # 1. Get pose
    with torch.no_grad():
        pose_image = openpose(image)
        pose_image=pose_image.resize(image.size)
    
    # generate image
    generator = torch.manual_seed(seed)
    new_image = pipe(
        prompt,
        negative_prompt = n_prompt,
        generator=generator,
        num_inference_steps=num_steps,
        image=image,
        control_image=pose_image,
        mask_image=(mask==False).float().numpy(),
    ).images
    
    if remix:
        for idx in range(len(new_image)):
            mask =  gaussian_blur2d(1.0*mask[None,None,:,:],
                                    kernel_size=(11, 11),
                                    sigma=(29, 29)
                                   ).squeeze().clip(0,1)
            
            new_image[idx] = (mask[:,:,None]*np.asarray(image) + (1-mask[:,:,None])*np.asarray(new_image[idx].resize(image.size))).int().numpy()
    
    return new_image