File size: 4,733 Bytes
44fe76b
df100ff
44fe76b
df100ff
44fe76b
df100ff
44fe76b
 
 
df100ff
44fe76b
 
df100ff
44fe76b
df100ff
 
 
44fe76b
df100ff
 
 
44fe76b
df100ff
 
 
 
 
 
 
 
44fe76b
 
 
 
 
 
 
 
df100ff
 
 
 
44fe76b
dbfd4a1
 
 
44fe76b
dbfd4a1
44fe76b
 
dbfd4a1
44fe76b
 
dbfd4a1
44fe76b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45f19f0
44fe76b
45f19f0
 
d5074d3
45f19f0
44fe76b
 
 
 
4943ec2
44fe76b
45f19f0
44fe76b
 
 
45f19f0
 
44fe76b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import torch
import gradio as gr
from tqdm import tqdm
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms as tfms
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, DiffusionPipeline

torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"

# Load the pipeline
model_path = "CompVis/stable-diffusion-v1-4"
sd_pipeline = DiffusionPipeline.from_pretrained(
    model_path,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float32
).to(torch_device)

# Load textual inversions
sd_pipeline.load_textual_inversion("sd-concepts-library/illustration-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/line-art")
sd_pipeline.load_textual_inversion("sd-concepts-library/hitokomoru-style-nao")
sd_pipeline.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
sd_pipeline.load_textual_inversion("sd-concepts-library/midjourney-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/hanfu-anime-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/birb-style")

# Update style token dictionary
style_token_dict = {
    "Illustration Style": '<illustration-style>',
    "Line Art":'<line-art>',
    "Hitokomoru Style":'<hitokomoru-style-nao>',
    "Marc Allante": '<Marc_Allante>',
    "Midjourney":'<midjourney-style>',
    "Hanfu Anime": '<hanfu-anime-style>',
    "Birb Style": '<birb-style>'
}


def set_timesteps(scheduler, num_inference_steps):
    scheduler.set_timesteps(num_inference_steps)
    scheduler.timesteps = scheduler.timesteps.to(torch.float32)

def pil_to_latent(input_im):
    with torch.no_grad():
        latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
    return 0.18215 * latent.latent_dist.sample()

def latents_to_pil(latents):
    latents = (1 / 0.18215) * latents
    with torch.no_grad():
        image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images

def generate_with_pipeline(prompt, num_inference_steps, guidance_scale, seed):
    generator = torch.Generator(device=torch_device).manual_seed(seed)
    image = sd_pipeline(
        prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator
    ).images[0]
    return image

def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale):
    prompt = text + " " + style_token_dict[style]

    # Generate image with pipeline
    image_pipeline = generate_with_pipeline(prompt, inference_step, guidance_scale, seed)

    # For the guided image, we'll need to implement a custom pipeline or modify the existing one
    # This is a placeholder and would need to be implemented
    image_guide = image_pipeline  # This should be replaced with actual guided generation

    return image_pipeline, image_guide

title = "Generative with Textual Inversion"
description = "A simple Gradio interface to infer Stable Diffusion and generate images with different art styles"
examples = [
    ["A majestic castle on a floating island", 'Illustration Style', 20, 7.5, 42, 'Grayscale', 200],
    ["A cyberpunk cityscape at night", 'Midjourney', 25, 8.0, 123, 'Contrast', 300]
]

demo = gr.Interface(inference, 
                    inputs = [gr.Textbox(label="Prompt", type="text"),
                              gr.Dropdown(label="Style", choices=list(style_token_dict.keys()), value="Illustration Style"), 
                              gr.Slider(1, 50, 10, step = 1, label="Inference steps"),
                              gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
                              gr.Slider(0, 10000, 42, step = 1, label="Seed"),
                              gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast', 
                                                                  'Symmetry', 'Saturation'], value="Grayscale"),
                              gr.Slider(100, 10000, 200, step = 100, label="Loss scale")],
                    outputs= [gr.Image(width=512, height=512, label="Generated art"),
                              gr.Image(width=512, height=512, label="Generated art with guidance")],
                    title=title,
                    description=description,
                    examples=examples)

demo.launch()