weifeng-chen commited on
Commit
7c6ef59
1 Parent(s): a771b6a

add ref image

Browse files
Files changed (1) hide show
  1. app.py +40 -20
app.py CHANGED
@@ -11,19 +11,30 @@ device="cuda"
11
  model_id = "IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1"
12
 
13
  pipe_text2img = StableDiffusionPipeline.from_pretrained(model_id).to(device)
 
 
 
14
  pipe_img2img = StableDiffusionImg2ImgPipeline(**pipe_text2img.components)
15
- pipe_inpaint = StableDiffusionInpaintPipeline(**pipe_text2img.components)
16
 
17
- # io1_text2img = gr.Interface.load("models/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1")
 
 
 
 
 
 
 
 
 
18
 
19
- def resize(w_val,l_val,img):
20
- img = Image.open(img)
21
- img = img.resize((w_val,l_val), Image.Resampling.LANCZOS)
22
- return img
23
 
24
- def infer(prompt, guide, steps, width, height):
25
- output = pipe_text2img(prompt, guidance_scale=guide, num_inference_steps=steps, width=width, height=height)
26
- # output = io1_text2img(prompt, guidance_scale=guide, num_inference_steps=steps, width=width, height=height)
 
27
  image = output.images[0]
28
  return image
29
 
@@ -35,20 +46,29 @@ with gr.Blocks() as demo:
35
  ["女孩背影, 日落, 唯美插画"],
36
  ]
37
  with gr.Row():
38
- with gr.Column(scale=2, ):
39
- output = gr.Image(label = '输出(output)')
40
-
41
  with gr.Column(scale=1, ):
42
- guide = gr.Slider(2, 15, value = 7, label = '文本引导强度(guidance scale)')
43
- steps = gr.Slider(10, 30, value = 20, step = 1, label = '迭代次数(inference steps)')
 
44
  width = gr.Slider(256, 768, value = 512, step = 64, label = '宽度(width)')
45
  height = gr.Slider(256, 768, value = 512, step = 64, label = '高度(height)')
46
-
47
  prompt = gr.Textbox(label = '提示词(prompt)')
48
- submit_btn = gr.Button("生成图片(Generate)").style(margin=False, rounded=(False, True, True, False), full_width=False,)
49
-
50
- ex = gr.Examples(examples, fn=infer, inputs=[prompt, guide, steps, width, height], outputs=output)
51
-
52
- submit_btn.click(fn = infer, inputs = [prompt, guide, steps, width, height], outputs = output)
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  demo.queue(concurrency_count=10).launch()
 
11
  model_id = "IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1"
12
 
13
  pipe_text2img = StableDiffusionPipeline.from_pretrained(model_id).to(device)
14
+
15
+ # pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(model_id).to(device) # work
16
+ # pipe_inpaint = StableDiffusionInpaintPipeline(**pipe_text2img.components) # not work
17
  pipe_img2img = StableDiffusionImg2ImgPipeline(**pipe_text2img.components)
 
18
 
19
+ def infer_text2img(prompt, width, height):
20
+ output = pipe_text2img(prompt, width=width, height=height, guidance_scale=7.5, num_inference_steps=20,)
21
+ image = output.images[0]
22
+ return image
23
+
24
+ def infer_img2img(prompt, width, height, image_in, strength ):
25
+ init_image = image_in.convert("RGB").resize((width, height))
26
+ output = pipe_img2img(prompt, init_image=init_image, strength=strength, width=width, height=height, guidance_scale=7.5, num_inference_steps=20)
27
+ image = output.images[0]
28
+ return image
29
 
30
+ def infer_inpaint(prompt, width, height, image_in):
31
+ init_image = image_in["image"].convert("RGB").resize((width, height))
32
+ mask = image_in["mask"].convert("RGB").resize((width, height))
 
33
 
34
+ output = pipe_inpaint(prompt, \
35
+ init_image=init_image, mask_image=mask, \
36
+ width=width, height=height, \
37
+ guidance_scale=7.5, num_inference_steps=20)
38
  image = output.images[0]
39
  return image
40
 
 
46
  ["女孩背影, 日落, 唯美插画"],
47
  ]
48
  with gr.Row():
 
 
 
49
  with gr.Column(scale=1, ):
50
+ # guide = gr.Slider(2, 15, value = 7, label = '文本引导强度(guidance scale)')
51
+ # steps = gr.Slider(10, 30, value = 20, step = 1, label = '迭代次数(inference steps)')
52
+ image_in = gr.Image(source='upload', elem_id="image_upload", type="pil", label="参考图")
53
  width = gr.Slider(256, 768, value = 512, step = 64, label = '宽度(width)')
54
  height = gr.Slider(256, 768, value = 512, step = 64, label = '高度(height)')
55
+ strength = gr.Slider(0, 1.0, value = 0.8, step = 0.1, label = '参考图改变程度(strength)')
56
  prompt = gr.Textbox(label = '提示词(prompt)')
57
+ submit_btn = gr.Button("生成图像(Generate)")
 
 
 
 
58
 
59
+ with gr.Column(scale=1, ):
60
+ image_out = gr.Image(label = '输出(output)')
61
+ ex = gr.Examples(examples, fn=infer_text2img, inputs=[prompt, width, height], outputs=image_out)
62
+ # with gr.Column(scale=1, ):
63
+ # image_in = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload")
64
+ # inpaint_prompt = gr.Textbox(label = '提示词(prompt)')
65
+ # inpaint_btn = gr.Button("图像编辑(Inpaint)")
66
+ # img2img_prompt = gr.Textbox(label = '提示词(prompt)')
67
+ # img2img_btn = gr.Button("图像编辑(Inpaint)")
68
+ if isinstance(image_in, Image.Image):
69
+ submit_btn.click(fn = infer_img2img, inputs = [prompt, width, height, image_in, strength], outputs = image_out)
70
+ else:
71
+ submit_btn.click(fn = infer_text2img, inputs = [prompt, width, height], outputs = image_out)
72
+ # inpaint_btn.click(fn = infer_inpaint, inputs = [inpaint_prompt, width, height, image_in], outputs = image_out)
73
+ # img2img_btn.click(fn = infer_img2img, inputs = [img2img_prompt, width, height, image_in], outputs = image_out)
74
  demo.queue(concurrency_count=10).launch()