File size: 3,944 Bytes
1ec3b8e
a771b6a
5600a85
1ec3b8e
a771b6a
 
 
 
 
1ec3b8e
078bead
a771b6a
1ec3b8e
3519104
7c6ef59
 
 
3519104
a771b6a
7c6ef59
 
 
 
 
 
 
 
 
 
1ec3b8e
7c6ef59
 
 
1ec3b8e
7c6ef59
 
 
 
7e31dbb
1ec3b8e
 
5905366
 
 
 
 
 
 
 
 
7c6ef59
 
 
5905366
 
7c6ef59
5905366
7c6ef59
5905366
7c6ef59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5905366
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from PIL import Image
import torch

from diffusers import (
    StableDiffusionPipeline,
    StableDiffusionImg2ImgPipeline,
    StableDiffusionInpaintPipeline,
)

device="cuda"
model_id = "IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1"

pipe_text2img = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)

# pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(model_id).to(device)    # work
# pipe_inpaint = StableDiffusionInpaintPipeline(**pipe_text2img.components)   # not work
pipe_img2img = StableDiffusionImg2ImgPipeline(**pipe_text2img.components).to(device)

def infer_text2img(prompt, width, height): 
    output = pipe_text2img(prompt, width=width, height=height, guidance_scale=7.5, num_inference_steps=20,)
    image = output.images[0]
    return image

def infer_img2img(prompt, width, height, image_in, strength ):
    init_image = image_in.convert("RGB").resize((width, height))
    output = pipe_img2img(prompt, init_image=init_image, strength=strength, width=width, height=height, guidance_scale=7.5, num_inference_steps=20)
    image = output.images[0]
    return image

def infer_inpaint(prompt, width, height, image_in): 
    init_image = image_in["image"].convert("RGB").resize((width, height))
    mask = image_in["mask"].convert("RGB").resize((width, height))

    output = pipe_inpaint(prompt, \
                        init_image=init_image, mask_image=mask, \
                        width=width, height=height, \
                        guidance_scale=7.5, num_inference_steps=20)
    image = output.images[0]
    return image

with gr.Blocks() as demo:
    examples = [
                ["飞流直下三千尺, 疑是银河落九天, 瀑布, 插画"], 
                ["东临碣石, 以观沧海, 波涛汹涌, 插画"], 
                ["孤帆远影碧空尽,惟见长江天际流,油画"],
                ["女孩背影, 日落, 唯美插画"],
                ]
    with gr.Row():
        with gr.Column(scale=1, ):
            # guide = gr.Slider(2, 15, value = 7, label = '文本引导强度(guidance scale)')
            # steps = gr.Slider(10, 30, value = 20, step = 1, label = '迭代次数(inference steps)')
            image_in = gr.Image(source='upload', elem_id="image_upload", type="pil", label="参考图")
            width = gr.Slider(256, 768, value = 512, step = 64, label = '宽度(width)')
            height = gr.Slider(256, 768, value = 512, step = 64, label = '高度(height)')
            strength = gr.Slider(0, 1.0, value = 0.8, step = 0.1, label = '参考图改变程度(strength)')
            prompt = gr.Textbox(label = '提示词(prompt)')
            submit_btn = gr.Button("生成图像(Generate)")
        
        with gr.Column(scale=1, ):
            image_out = gr.Image(label = '输出(output)')
            ex = gr.Examples(examples, fn=infer_text2img, inputs=[prompt, width, height], outputs=image_out)
        # with gr.Column(scale=1, ):
        #     image_in = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload")
        #     inpaint_prompt = gr.Textbox(label = '提示词(prompt)')
        #     inpaint_btn = gr.Button("图像编辑(Inpaint)")
            # img2img_prompt = gr.Textbox(label = '提示词(prompt)')
            # img2img_btn = gr.Button("图像编辑(Inpaint)")
        if isinstance(image_in, Image.Image):
            submit_btn.click(fn = infer_img2img, inputs = [prompt, width, height, image_in, strength], outputs = image_out)
        else:
            submit_btn.click(fn = infer_text2img, inputs = [prompt, width, height], outputs = image_out)
        # inpaint_btn.click(fn = infer_inpaint, inputs = [inpaint_prompt, width, height, image_in], outputs = image_out)
        # img2img_btn.click(fn = infer_img2img, inputs = [img2img_prompt, width, height, image_in], outputs = image_out)
demo.queue(concurrency_count=10).launch()