File size: 5,441 Bytes
238cf85
 
 
afcd47e
 
238cf85
 
 
 
 
d488098
238cf85
 
 
d488098
238cf85
 
 
 
 
fc5c76a
356456f
910a5d8
238cf85
 
d1dbfbe
238cf85
 
 
345ab15
 
 
 
 
 
 
238cf85
 
 
 
a187191
6af3625
 
 
a187191
 
345ab15
238cf85
 
 
 
 
 
 
 
 
 
 
 
afcd47e
238cf85
 
afcd47e
238cf85
 
afcd47e
238cf85
afcd47e
345ab15
 
 
 
 
 
e073e66
 
345ab15
a187191
345ab15
cd3e4b1
910a5d8
 
0a363ad
910a5d8
 
 
345ab15
fc5c76a
910a5d8
 
fc5c76a
910a5d8
 
 
345ab15
c5f68e0
afcd47e
 
46ab2e4
238cf85
 
afcd47e
345ab15
8f9829d
345ab15
 
 
 
e073e66
 
345ab15
7bb5d75
a187191
afcd47e
 
 
 
 
 
345ab15
afcd47e
 
 
 
238cf85
afcd47e
 
 
 
 
 
 
238cf85
afcd47e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a187191
afcd47e
 
 
 
 
a187191
afcd47e
a187191
afcd47e
a187191
 
345ab15
 
a187191
238cf85
 
345ab15
 
 
238cf85
 
910a5d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import gradio as gr
import numpy as np
import random
from diffusers import DiffusionPipeline
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

if torch.cuda.is_available():
    torch.cuda.max_memory_allocated(device=device)
    pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
    pipe.enable_xformers_memory_efficient_attention()
    pipe = pipe.to(device)
else: 
    pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
    pipe = pipe.to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

def infer(prompt_part1, color, dress_type, design, prompt_part5, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
    prompt = f"{prompt_part1} {color} colored plain {dress_type} with {design} design, {prompt_part5}"
    
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
        
    generator = torch.Generator().manual_seed(seed)
    
    image = pipe(
        prompt=prompt, 
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale, 
        num_inference_steps=num_inference_steps, 
        width=width, 
        height=height,
        generator=generator
    ).images[0] 
    
    return image

examples = [
    "red, t-shirt, yellow stripes",
    "blue, hoodie, minimalist",
    "red, sweat shirt, geometric design",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

if torch.cuda.is_available():
    power_device = "GPU"
else:
    power_device = "CPU"

with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""
        # Text-to-Image Gradio Template
        Currently running on {power_device}.
        """)
        
        with gr.Row():
            
            prompt_part1 = gr.Textbox(
                value="a single", 
                label="Prompt Part 1",
                show_label=False,
                interactive=False,
                container=False,
                elem_id="prompt_part1",
                visible=False,
            )
            
            prompt_part2 = gr.Textbox(
                label="color",
                show_label=False,
                max_lines=1,
                placeholder="color (e.g., color category)",
                container=False,
            )
            
            prompt_part3 = gr.Textbox(
                label="dress_type",
                show_label=False,
                max_lines=1,
                placeholder="dress_type (e.g., t-shirt, sweatshirt, shirt, hoodie)",
                container=False,
            )
            
            prompt_part4 = gr.Textbox(
                label="design",
                show_label=False,
                max_lines=1,
                placeholder="design",
                container=False,
            )
            
            prompt_part5 = gr.Textbox(
                value="hanging on the plain wall", 
                label="Prompt Part 5",
                show_label=False,
                interactive=False,
                container=False,
                elem_id="prompt_part5",
                visible=False,
            )
            
            
            run_button = gr.Button("Run", scale=0)
        
        result = gr.Image(label="Result", show_label=False)

        with gr.Accordion("Advanced Settings", open=False):
            
            negative_prompt = gr.Textbox(
                label="Negative prompt",
                max_lines=1,
                placeholder="Enter a negative prompt",
                visible=False,
            )
            
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            
            with gr.Row():
                
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=512,
                )
                
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=512,
                )
            
            with gr.Row():
                
                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=0.0,
                    maximum=10.0,
                    step=0.1,
                    value=0.0,
                )
                
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=12,
                    step=1,
                    value=2,
                )
        
        gr.Examples(
            examples=examples,
            inputs=[prompt_part2]
        )

    run_button.click(
        fn=infer,
        inputs=[prompt_part1, prompt_part2, prompt_part3, prompt_part4, prompt_part5, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
        outputs=[result]
    )

demo.queue().launch()