Upload 15 files
Browse files- config.json +42 -0
- control_v2p_sd15_mediapipe_face.full.ckpt +3 -0
- control_v2p_sd15_mediapipe_face.pth +3 -0
- control_v2p_sd15_mediapipe_face.safetensors +3 -0
- control_v2p_sd15_mediapipe_face.yaml +79 -0
- diffusion_pytorch_model.bin +3 -0
- diffusion_pytorch_model.fp16.bin +3 -0
- diffusion_pytorch_model.fp16.safetensors +3 -0
- gradio_face2image.py +105 -0
- laion_face_common.py +180 -0
- laion_face_dataset.py +55 -0
- tool_download_face_targets.py +86 -0
- tool_generate_face_poses.py +180 -0
- train_laion_face.py +46 -0
- train_laion_face_sd15.py +42 -0
config.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "ControlNetModel",
|
3 |
+
"_diffusers_version": "0.15.0.dev0",
|
4 |
+
"_name_or_path": "/home/josephcatrambone/ControlNet/models",
|
5 |
+
"act_fn": "silu",
|
6 |
+
"attention_head_dim": 8,
|
7 |
+
"block_out_channels": [
|
8 |
+
320,
|
9 |
+
640,
|
10 |
+
1280,
|
11 |
+
1280
|
12 |
+
],
|
13 |
+
"class_embed_type": null,
|
14 |
+
"conditioning_embedding_out_channels": [
|
15 |
+
16,
|
16 |
+
32,
|
17 |
+
96,
|
18 |
+
256
|
19 |
+
],
|
20 |
+
"controlnet_conditioning_channel_order": "rgb",
|
21 |
+
"cross_attention_dim": 768,
|
22 |
+
"down_block_types": [
|
23 |
+
"CrossAttnDownBlock2D",
|
24 |
+
"CrossAttnDownBlock2D",
|
25 |
+
"CrossAttnDownBlock2D",
|
26 |
+
"DownBlock2D"
|
27 |
+
],
|
28 |
+
"downsample_padding": 1,
|
29 |
+
"flip_sin_to_cos": true,
|
30 |
+
"freq_shift": 0,
|
31 |
+
"in_channels": 4,
|
32 |
+
"layers_per_block": 2,
|
33 |
+
"mid_block_scale_factor": 1,
|
34 |
+
"norm_eps": 1e-05,
|
35 |
+
"norm_num_groups": 32,
|
36 |
+
"num_class_embeds": null,
|
37 |
+
"only_cross_attention": false,
|
38 |
+
"projection_class_embeddings_input_dim": null,
|
39 |
+
"resnet_time_scale_shift": "default",
|
40 |
+
"upcast_attention": null,
|
41 |
+
"use_linear_projection": false
|
42 |
+
}
|
control_v2p_sd15_mediapipe_face.full.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a2a71953d7372d5585899b44693a7532ebbf80c091108ae2b8987ca93cc2dac2
|
3 |
+
size 8601300183
|
control_v2p_sd15_mediapipe_face.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2f2ccead3a8c0b9fbf9cad7b8eaa29834983ced916c766a92fb84db34ff29e43
|
3 |
+
size 1445239863
|
control_v2p_sd15_mediapipe_face.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5be501156709895f0b14a7ec76faae7cf0a105f76895252a2c69db541629628f
|
3 |
+
size 1445154814
|
control_v2p_sd15_mediapipe_face.yaml
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: cldm.cldm.ControlLDM
|
3 |
+
params:
|
4 |
+
linear_start: 0.00085
|
5 |
+
linear_end: 0.0120
|
6 |
+
num_timesteps_cond: 1
|
7 |
+
log_every_t: 200
|
8 |
+
timesteps: 1000
|
9 |
+
first_stage_key: "jpg"
|
10 |
+
cond_stage_key: "txt"
|
11 |
+
control_key: "hint"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
only_mid_control: False
|
20 |
+
|
21 |
+
control_stage_config:
|
22 |
+
target: cldm.cldm.ControlNet
|
23 |
+
params:
|
24 |
+
image_size: 32 # unused
|
25 |
+
in_channels: 4
|
26 |
+
hint_channels: 3
|
27 |
+
model_channels: 320
|
28 |
+
attention_resolutions: [ 4, 2, 1 ]
|
29 |
+
num_res_blocks: 2
|
30 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
31 |
+
num_heads: 8
|
32 |
+
use_spatial_transformer: True
|
33 |
+
transformer_depth: 1
|
34 |
+
context_dim: 768
|
35 |
+
use_checkpoint: True
|
36 |
+
legacy: False
|
37 |
+
|
38 |
+
unet_config:
|
39 |
+
target: cldm.cldm.ControlledUnetModel
|
40 |
+
params:
|
41 |
+
image_size: 32 # unused
|
42 |
+
in_channels: 4
|
43 |
+
out_channels: 4
|
44 |
+
model_channels: 320
|
45 |
+
attention_resolutions: [ 4, 2, 1 ]
|
46 |
+
num_res_blocks: 2
|
47 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
48 |
+
num_heads: 8
|
49 |
+
use_spatial_transformer: True
|
50 |
+
transformer_depth: 1
|
51 |
+
context_dim: 768
|
52 |
+
use_checkpoint: True
|
53 |
+
legacy: False
|
54 |
+
|
55 |
+
first_stage_config:
|
56 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
57 |
+
params:
|
58 |
+
embed_dim: 4
|
59 |
+
monitor: val/rec_loss
|
60 |
+
ddconfig:
|
61 |
+
double_z: true
|
62 |
+
z_channels: 4
|
63 |
+
resolution: 256
|
64 |
+
in_channels: 3
|
65 |
+
out_ch: 3
|
66 |
+
ch: 128
|
67 |
+
ch_mult:
|
68 |
+
- 1
|
69 |
+
- 2
|
70 |
+
- 4
|
71 |
+
- 4
|
72 |
+
num_res_blocks: 2
|
73 |
+
attn_resolutions: []
|
74 |
+
dropout: 0.0
|
75 |
+
lossconfig:
|
76 |
+
target: torch.nn.Identity
|
77 |
+
|
78 |
+
cond_stage_config:
|
79 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
diffusion_pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f63de389f776b75bb11f10487a187573aea84f9a51debd08f314bd084e7fb362
|
3 |
+
size 1445254969
|
diffusion_pytorch_model.fp16.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0c37b3dd41e956160909129b50f84fd938116550727b491192cbdbe6f896cd7b
|
3 |
+
size 722696633
|
diffusion_pytorch_model.fp16.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9fb50465b4fd7e15f0dc7df8031767e57309cfda2917082485bcf6c11bedb540
|
3 |
+
size 722598642
|
gradio_face2image.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Mapping
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import numpy
|
6 |
+
import torch
|
7 |
+
import random
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from cldm.model import create_model, load_state_dict
|
11 |
+
from cldm.ddim_hacked import DDIMSampler
|
12 |
+
from laion_face_common import generate_annotation
|
13 |
+
from share import *
|
14 |
+
|
15 |
+
|
16 |
+
model = create_model('./control_v2p_sd21_mediapipe_face.yaml').cpu()
|
17 |
+
model.load_state_dict(load_state_dict('./control_v2p_sd21_mediapipe_face.full.ckpt', location='cuda'))
|
18 |
+
model = model.cuda()
|
19 |
+
ddim_sampler = DDIMSampler(model) # ControlNet _only_ works with DDIM.
|
20 |
+
|
21 |
+
|
22 |
+
def process(input_image: Image.Image, prompt, a_prompt, n_prompt, max_faces, num_samples, ddim_steps, guess_mode, strength, scale, seed, eta):
|
23 |
+
with torch.no_grad():
|
24 |
+
empty = generate_annotation(input_image, max_faces)
|
25 |
+
visualization = Image.fromarray(empty) # Save to help debug.
|
26 |
+
|
27 |
+
empty = numpy.moveaxis(empty, 2, 0) # h, w, c -> c, h, w
|
28 |
+
control = torch.from_numpy(empty.copy()).float().cuda() / 255.0
|
29 |
+
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
30 |
+
# control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
31 |
+
|
32 |
+
# Sanity check the dimensions.
|
33 |
+
B, C, H, W = control.shape
|
34 |
+
assert C == 3
|
35 |
+
assert B == num_samples
|
36 |
+
|
37 |
+
if seed != -1:
|
38 |
+
random.seed(seed)
|
39 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
40 |
+
numpy.random.seed(seed)
|
41 |
+
torch.manual_seed(seed)
|
42 |
+
torch.cuda.manual_seed(seed)
|
43 |
+
torch.backends.cudnn.deterministic = True
|
44 |
+
|
45 |
+
if config.save_memory:
|
46 |
+
model.low_vram_shift(is_diffusing=False)
|
47 |
+
|
48 |
+
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
49 |
+
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
|
50 |
+
shape = (4, H // 8, W // 8)
|
51 |
+
|
52 |
+
if config.save_memory:
|
53 |
+
model.low_vram_shift(is_diffusing=True)
|
54 |
+
|
55 |
+
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
|
56 |
+
samples, intermediates = ddim_sampler.sample(
|
57 |
+
ddim_steps,
|
58 |
+
num_samples,
|
59 |
+
shape,
|
60 |
+
cond,
|
61 |
+
verbose=False,
|
62 |
+
eta=eta,
|
63 |
+
unconditional_guidance_scale=scale,
|
64 |
+
unconditional_conditioning=un_cond
|
65 |
+
)
|
66 |
+
|
67 |
+
if config.save_memory:
|
68 |
+
model.low_vram_shift(is_diffusing=False)
|
69 |
+
|
70 |
+
x_samples = model.decode_first_stage(samples)
|
71 |
+
# x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(numpy.uint8)
|
72 |
+
x_samples = numpy.moveaxis((x_samples * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(numpy.uint8), 1, -1) # b, c, h, w -> b, h, w, c
|
73 |
+
results = [visualization] + [x_samples[i] for i in range(num_samples)]
|
74 |
+
|
75 |
+
return results
|
76 |
+
|
77 |
+
|
78 |
+
block = gr.Blocks().queue()
|
79 |
+
with block:
|
80 |
+
with gr.Row():
|
81 |
+
gr.Markdown("## Control Stable Diffusion with a Facial Pose")
|
82 |
+
with gr.Row():
|
83 |
+
with gr.Column():
|
84 |
+
input_image = gr.Image(source='upload', type="numpy")
|
85 |
+
prompt = gr.Textbox(label="Prompt")
|
86 |
+
run_button = gr.Button(label="Run")
|
87 |
+
with gr.Accordion("Advanced options", open=False):
|
88 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
89 |
+
max_faces = gr.Slider(label="Max Faces", minimum=1, maximum=5, value=1, step=1)
|
90 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
91 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
92 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
93 |
+
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
94 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
95 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
96 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
97 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
98 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
99 |
+
with gr.Column():
|
100 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
101 |
+
ips = [input_image, prompt, a_prompt, n_prompt, max_faces, num_samples, ddim_steps, guess_mode, strength, scale, seed, eta]
|
102 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
|
103 |
+
|
104 |
+
|
105 |
+
block.launch(server_name='0.0.0.0')
|
laion_face_common.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Mapping
|
2 |
+
|
3 |
+
import mediapipe as mp
|
4 |
+
import numpy
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
mp_drawing = mp.solutions.drawing_utils
|
9 |
+
mp_drawing_styles = mp.solutions.drawing_styles
|
10 |
+
mp_face_detection = mp.solutions.face_detection # Only for counting faces.
|
11 |
+
mp_face_mesh = mp.solutions.face_mesh
|
12 |
+
mp_face_connections = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION
|
13 |
+
mp_hand_connections = mp.solutions.hands_connections.HAND_CONNECTIONS
|
14 |
+
mp_body_connections = mp.solutions.pose_connections.POSE_CONNECTIONS
|
15 |
+
|
16 |
+
DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
|
17 |
+
PoseLandmark = mp.solutions.drawing_styles.PoseLandmark
|
18 |
+
|
19 |
+
f_thick = 2
|
20 |
+
f_rad = 1
|
21 |
+
right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
|
22 |
+
right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
|
23 |
+
right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
|
24 |
+
left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
|
25 |
+
left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
|
26 |
+
left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
|
27 |
+
mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad)
|
28 |
+
head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
|
29 |
+
|
30 |
+
# mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
|
31 |
+
face_connection_spec = {}
|
32 |
+
for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
|
33 |
+
face_connection_spec[edge] = head_draw
|
34 |
+
for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
|
35 |
+
face_connection_spec[edge] = left_eye_draw
|
36 |
+
for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
|
37 |
+
face_connection_spec[edge] = left_eyebrow_draw
|
38 |
+
# for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
|
39 |
+
# face_connection_spec[edge] = left_iris_draw
|
40 |
+
for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
|
41 |
+
face_connection_spec[edge] = right_eye_draw
|
42 |
+
for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
|
43 |
+
face_connection_spec[edge] = right_eyebrow_draw
|
44 |
+
# for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
|
45 |
+
# face_connection_spec[edge] = right_iris_draw
|
46 |
+
for edge in mp_face_mesh.FACEMESH_LIPS:
|
47 |
+
face_connection_spec[edge] = mouth_draw
|
48 |
+
iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
|
49 |
+
|
50 |
+
|
51 |
+
def draw_pupils(image, landmark_list, drawing_spec, halfwidth: int = 2):
|
52 |
+
"""We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
|
53 |
+
landmarks. Until our PR is merged into mediapipe, we need this separate method."""
|
54 |
+
if len(image.shape) != 3:
|
55 |
+
raise ValueError("Input image must be H,W,C.")
|
56 |
+
image_rows, image_cols, image_channels = image.shape
|
57 |
+
if image_channels != 3: # BGR channels
|
58 |
+
raise ValueError('Input image must contain three channel bgr data.')
|
59 |
+
for idx, landmark in enumerate(landmark_list.landmark):
|
60 |
+
if (
|
61 |
+
(landmark.HasField('visibility') and landmark.visibility < 0.9) or
|
62 |
+
(landmark.HasField('presence') and landmark.presence < 0.5)
|
63 |
+
):
|
64 |
+
continue
|
65 |
+
if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
|
66 |
+
continue
|
67 |
+
image_x = int(image_cols*landmark.x)
|
68 |
+
image_y = int(image_rows*landmark.y)
|
69 |
+
draw_color = None
|
70 |
+
if isinstance(drawing_spec, Mapping):
|
71 |
+
if drawing_spec.get(idx) is None:
|
72 |
+
continue
|
73 |
+
else:
|
74 |
+
draw_color = drawing_spec[idx].color
|
75 |
+
elif isinstance(drawing_spec, DrawingSpec):
|
76 |
+
draw_color = drawing_spec.color
|
77 |
+
image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color
|
78 |
+
|
79 |
+
|
80 |
+
def reverse_channels(image):
|
81 |
+
"""Given a numpy array in RGB form, convert to BGR. Will also convert from BGR to RGB."""
|
82 |
+
# im[:,:,::-1] is a neat hack to convert BGR to RGB by reversing the indexing order.
|
83 |
+
# im[:,:,::[2,1,0]] would also work but makes a copy of the data.
|
84 |
+
return image[:, :, ::-1]
|
85 |
+
|
86 |
+
|
87 |
+
def generate_annotation(
|
88 |
+
input_image: Image.Image,
|
89 |
+
max_faces: int,
|
90 |
+
min_face_size_pixels: int = 0,
|
91 |
+
return_annotation_data: bool = False
|
92 |
+
):
|
93 |
+
"""
|
94 |
+
Find up to 'max_faces' inside the provided input image.
|
95 |
+
If min_face_size_pixels is provided and nonzero it will be used to filter faces that occupy less than this many
|
96 |
+
pixels in the image.
|
97 |
+
If return_annotation_data is TRUE (default: false) then in addition to returning the 'detected face' image, three
|
98 |
+
additional parameters will be returned: faces before filtering, faces after filtering, and an annotation image.
|
99 |
+
The faces_before_filtering return value is the number of faces detected in an image with no filtering.
|
100 |
+
faces_after_filtering is the number of faces remaining after filtering small faces.
|
101 |
+
|
102 |
+
:return:
|
103 |
+
If 'return_annotation_data==True', returns (numpy array, numpy array, int, int).
|
104 |
+
If 'return_annotation_data==False' (default), returns a numpy array.
|
105 |
+
"""
|
106 |
+
with mp_face_mesh.FaceMesh(
|
107 |
+
static_image_mode=True,
|
108 |
+
max_num_faces=max_faces,
|
109 |
+
refine_landmarks=True,
|
110 |
+
min_detection_confidence=0.5,
|
111 |
+
) as facemesh:
|
112 |
+
img_rgb = numpy.asarray(input_image)
|
113 |
+
results = facemesh.process(img_rgb).multi_face_landmarks
|
114 |
+
|
115 |
+
faces_found_before_filtering = len(results)
|
116 |
+
|
117 |
+
# Filter faces that are too small
|
118 |
+
filtered_landmarks = []
|
119 |
+
for lm in results:
|
120 |
+
landmarks = lm.landmark
|
121 |
+
face_rect = [
|
122 |
+
landmarks[0].x,
|
123 |
+
landmarks[0].y,
|
124 |
+
landmarks[0].x,
|
125 |
+
landmarks[0].y,
|
126 |
+
] # Left, up, right, down.
|
127 |
+
for i in range(len(landmarks)):
|
128 |
+
face_rect[0] = min(face_rect[0], landmarks[i].x)
|
129 |
+
face_rect[1] = min(face_rect[1], landmarks[i].y)
|
130 |
+
face_rect[2] = max(face_rect[2], landmarks[i].x)
|
131 |
+
face_rect[3] = max(face_rect[3], landmarks[i].y)
|
132 |
+
if min_face_size_pixels > 0:
|
133 |
+
face_width = abs(face_rect[2] - face_rect[0])
|
134 |
+
face_height = abs(face_rect[3] - face_rect[1])
|
135 |
+
face_width_pixels = face_width * input_image.size[0]
|
136 |
+
face_height_pixels = face_height * input_image.size[1]
|
137 |
+
face_size = min(face_width_pixels, face_height_pixels)
|
138 |
+
if face_size >= min_face_size_pixels:
|
139 |
+
filtered_landmarks.append(lm)
|
140 |
+
else:
|
141 |
+
filtered_landmarks.append(lm)
|
142 |
+
|
143 |
+
faces_remaining_after_filtering = len(filtered_landmarks)
|
144 |
+
|
145 |
+
# Annotations are drawn in BGR for some reason, but we don't need to flip a zero-filled image at the start.
|
146 |
+
empty = numpy.zeros_like(img_rgb)
|
147 |
+
|
148 |
+
# Draw detected faces:
|
149 |
+
for face_landmarks in filtered_landmarks:
|
150 |
+
mp_drawing.draw_landmarks(
|
151 |
+
empty,
|
152 |
+
face_landmarks,
|
153 |
+
connections=face_connection_spec.keys(),
|
154 |
+
landmark_drawing_spec=None,
|
155 |
+
connection_drawing_spec=face_connection_spec
|
156 |
+
)
|
157 |
+
draw_pupils(empty, face_landmarks, iris_landmark_spec, 2)
|
158 |
+
|
159 |
+
# Flip BGR back to RGB.
|
160 |
+
empty = reverse_channels(empty)
|
161 |
+
|
162 |
+
# We might have to generate a composite.
|
163 |
+
if return_annotation_data:
|
164 |
+
# Note that we're copying the input image AND flipping the channels so we can draw on top of it.
|
165 |
+
annotated = reverse_channels(numpy.asarray(input_image)).copy()
|
166 |
+
for face_landmarks in filtered_landmarks:
|
167 |
+
mp_drawing.draw_landmarks(
|
168 |
+
empty,
|
169 |
+
face_landmarks,
|
170 |
+
connections=face_connection_spec.keys(),
|
171 |
+
landmark_drawing_spec=None,
|
172 |
+
connection_drawing_spec=face_connection_spec
|
173 |
+
)
|
174 |
+
draw_pupils(empty, face_landmarks, iris_landmark_spec, 2)
|
175 |
+
annotated = reverse_channels(annotated)
|
176 |
+
|
177 |
+
if not return_annotation_data:
|
178 |
+
return empty
|
179 |
+
else:
|
180 |
+
return empty, annotated, faces_found_before_filtering, faces_remaining_after_filtering
|
laion_face_dataset.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
|
8 |
+
class LaionDataset(Dataset):
|
9 |
+
def __init__(self):
|
10 |
+
self.data = []
|
11 |
+
with open('./training/laion-face-processed/prompt.jsonl', 'rt') as f:
|
12 |
+
for line in f:
|
13 |
+
self.data.append(json.loads(line))
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return len(self.data)
|
17 |
+
|
18 |
+
def __getitem__(self, idx):
|
19 |
+
item = self.data[idx]
|
20 |
+
|
21 |
+
source_filename = os.path.split(item['source'])[-1]
|
22 |
+
target_filename = os.path.split(item['target'])[-1]
|
23 |
+
prompt = item['prompt']
|
24 |
+
|
25 |
+
# If prompt is "" or null, make it something simple.
|
26 |
+
if not prompt:
|
27 |
+
print(f"Image with index {idx} / {source_filename} has no text.")
|
28 |
+
prompt = "an image"
|
29 |
+
|
30 |
+
source_image = Image.open('./training/laion-face-processed/source/' + source_filename).convert("RGB")
|
31 |
+
target_image = Image.open('./training/laion-face-processed/target/' + target_filename).convert("RGB")
|
32 |
+
# Resize the image so that the minimum edge is bigger than 512x512, then crop center.
|
33 |
+
# This may cut off some parts of the face image, but in general they're smaller than 512x512 and we still want
|
34 |
+
# to cover the literal edge cases.
|
35 |
+
img_size = source_image.size
|
36 |
+
scale_factor = 512/min(img_size)
|
37 |
+
source_image = source_image.resize((1+int(img_size[0]*scale_factor), 1+int(img_size[1]*scale_factor)))
|
38 |
+
target_image = target_image.resize((1+int(img_size[0]*scale_factor), 1+int(img_size[1]*scale_factor)))
|
39 |
+
img_size = source_image.size
|
40 |
+
left_padding = (img_size[0] - 512)//2
|
41 |
+
top_padding = (img_size[1] - 512)//2
|
42 |
+
source_image = source_image.crop((left_padding, top_padding, left_padding+512, top_padding+512))
|
43 |
+
target_image = target_image.crop((left_padding, top_padding, left_padding+512, top_padding+512))
|
44 |
+
|
45 |
+
source = numpy.asarray(source_image)
|
46 |
+
target = numpy.asarray(target_image)
|
47 |
+
|
48 |
+
# Normalize source images to [0, 1].
|
49 |
+
source = source.astype(numpy.float32) / 255.0
|
50 |
+
|
51 |
+
# Normalize target images to [-1, 1].
|
52 |
+
target = (target.astype(numpy.float32) / 127.5) - 1.0
|
53 |
+
|
54 |
+
return dict(jpg=target, txt=prompt, hint=source)
|
55 |
+
|
tool_download_face_targets.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
"""
|
3 |
+
tool_download_face_targets.py
|
4 |
+
|
5 |
+
Reads in the metadata from the LAION images and begins downloading all images.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
import time
|
12 |
+
import urllib
|
13 |
+
import urllib.request
|
14 |
+
try:
|
15 |
+
from tqdm import tqdm
|
16 |
+
except ImportError:
|
17 |
+
# Wrap this method into the identity.
|
18 |
+
print("TQDM not found. Progress will be quiet without 'verbose'.")
|
19 |
+
def tqdm(x):
|
20 |
+
return x
|
21 |
+
|
22 |
+
|
23 |
+
def main(logfile_path: str, verbose: bool = False, pause_between_fetches: float = 0.0):
|
24 |
+
"""Open the metadata.json file from the training directory and fetch all target images."""
|
25 |
+
# Toggle a function pointer so we don't have to check verbosity everywhere.
|
26 |
+
def out(x):
|
27 |
+
pass
|
28 |
+
if verbose:
|
29 |
+
out = print
|
30 |
+
|
31 |
+
log = open(logfile_path, 'at')
|
32 |
+
skipped_image_count = 0
|
33 |
+
errored_image_count = 0
|
34 |
+
successful_image_count = 0
|
35 |
+
if not os.path.exists("training"):
|
36 |
+
print("ERROR: training directory does not exist in the current directory.")
|
37 |
+
print("Has the archive been unzipped?")
|
38 |
+
print("Are you running from the project root?")
|
39 |
+
return 2 # BASH: No such directory.
|
40 |
+
if not os.path.exists("training/laion-face-processed/metadata.json"):
|
41 |
+
print("ERROR: metadata.json was not found in training/laion-face-processed.")
|
42 |
+
return 2
|
43 |
+
with open("training/laion-face-processed/metadata.json", 'rt') as md_in:
|
44 |
+
metadata = json.load(md_in)
|
45 |
+
# Create the directory for targets if it does not exist.
|
46 |
+
if not os.path.exists("training/laion-face-processed/target"):
|
47 |
+
os.mkdir("training/laion-face-processed/target")
|
48 |
+
for image_id, image_data in tqdm(metadata.items()):
|
49 |
+
filename = f"training/laion-face-processed/target/{image_id}.jpg"
|
50 |
+
if os.path.exists(filename):
|
51 |
+
out(f"Skipping {image_id}: file exists.")
|
52 |
+
skipped_image_count += 1
|
53 |
+
continue
|
54 |
+
if not download_file(image_data['url'], filename, verbose):
|
55 |
+
error_message = f"Problem downloading {image_id}"
|
56 |
+
out(error_message)
|
57 |
+
log.write(error_message + "\n")
|
58 |
+
log.flush() # Flush often in case we crash.
|
59 |
+
errored_image_count += 1
|
60 |
+
if pause_between_fetches > 0.0:
|
61 |
+
time.sleep(pause_between_fetches)
|
62 |
+
successful_image_count += 1
|
63 |
+
log.close()
|
64 |
+
print("Run success.")
|
65 |
+
print(f"{skipped_image_count} images skipped")
|
66 |
+
print(f"{errored_image_count} images failed to download")
|
67 |
+
print(f"{successful_image_count} images downloaded")
|
68 |
+
|
69 |
+
|
70 |
+
def download_file(url: str, output_path: str, verbose: bool = False) -> bool:
|
71 |
+
"""Download the file with the given URL and save it to the specified path. Return true on success."""
|
72 |
+
try:
|
73 |
+
r = urllib.request.urlopen(url)
|
74 |
+
if not r.status == 200:
|
75 |
+
return False
|
76 |
+
with open(output_path, 'wb') as fout:
|
77 |
+
fout.write(r.read())
|
78 |
+
return True
|
79 |
+
except Exception as e:
|
80 |
+
if verbose:
|
81 |
+
print(e)
|
82 |
+
return False
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
main("downloads.log", verbose="-v" in sys.argv)
|
tool_generate_face_poses.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
from glob import glob
|
6 |
+
from typing import Mapping
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from laion_face_common import generate_annotation
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class RunProgress:
|
16 |
+
pending: list = field(default_factory=list)
|
17 |
+
success: list = field(default_factory=list)
|
18 |
+
skipped_size: list = field(default_factory=list)
|
19 |
+
skipped_nsfw: list = field(default_factory=list)
|
20 |
+
skipped_noface: list = field(default_factory=list)
|
21 |
+
skipped_smallface: list = field(default_factory=list)
|
22 |
+
|
23 |
+
|
24 |
+
def main(
|
25 |
+
status_filename: str,
|
26 |
+
prompt_filename: str,
|
27 |
+
input_glob: str,
|
28 |
+
output_directory: str,
|
29 |
+
annotated_output_directory: str = "",
|
30 |
+
min_image_size: int = 384,
|
31 |
+
max_image_size: int = 32766,
|
32 |
+
min_face_size_pixels: int = 64,
|
33 |
+
prompt_mapping: dict = None, # If present, maps a filename to a text prompt.
|
34 |
+
):
|
35 |
+
status = RunProgress()
|
36 |
+
|
37 |
+
if os.path.exists(status_filename):
|
38 |
+
print("Continuing from checkpoint.")
|
39 |
+
# Restore a saved state:
|
40 |
+
status_temp = json.load(open(status_filename, 'rt'))
|
41 |
+
for k in status.__dict__.keys():
|
42 |
+
status.__setattr__(k, status_temp[k])
|
43 |
+
# Output label file:
|
44 |
+
pout = open(prompt_filename, 'at')
|
45 |
+
else:
|
46 |
+
print("Starting run.")
|
47 |
+
status = RunProgress()
|
48 |
+
status.pending = list(glob(input_glob))
|
49 |
+
# Output label file:
|
50 |
+
pout = open(prompt_filename, 'wt')
|
51 |
+
with open(status_filename, 'wt') as fout:
|
52 |
+
json.dump(status.__dict__, fout)
|
53 |
+
|
54 |
+
print(f"{len(status.pending)} images remaining")
|
55 |
+
|
56 |
+
# If we don't have a preexisting set of labels (like for ImageNet/MSCOCO), just null-fill the mapping.
|
57 |
+
# We will try on a per-image basis to see if there's a metadata .json.
|
58 |
+
if prompt_mapping is None:
|
59 |
+
prompt_mapping = dict()
|
60 |
+
|
61 |
+
step = 0
|
62 |
+
with tqdm(total=len(status.pending)) as pbar:
|
63 |
+
while len(status.pending) > 0:
|
64 |
+
full_filename = status.pending.pop()
|
65 |
+
pbar.update(1)
|
66 |
+
step += 1
|
67 |
+
|
68 |
+
if step % 100 == 0:
|
69 |
+
# Checkpoint save:
|
70 |
+
with open(status_filename, 'wt') as fout:
|
71 |
+
json.dump(status.__dict__, fout)
|
72 |
+
|
73 |
+
_fpath, fname = os.path.split(full_filename)
|
74 |
+
|
75 |
+
# Make our output filenames.
|
76 |
+
# We used to do this here so we could check if a file existed before writing, then skip it, but since we
|
77 |
+
# have a 'status' that we cache and update, we no longer have to do this check.
|
78 |
+
annotation_filename = ""
|
79 |
+
if annotated_output_directory:
|
80 |
+
annotation_filename = os.path.join(annotated_output_directory, fname)
|
81 |
+
output_filename = os.path.join(output_directory, fname)
|
82 |
+
|
83 |
+
# The LAION dataset has accompanying .json files with each image.
|
84 |
+
partial_filename, extension = os.path.splitext(full_filename)
|
85 |
+
candidate_json_fullpath = partial_filename + ".json"
|
86 |
+
image_metadata = {}
|
87 |
+
if os.path.exists(candidate_json_fullpath):
|
88 |
+
try:
|
89 |
+
image_metadata = json.load(open(candidate_json_fullpath, 'rt'))
|
90 |
+
except Exception as e:
|
91 |
+
print(e)
|
92 |
+
if "NSFW" in image_metadata:
|
93 |
+
nsfw_marker = image_metadata.get("NSFW") # This can be "", None, or other weird things.
|
94 |
+
if nsfw_marker is not None and nsfw_marker.lower() != "unlikely":
|
95 |
+
# Skip NSFW images.
|
96 |
+
status.skipped_nsfw.append(full_filename)
|
97 |
+
continue
|
98 |
+
|
99 |
+
# Try to get a prompt/caption from the metadata or the prompt mapping.
|
100 |
+
image_prompt = image_metadata.get("caption", prompt_mapping.get(fname, ""))
|
101 |
+
|
102 |
+
# Load image:
|
103 |
+
img = Image.open(full_filename).convert("RGB")
|
104 |
+
img_width = img.size[0]
|
105 |
+
img_height = img.size[1]
|
106 |
+
img_size = min(img.size[0], img.size[1])
|
107 |
+
if img_size < min_image_size or max(img_width, img_height) > max_image_size:
|
108 |
+
status.skipped_size.append(full_filename)
|
109 |
+
continue
|
110 |
+
|
111 |
+
# We re-initialize the detector every time because it has a habit of triggering weird race conditions.
|
112 |
+
empty, annotated, faces_before_filtering, faces_after_filtering = generate_annotation(
|
113 |
+
img,
|
114 |
+
max_faces=5,
|
115 |
+
min_face_size_pixels=min_face_size_pixels,
|
116 |
+
return_annotation_data=True
|
117 |
+
)
|
118 |
+
if faces_before_filtering == 0:
|
119 |
+
# Skip images with no faces.
|
120 |
+
status.skipped_noface.append(full_filename)
|
121 |
+
continue
|
122 |
+
if faces_after_filtering == 0:
|
123 |
+
# Skip images with no faces large enough
|
124 |
+
status.skipped_smallface.append(full_filename)
|
125 |
+
continue
|
126 |
+
|
127 |
+
Image.fromarray(empty).save(output_filename)
|
128 |
+
if annotation_filename:
|
129 |
+
Image.fromarray(annotated).save(annotation_filename)
|
130 |
+
|
131 |
+
# See https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md for the training file format.
|
132 |
+
# prompt.json
|
133 |
+
# a JSONL file with {"source": "source/0.jpg", "target": "target/0.jpg", "prompt": "..."}.
|
134 |
+
# a source/xxxxx.jpg or source/xxxx.png file for each of the inputs.
|
135 |
+
# a target/xxxxx.jpg for each of the outputs.
|
136 |
+
pout.write(json.dumps({
|
137 |
+
"source": os.path.join(output_directory, fname),
|
138 |
+
"target": full_filename,
|
139 |
+
"prompt": image_prompt,
|
140 |
+
}) + "\n")
|
141 |
+
pout.flush()
|
142 |
+
status.success.append(full_filename)
|
143 |
+
|
144 |
+
# We do save every 100 iterations, but it's good to save on completion, too.
|
145 |
+
with open(status_filename, 'wt') as fout:
|
146 |
+
json.dump(status.__dict__, fout)
|
147 |
+
|
148 |
+
pout.close()
|
149 |
+
print("Done!")
|
150 |
+
print(f"{len(status.success)} images added to dataset.")
|
151 |
+
print(f"{len(status.skipped_size)} images rejected for size.")
|
152 |
+
print(f"{len(status.skipped_smallface)} images rejected for having faces too small.")
|
153 |
+
print(f"{len(status.skipped_noface)} images rejected for not having faces.")
|
154 |
+
print(f"{len(status.skipped_nsfw)} images rejected for NSFW.")
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
if len(sys.argv) >= 3 and "-h" not in sys.argv:
|
159 |
+
prompt_jsonl = sys.argv[1]
|
160 |
+
in_glob = sys.argv[2] # Should probably be in a directory called "target/*.jpg".
|
161 |
+
output_dir = sys.argv[3] # Should probably be a directory called "source".
|
162 |
+
annotation_dir = ""
|
163 |
+
if len(sys.argv) > 4:
|
164 |
+
annotation_dir = sys.argv[4]
|
165 |
+
main("generate_face_poses_checkpoint.json", prompt_jsonl, in_glob, output_dir, annotation_dir)
|
166 |
+
else:
|
167 |
+
print(f"""Usage:
|
168 |
+
python {sys.argv[0]} prompt.jsonl target/*.jpg source/ [annotated/]
|
169 |
+
source and target are slightly confusing in this context. We are writing the image names to prompt.jsonl, so
|
170 |
+
the naming system has to be consistent with what ControlNet expects. In ControlNet, the source is the input and
|
171 |
+
target is the output. We are generating source images from targets in this application, so the second argument
|
172 |
+
should be a folder full of images. The third argument should be 'source', where the images should be places.
|
173 |
+
Optionally, an 'annotated' directory can be provided. Augmented images will be placed here.
|
174 |
+
|
175 |
+
A checkpoint file named 'generate_face_poses_checkpoint.json' will be created in the place where the script is
|
176 |
+
run. If a run is cancelled, it can be resumed from this checkpoint.
|
177 |
+
|
178 |
+
If invoking the script from bash, do not forget to enclose globs with quotes. Example usage:
|
179 |
+
`python ./tool_generate_face_poses.py ./face_prompt.jsonl "/home/josephcatrambone/training_data/data-mscoco/images/train2017/*" /home/josephcatrambone/training_data/data-mscoco/images/source_2017/`
|
180 |
+
""")
|
train_laion_face.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from share import *
|
2 |
+
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from laion_face_dataset import LaionDataset
|
6 |
+
from cldm.logger import ImageLogger
|
7 |
+
from cldm.model import create_model, load_state_dict
|
8 |
+
|
9 |
+
|
10 |
+
# Configs
|
11 |
+
resume_path = './models/controlnet_sd21_laion_face.ckpt'
|
12 |
+
batch_size = 4
|
13 |
+
logger_freq = 2500
|
14 |
+
learning_rate = 1e-5
|
15 |
+
sd_locked = True
|
16 |
+
only_mid_control = False
|
17 |
+
|
18 |
+
|
19 |
+
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
|
20 |
+
model = create_model('./models/cldm_v21.yaml').cpu()
|
21 |
+
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
|
22 |
+
model.learning_rate = learning_rate
|
23 |
+
model.sd_locked = sd_locked
|
24 |
+
model.only_mid_control = only_mid_control
|
25 |
+
|
26 |
+
|
27 |
+
# Save every so often:
|
28 |
+
ckpt_callback = pl.callbacks.ModelCheckpoint(
|
29 |
+
dirpath="./checkpoints/",
|
30 |
+
filename="ckpt_controlnet_sd21_{epoch}_{step}_{loss}",
|
31 |
+
monitor='train/loss_simple_step',
|
32 |
+
save_top_k=5,
|
33 |
+
every_n_train_steps=5000,
|
34 |
+
save_last=True,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
# Misc
|
39 |
+
dataset = LaionDataset()
|
40 |
+
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
|
41 |
+
logger = ImageLogger(batch_frequency=logger_freq)
|
42 |
+
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger, ckpt_callback])
|
43 |
+
|
44 |
+
|
45 |
+
# Train!
|
46 |
+
trainer.fit(model, dataloader)
|
train_laion_face_sd15.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from share import *
|
2 |
+
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from laion_face_dataset import LaionDataset
|
6 |
+
from cldm.logger import ImageLogger
|
7 |
+
from cldm.model import create_model, load_state_dict
|
8 |
+
|
9 |
+
|
10 |
+
# Configs
|
11 |
+
resume_path = './models/controlnet_sd15_laion_face.ckpt'
|
12 |
+
batch_size = 8
|
13 |
+
logger_freq = 2500
|
14 |
+
learning_rate = 1e-5
|
15 |
+
sd_locked = True
|
16 |
+
only_mid_control = False
|
17 |
+
|
18 |
+
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
|
19 |
+
model = create_model('./models/cldm_v15.yaml').cpu()
|
20 |
+
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
|
21 |
+
model.learning_rate = learning_rate
|
22 |
+
model.sd_locked = sd_locked
|
23 |
+
model.only_mid_control = only_mid_control
|
24 |
+
|
25 |
+
# Save every so often:
|
26 |
+
ckpt_callback = pl.callbacks.ModelCheckpoint(
|
27 |
+
dirpath="./checkpoints/",
|
28 |
+
filename="controlnet_sd15_laion_face_{epoch}_{step}_{loss}.ckpt",
|
29 |
+
monitor='train/loss_simple_step',
|
30 |
+
save_top_k=5,
|
31 |
+
every_n_train_steps=5000,
|
32 |
+
save_last=True,
|
33 |
+
)
|
34 |
+
|
35 |
+
# Misc
|
36 |
+
dataset = LaionDataset()
|
37 |
+
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
|
38 |
+
logger = ImageLogger(batch_frequency=logger_freq)
|
39 |
+
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger, ckpt_callback])
|
40 |
+
|
41 |
+
# Train!
|
42 |
+
trainer.fit(model, dataloader)
|