Beeniebeen commited on
Commit
c0ff78c
โ€ข
1 Parent(s): 434b4b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -98
app.py CHANGED
@@ -2,20 +2,24 @@ import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionPipeline
4
  from safetensors.torch import load_file
 
 
5
 
6
  # safetensors ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜
7
- def load_safetensors_model(model_id, model_file):
8
  # safetensors ํŒŒ์ผ ๋กœ๋“œ
9
  state_dict = load_file(model_file)
10
- # ๋ชจ๋ธ ๋กœ๋“œ
11
- model = StableDiffusionPipeline.from_pretrained(model_id, state_dict=state_dict)
 
 
12
  return model
13
 
14
  model_id = "gagong/Traditional-Korean-Painting-Model-v2.0"
15
- model_file = "Traditional_Korean_Painting_Model_2.safetensors"
16
 
17
  try:
18
- pipe = load_safetensors_model(model_id, model_file)
19
  pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
20
  except Exception as e:
21
  print(f"Error loading model: {e}")
@@ -27,97 +31,20 @@ def generate_image(prompt):
27
  image = pipe(prompt).images[0]
28
  return image
29
 
30
- examples = [
31
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
32
- "An astronaut riding a green horse",
33
- "A delicious ceviche cheesecake slice",
34
- ]
35
-
36
- css = """
37
- #col-container {
38
- margin: 0 auto;
39
- max-width: 520px;
40
- }
41
- """
42
-
43
- if torch.cuda.is_available():
44
- power_device = "GPU"
45
- else:
46
- power_device = "CPU"
47
-
48
- with gr.Blocks(css=css) as demo:
49
- with gr.Column(elem_id="col-container"):
50
- gr.Markdown(f"""
51
- # Text-to-Image Gradio Template
52
- Currently running on {power_device}.
53
- """)
54
-
55
- with gr.Row():
56
- prompt = gr.Textbox(
57
- label="Prompt",
58
- show_label=False,
59
- max_lines=1,
60
- placeholder="Enter your prompt",
61
- container=False,
62
- )
63
- run_button = gr.Button("Run", scale=0)
64
-
65
- result = gr.Image(label="Result", show_label=False)
66
-
67
- with gr.Accordion("Advanced Settings", open=False):
68
- negative_prompt = gr.Textbox(
69
- label="Negative prompt",
70
- max_lines=1,
71
- placeholder="Enter a negative prompt"
72
- )
73
- seed = gr.Slider(
74
- label="Seed",
75
- minimum=0,
76
- maximum=MAX_SEED,
77
- step=1,
78
- value=0,
79
- )
80
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
81
- with gr.Row():
82
- width = gr.Slider(
83
- label="Width",
84
- minimum=256,
85
- maximum=MAX_IMAGE_SIZE,
86
- step=32,
87
- value=512,
88
- )
89
- height = gr.Slider(
90
- label="Height",
91
- minimum=256,
92
- maximum=MAX_IMAGE_SIZE,
93
- step=32,
94
- value=512,
95
- )
96
- with gr.Row():
97
- guidance_scale = gr.Slider(
98
- label="Guidance scale",
99
- minimum=0.0,
100
- maximum=10.0,
101
- step=0.1,
102
- value=7.5,
103
- )
104
- num_inference_steps = gr.Slider(
105
- label="Number of inference steps",
106
- minimum=1,
107
- maximum=50,
108
- step=1,
109
- value=25,
110
- )
111
-
112
- gr.Examples(
113
- examples=examples,
114
- inputs=[prompt]
115
- )
116
-
117
- run_button.click(
118
- fn=infer,
119
- inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
120
- outputs=[result]
121
- )
122
 
123
- demo.queue().launch()
 
2
  import torch
3
  from diffusers import StableDiffusionPipeline
4
  from safetensors.torch import load_file
5
+ import os
6
+
7
 
8
  # safetensors ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜
9
+ def load_safetensors_model(model_id, model_file, use_auth_token):
10
  # safetensors ํŒŒ์ผ ๋กœ๋“œ
11
  state_dict = load_file(model_file)
12
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
13
+ model = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=use_auth_token)
14
+ # ๋ชจ๋ธ์— state_dict ๋กœ๋“œ
15
+ model.model.load_state_dict(state_dict)
16
  return model
17
 
18
  model_id = "gagong/Traditional-Korean-Painting-Model-v2.0"
19
+ model_file = "./Traditional_Korean_Painting_Model_2.safetensors" # ๋ฃจํŠธ ๋””๋ ‰ํ† ๋ฆฌ์— ์žˆ๋Š” ํŒŒ์ผ ๊ฒฝ๋กœ ์„ค์ •
20
 
21
  try:
22
+ pipe = load_safetensors_model(model_id, model_file, use_auth_token=HUGGINGFACE_TOKEN)
23
  pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
24
  except Exception as e:
25
  print(f"Error loading model: {e}")
 
31
  image = pipe(prompt).images[0]
32
  return image
33
 
34
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
35
+ with gr.Blocks() as demo:
36
+ gr.Markdown("# Traditional Korean Painting Generator")
37
+ gr.Markdown("Enter a prompt to generate a traditional Korean painting.")
38
+
39
+ with gr.Row():
40
+ with gr.Column():
41
+ prompt = gr.Textbox(label="Prompt", placeholder="Describe the scene...")
42
+ generate_btn = gr.Button("Generate")
43
+ with gr.Column():
44
+ output_image = gr.Image(label="Generated Image")
45
+
46
+ generate_btn.click(fn=generate_image, inputs=prompt, outputs=output_image)
47
+
48
+ if __name__ == "__main__":
49
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50