File size: 3,002 Bytes
aca81a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from kornia.morphology import dilation, closing
import requests
from transformers import SamModel, SamProcessor

print('Loading SAM...')
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
print('DONE')

def build_mask(image, faces, hairs):

    # 1. Segmentation
    input_points = faces  # 2D location of the face
    
    with torch.no_grad():
        inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
        outputs = model(**inputs)
        
        masks = processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
        )
        scores = outputs.iou_scores
    
    input_points = hairs  # 2D location of the face
    
    with torch.no_grad():
        inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
        outputs = model(**inputs)
        
        h_masks = processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
        )
        h_scores = outputs.iou_scores

    # 2. Post-processing
    mask=masks[0][0].all(0) | h_masks[0][0].all(0)
    
    # dilation
    tensor = mask[None,None,:,:]
    kernel = torch.ones(3, 3)
    mask = closing(tensor, kernel)[0,0].bool()
    
    return mask

def build_mask_multi(image, faces, hairs):

    all_masks = []
    
    for face,hair in zip(faces,hairs):
        # 1. Segmentation
        input_points = [face]  # 2D location of the face
        
        with torch.no_grad():
            inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
            outputs = model(**inputs)
            
            masks = processor.image_processor.post_process_masks(
                outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
            )
            scores = outputs.iou_scores
        
        input_points = [hair]  # 2D location of the face
        
        with torch.no_grad():
            inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
            outputs = model(**inputs)
            
            h_masks = processor.image_processor.post_process_masks(
                outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
            )
            h_scores = outputs.iou_scores
    
        # 2. Post-processing
        mask=masks[0][0].all(0) | h_masks[0][0].all(0)
        
        # dilation
        mask_T = mask[None,None,:,:]
        kernel = torch.ones(3, 3)
        mask = closing(mask_T, kernel)[0,0].bool()

        all_masks.append(mask)

    mask = all_masks[0]
    for next_mask in all_masks[1:]:
        mask = mask | next_mask
    
    return mask