Beeniebeen commited on
Commit
e5ea0bc
โ€ข
1 Parent(s): 1c8941e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -17
app.py CHANGED
@@ -4,33 +4,43 @@ from diffusers import StableDiffusionPipeline
4
  from safetensors.torch import load_file
5
  import os
6
 
 
 
 
7
  # safetensors ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜
8
  def load_safetensors_model(model_id, model_file, use_auth_token):
9
- # safetensors ํŒŒ์ผ ๋กœ๋“œ
10
- state_dict = load_file(model_file)
11
- # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
12
- model = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=use_auth_token)
13
- # ๋ชจ๋ธ์— state_dict ๋กœ๋“œ
14
- model.model.load_state_dict(state_dict)
15
- return model
 
 
 
 
16
 
17
  # ์˜ฌ๋ฐ”๋ฅธ ๋ชจ๋ธ ID๋กœ ์„ค์ •
18
  model_id = "stabilityai/stable-diffusion-2"
19
  model_file = "./Traditional_Korean_Painting_Model_2.safetensors" # ๋ฃจํŠธ ๋””๋ ‰ํ† ๋ฆฌ์— ์žˆ๋Š” ํŒŒ์ผ ๊ฒฝ๋กœ ์„ค์ •
20
 
21
- try:
22
- pipe = load_safetensors_model(model_id, model_file, use_auth_token=HUGGINGFACE_TOKEN)
23
  pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
24
  print("Model loaded successfully")
25
- except Exception as e:
26
- print(f"Error loading model: {e}")
27
- pipe = None
28
 
29
  def generate_image(prompt):
30
  if not pipe:
31
  return "Model not loaded properly"
32
- image = pipe(prompt).images[0]
33
- return image
 
 
 
 
34
 
35
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
36
  with gr.Blocks() as demo:
@@ -42,10 +52,9 @@ with gr.Blocks() as demo:
42
  prompt = gr.Textbox(label="Prompt", placeholder="Describe the scene...")
43
  generate_btn = gr.Button("Generate")
44
  with gr.Column():
45
- output_image = gr.Image(label="Generated Image")
46
-
47
  generate_btn.click(fn=generate_image, inputs=prompt, outputs=output_image)
48
 
49
  if __name__ == "__main__":
50
  demo.launch()
51
-
 
4
  from safetensors.torch import load_file
5
  import os
6
 
7
+ # Hugging Face ํ† ํฐ ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ ๊ฐ€์ ธ์˜ค๊ธฐ
8
+ HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
9
+
10
  # safetensors ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜
11
  def load_safetensors_model(model_id, model_file, use_auth_token):
12
+ try:
13
+ # safetensors ํŒŒ์ผ ๋กœ๋“œ
14
+ state_dict = load_file(model_file)
15
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
16
+ model = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=use_auth_token)
17
+ # ๋ชจ๋ธ์— state_dict ๋กœ๋“œ
18
+ model.model.load_state_dict(state_dict)
19
+ return model
20
+ except Exception as e:
21
+ print(f"Error loading model: {e}")
22
+ return None
23
 
24
  # ์˜ฌ๋ฐ”๋ฅธ ๋ชจ๋ธ ID๋กœ ์„ค์ •
25
  model_id = "stabilityai/stable-diffusion-2"
26
  model_file = "./Traditional_Korean_Painting_Model_2.safetensors" # ๋ฃจํŠธ ๋””๋ ‰ํ† ๋ฆฌ์— ์žˆ๋Š” ํŒŒ์ผ ๊ฒฝ๋กœ ์„ค์ •
27
 
28
+ pipe = load_safetensors_model(model_id, model_file, use_auth_token=HUGGINGFACE_TOKEN)
29
+ if pipe:
30
  pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
31
  print("Model loaded successfully")
32
+ else:
33
+ print("Failed to load the model")
 
34
 
35
  def generate_image(prompt):
36
  if not pipe:
37
  return "Model not loaded properly"
38
+ try:
39
+ image = pipe(prompt).images[0]
40
+ return image
41
+ except Exception as e:
42
+ print(f"Error generating image: {e}")
43
+ return "Error generating image"
44
 
45
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
46
  with gr.Blocks() as demo:
 
52
  prompt = gr.Textbox(label="Prompt", placeholder="Describe the scene...")
53
  generate_btn = gr.Button("Generate")
54
  with gr.Column():
55
+ output_image = gr.Image(label="Generated Image", type="pil")
56
+
57
  generate_btn.click(fn=generate_image, inputs=prompt, outputs=output_image)
58
 
59
  if __name__ == "__main__":
60
  demo.launch()