yyk19 commited on
Commit
b0e2a49
1 Parent(s): 92b8ea7

exchange the checkpoint trained on Laion-OCR 10M.

Browse files
.gitignore CHANGED
@@ -1 +1,3 @@
1
- *.pyc
 
 
 
1
+ *.pyc
2
+ *__pycache__/*
3
+ *__pycache__
config_cuda_ema.yaml ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-6 #1.0e-5 #1.0e-4
3
+ target: cldm.cldm.ControlLDM
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ control_key: "hint"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false
16
+ conditioning_key: crossattn
17
+ monitor: #val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ only_mid_control: False
20
+ sd_locked: True
21
+ use_ema: True #TODO: specify
22
+
23
+ control_stage_config:
24
+ target: cldm.cldm.ControlNet
25
+ params:
26
+ use_checkpoint: True
27
+ image_size: 32 # unused
28
+ in_channels: 4
29
+ hint_channels: 3
30
+ model_channels: 320
31
+ attention_resolutions: [ 4, 2, 1 ]
32
+ num_res_blocks: 2
33
+ channel_mult: [ 1, 2, 4, 4 ]
34
+ num_head_channels: 64 # need to fix for flash-attn
35
+ use_spatial_transformer: True
36
+ use_linear_in_transformer: True
37
+ transformer_depth: 1
38
+ context_dim: 1024
39
+ legacy: False
40
+
41
+ unet_config:
42
+ target: cldm.cldm.ControlledUnetModel
43
+ params:
44
+ use_checkpoint: True
45
+ image_size: 32 # unused
46
+ in_channels: 4
47
+ out_channels: 4
48
+ model_channels: 320
49
+ attention_resolutions: [ 4, 2, 1 ]
50
+ num_res_blocks: 2
51
+ channel_mult: [ 1, 2, 4, 4 ]
52
+ num_head_channels: 64 # need to fix for flash-attn
53
+ use_spatial_transformer: True
54
+ use_linear_in_transformer: True
55
+ transformer_depth: 1
56
+ context_dim: 1024
57
+ legacy: False
58
+
59
+ first_stage_config:
60
+ target: ldm.models.autoencoder.AutoencoderKL
61
+ params:
62
+ embed_dim: 4
63
+ monitor: val/rec_loss
64
+ ddconfig:
65
+ #attn_type: "vanilla-xformers"
66
+ double_z: true
67
+ z_channels: 4
68
+ resolution: 256
69
+ in_channels: 3
70
+ out_ch: 3
71
+ ch: 128
72
+ ch_mult:
73
+ - 1
74
+ - 2
75
+ - 4
76
+ - 4
77
+ num_res_blocks: 2
78
+ attn_resolutions: []
79
+ dropout: 0.0
80
+ lossconfig:
81
+ target: torch.nn.Identity
82
+
83
+ cond_stage_config:
84
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
85
+ params:
86
+ freeze: True
87
+ layer: "penultimate"
88
+ # device: "cpu" #TODO: specify
model.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f5f82f4af7d69b0ffdff6bf3d1b8dc6b13bbf81e28ea0fbacbf68824d2c1f652
3
- size 8129070351
 
 
 
 
model_wo_ema.ckpt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0b86b22188bf580e80773a5ae101bf9787eb258349f3f1acf0ae50fd10cb3fec
3
- size 6671922039
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7ae9f29a41152a85bc5001811f80f701fb8b845b44526483497fdd6f4946e4b
3
+ size 6671914001
scripts/gradio_rendertext.py DELETED
@@ -1,314 +0,0 @@
1
-
2
- from cldm.model import load_state_dict
3
- from cldm.ddim_hacked import DDIMSampler
4
- from ldm.util import instantiate_from_config
5
- import os
6
- from omegaconf import OmegaConf
7
- import argparse, os
8
- from torchvision.transforms import ToTensor
9
- from torch import autocast
10
- from contextlib import nullcontext
11
- from scripts.rendertext_tool import Render_Text, load_model_from_config
12
- # def load_model_from_config(cfg, ckpt, verbose=False, not_use_ckpt=False):
13
- # sd = load_state_dict(ckpt, location='cpu')
14
-
15
- # if "model_ema.input_blocks10in_layers0weight" not in sd:
16
- # cfg.model.params.use_ema = False
17
- # model = instantiate_from_config(cfg.model)
18
-
19
- # if not not_use_ckpt:
20
- # m, u = model.load_state_dict(sd, strict=False)
21
- # if len(m) > 0 and verbose:
22
- # print("missing keys: {}".format(len(m)))
23
- # print(m)
24
- # if len(u) > 0 and verbose:
25
- # print("unexpected keys: {}".format(len(u)))
26
- # print(u)
27
-
28
- # model.cuda()
29
- # model.eval()
30
- # return model
31
-
32
-
33
-
34
-
35
- def parse_args():
36
- parser = argparse.ArgumentParser()
37
- parser.add_argument(
38
- "--cfg",
39
- type=str,
40
- default="configs/stable-diffusion/textcaps_cldm_v20.yaml",
41
- help="path to config which constructs model",
42
- )
43
- parser.add_argument(
44
- "--ckpt",
45
- type=str,
46
- help="path to checkpoint of model",
47
- )
48
- parser.add_argument(
49
- "--hint_range_m11",
50
- action="store_true",
51
- help="the range of the hint image ([-1, 1])",
52
- )
53
- parser.add_argument(
54
- "--precision",
55
- type=str,
56
- help="evaluate at this precision",
57
- choices=["full", "autocast"],
58
- default="full" #"autocast"
59
- )
60
- parser.add_argument(
61
- "--not_use_ckpt",
62
- action="store_true",
63
- help="not to use the ckpt",
64
- )
65
- parser.add_argument(
66
- "--build_demo",
67
- action="store_true",
68
- help="whether to build the demo",
69
- )
70
- parser.add_argument(
71
- "--sep_prompt",
72
- action="store_true",
73
- help="whether to sep the prompt",
74
- )
75
- parser.add_argument(
76
- "--spell_prompt_type",
77
- type=int,
78
- default=1,
79
- help="1: A sign with the word 'xxx' written on it; 2: A sign that says 'xxx'",
80
- )
81
- parser.add_argument(
82
- "--max_num_prompts",
83
- type=int,
84
- default=None,
85
- help="max num of the used prompts",
86
- )
87
- parser.add_argument(
88
- "--grams",
89
- type=int,
90
- default=1,
91
- help="How many grams (words or symbols) to form the to-be-rendered text (used for DrawSpelling Benchmark)",
92
- )
93
- parser.add_argument(
94
- "--num_samples",
95
- type=int,
96
- default=1,
97
- help="how many samples to produce for each given prompt. A.k.a batch size",
98
- )
99
- parser.add_argument(
100
- "--from-file",
101
- type=str,
102
- help="if specified, load prompts from this file, separated by newlines",
103
- )
104
- parser.add_argument(
105
- "--prompt",
106
- type=str,
107
- nargs="?",
108
- default="a sign that says 'Stable Diffusion'",
109
- help="the prompt"
110
- )
111
- parser.add_argument(
112
- "--rendered_txt",
113
- type=str,
114
- nargs="?",
115
- default="Stable Diffusion",
116
- help="the text to render"
117
- )
118
- parser.add_argument(
119
- "--uncond_glycon_img",
120
- action="store_true",
121
- help="whether to set glyph embedding as None while using unconditional conditioning",
122
- )
123
- parser.add_argument(
124
- "--deepspeed_ckpt",
125
- action="store_true",
126
- help="whether to use deepspeed while training",
127
- )
128
- parser.add_argument(
129
- "--glyph_img_size",
130
- type=int,
131
- default=256,
132
- help="the size of input images of the glyph image encoder",
133
- )
134
- parser.add_argument(
135
- "--uncond_glyph_image_type",
136
- type=str,
137
- default="white",
138
- help="the type of rendered glyph images as unconditional conditions while using classifier-free guidance"
139
- )
140
- parser.add_argument(
141
- "--remove_txt_in_prompt",
142
- action="store_true",
143
- help="whether to remove text in the prompt",
144
- )
145
- parser.add_argument(
146
- "--replace_token",
147
- type=str,
148
- default="",
149
- help="the token used to replace"
150
- )
151
- return parser
152
-
153
- if not os.path.basename(os.getcwd()) == "stablediffusion":
154
- os.chdir(os.path.join(os.getcwd(), "stablediffusion"))
155
- print(os.getcwd())
156
- parser = parse_args()
157
- opt = parser.parse_args()
158
-
159
- if opt.deepspeed_ckpt:
160
- assert os.path.isdir(opt.ckpt)
161
- opt.ckpt = os.path.join(opt.ckpt, "checkpoint", "mp_rank_00_model_states.pt")
162
- assert os.path.exists(opt.ckpt)
163
-
164
- cfg = OmegaConf.load(f"{opt.cfg}")
165
- model = load_model_from_config(cfg, f"{opt.ckpt}", verbose=True, not_use_ckpt=opt.not_use_ckpt)
166
- hint_range_m11 = opt.hint_range_m11
167
- sep_prompt = opt.sep_prompt
168
-
169
- ddim_sampler = DDIMSampler(model)
170
- precision_scope = autocast if opt.precision == "autocast" else nullcontext
171
- trans = ToTensor()
172
- render_tool = Render_Text(
173
- model, precision_scope,
174
- trans,
175
- hint_range_m11,
176
- sep_prompt,
177
- uncond_glycon_img= cfg.uncond_glycon_img if hasattr(cfg, "uncond_glycon_img") else opt.uncond_glycon_img,
178
- glyph_control_proc_config= cfg.glyph_control_proc_config if hasattr(cfg, "glyph_control_proc_config") else None,
179
- glyph_img_size = opt.glyph_img_size,
180
- uncond_glyph_image_type = cfg.uncond_glyph_image_type if hasattr(cfg, "uncond_glyph_image_type") else opt.uncond_glyph_image_type,
181
- remove_txt_in_prompt = cfg.remove_txt_in_prompt if hasattr(cfg, "remove_txt_in_prompt") else opt.remove_txt_in_prompt,
182
- replace_token = cfg.replace_token if hasattr(cfg, "replace_token") else opt.replace_token,
183
- )
184
-
185
-
186
- if opt.build_demo:
187
- import gradio as gr
188
- block = gr.Blocks().queue()
189
- with block:
190
- with gr.Row():
191
- gr.Markdown("## Control Stable Diffusion with Glyph Images")
192
- with gr.Row():
193
- with gr.Column():
194
- # input_image = gr.Image(source='upload', type="numpy")
195
- rendered_txt = gr.Textbox(label="rendered_txt")
196
- prompt = gr.Textbox(label="Prompt")
197
- if sep_prompt:
198
- prompt_2 = gr.Textbox(label="Prompt_ControlNet")
199
- else:
200
- prompt_2 = gr.Number(value = 0, visible = False) #None #""
201
- run_button = gr.Button(label="Run")
202
- with gr.Accordion("Advanced options", open=False):
203
- width = gr.Slider(label="bbox_width", minimum=0., maximum=1, value=0.3, step=0.01)
204
- # height = gr.Slider(label="bbox_height", minimum=0., maximum=1, value=0.2, step=0.01)
205
- ratio = gr.Slider(label="bbox_width_height_ratio", minimum=0., maximum=5, value=0., step=0.02)
206
- top_left_x = gr.Slider(label="bbox_top_left_x", minimum=0., maximum=1, value=0.5, step=0.01)
207
- top_left_y = gr.Slider(label="bbox_top_left_y", minimum=0., maximum=1, value=0.5, step=0.01)
208
- yaw = gr.Slider(label="bbox_yaw", minimum=-180, maximum=180, value=0, step=5)
209
- num_rows = gr.Slider(label="num_rows", minimum=1, maximum=4, value=1, step=1)
210
- num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
211
- image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
212
- strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
213
- guess_mode = gr.Checkbox(label='Guess Mode', value=False)
214
- # low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
215
- # high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
216
- ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
217
- scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
218
- seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
219
- eta = gr.Number(label="eta (DDIM)", value=0.0)
220
- a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
221
- n_prompt = gr.Textbox(label="Negative Prompt",
222
- value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
223
- with gr.Column():
224
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
225
- ips = [
226
- rendered_txt, prompt,
227
- width, ratio, # height,
228
- top_left_x, top_left_y, yaw, num_rows,
229
- a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta,
230
- prompt_2
231
- ]
232
- run_button.click(fn=render_tool.process, inputs=ips, outputs=[result_gallery])
233
- # run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
234
-
235
-
236
- block.launch(server_name='0.0.0.0', share=True)
237
- else:
238
- import easyocr
239
- reader = easyocr.Reader(['en'])
240
- # num_samples = 1
241
- # rendered_txt = "happy"
242
- # prompt = "A sign that says 'happy'"
243
-
244
- num_samples = opt.num_samples
245
- print("the num of samples is {}".format(num_samples))
246
- if not opt.from_file:
247
- prompts = [opt.prompt]
248
- data = [opt.rendered_txt]
249
- print("the prompt is {}".format(prompts))
250
- print("the rendered_txt is {}".format(data))
251
- assert prompts is not None
252
- else:
253
- print(f"reading prompts from {opt.from_file}")
254
- with open(opt.from_file, "r") as f:
255
- data = f.read().splitlines()
256
- if "gram" in os.path.basename(opt.from_file):
257
- data = [item.split("\t")[0] for item in data]
258
- if opt.grams > 1:
259
- data = [" ".join(data[i:i + opt.grams]) for i in range(0, len(data), opt.grams)]
260
- if "DrawText_Spelling" in os.path.basename(opt.from_file) or "gram" in os.path.basename(opt.from_file):
261
- if opt.spell_prompt_type == 1:
262
- prompts = ['A sign with the word "{}" written on it'.format(line.strip()) for line in data]
263
- elif opt.spell_prompt_type == 2:
264
- prompts = ["A sign that says '{}'".format(line.strip()) for line in data]
265
- elif opt.spell_prompt_type == 20:
266
- prompts = ['A sign that says "{}"'.format(line.strip()) for line in data]
267
- elif opt.spell_prompt_type == 3:
268
- prompts = ["A whiteboard that says '{}'".format(line.strip()) for line in data]
269
- elif opt.spell_prompt_type == 30:
270
- prompts = ['A whiteboard that says "{}"'.format(line.strip()) for line in data]
271
- else:
272
- print("Only five types of prompt templates are supported currently")
273
- raise ValueError
274
- # if opt.verbose_all_prompts:
275
- # show_num = opt.max_num_prompts if (opt.max_num_prompts is not None and opt.max_num_prompts >0) else 10
276
- # for i in range(show_num):
277
- # print("embed the word into the prompt template for {} Benchmark: {}".format(
278
- # os.path.basename(opt.from_file), data[i])
279
- # )
280
- # else:
281
- # print("embed the word into the prompt template for {} Benchmark: e.g., {}".format(
282
- # os.path.basename(opt.from_file), data[0])
283
- # )
284
- if opt.max_num_prompts is not None and opt.max_num_prompts >0:
285
- print("only use {} prompts to test the model".format(opt.max_num_prompts))
286
- data = data[:opt.max_num_prompts]
287
- prompts = prompts[:opt.max_num_prompts]
288
-
289
- width, ratio, top_left_x, top_left_y, yaw, num_rows = 0.3, 0, 0.5, 0.5, 0, 1
290
- image_resolution = 512
291
- strength = 1
292
- guess_mode = False
293
- ddim_steps = 20
294
- scale = 9.0
295
- seed = 1945923867
296
- eta = 0
297
- a_prompt = 'best quality, extremely detailed'
298
- n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
299
-
300
- all_results_list = []
301
- for i in range(len(data)):
302
- ips = (
303
- data[i], prompts[i],
304
- width, ratio, top_left_x, top_left_y, yaw, num_rows,
305
- a_prompt, n_prompt,
306
- num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta
307
- )
308
- all_results = render_tool.process(*ips) #process(*ips)
309
- all_results_list.extend(all_results[1:] if data[i] != "" else all_results)
310
- all_ocr_info = []
311
- for image_array in all_results_list:
312
- ocr_result = reader.readtext(image_array)
313
- all_ocr_info.append(ocr_result)
314
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transfer.py CHANGED
@@ -1,7 +1,7 @@
1
  from omegaconf import OmegaConf
2
  from scripts.rendertext_tool import Render_Text, load_model_from_config
3
  import torch
4
- cfg = OmegaConf.load("config_cuda.yaml")
5
  model = load_model_from_config(cfg, "model_states.pt", verbose=True)
6
 
7
  from pytorch_lightning.callbacks import ModelCheckpoint
 
1
  from omegaconf import OmegaConf
2
  from scripts.rendertext_tool import Render_Text, load_model_from_config
3
  import torch
4
+ cfg = OmegaConf.load("config_cuda_ema.yaml")
5
  model = load_model_from_config(cfg, "model_states.pt", verbose=True)
6
 
7
  from pytorch_lightning.callbacks import ModelCheckpoint