Beeniebeen commited on
Commit
1c8941e
โ€ข
1 Parent(s): c0ff78c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -4,7 +4,6 @@ from diffusers import StableDiffusionPipeline
4
  from safetensors.torch import load_file
5
  import os
6
 
7
-
8
  # safetensors ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜
9
  def load_safetensors_model(model_id, model_file, use_auth_token):
10
  # safetensors ํŒŒ์ผ ๋กœ๋“œ
@@ -15,19 +14,21 @@ def load_safetensors_model(model_id, model_file, use_auth_token):
15
  model.model.load_state_dict(state_dict)
16
  return model
17
 
18
- model_id = "gagong/Traditional-Korean-Painting-Model-v2.0"
 
19
  model_file = "./Traditional_Korean_Painting_Model_2.safetensors" # ๋ฃจํŠธ ๋””๋ ‰ํ† ๋ฆฌ์— ์žˆ๋Š” ํŒŒ์ผ ๊ฒฝ๋กœ ์„ค์ •
20
 
21
  try:
22
  pipe = load_safetensors_model(model_id, model_file, use_auth_token=HUGGINGFACE_TOKEN)
23
  pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
 
24
  except Exception as e:
25
  print(f"Error loading model: {e}")
26
  pipe = None
27
 
28
  def generate_image(prompt):
29
  if not pipe:
30
- return None
31
  image = pipe(prompt).images[0]
32
  return image
33
 
 
4
  from safetensors.torch import load_file
5
  import os
6
 
 
7
  # safetensors ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜
8
  def load_safetensors_model(model_id, model_file, use_auth_token):
9
  # safetensors ํŒŒ์ผ ๋กœ๋“œ
 
14
  model.model.load_state_dict(state_dict)
15
  return model
16
 
17
+ # ์˜ฌ๋ฐ”๋ฅธ ๋ชจ๋ธ ID๋กœ ์„ค์ •
18
+ model_id = "stabilityai/stable-diffusion-2"
19
  model_file = "./Traditional_Korean_Painting_Model_2.safetensors" # ๋ฃจํŠธ ๋””๋ ‰ํ† ๋ฆฌ์— ์žˆ๋Š” ํŒŒ์ผ ๊ฒฝ๋กœ ์„ค์ •
20
 
21
  try:
22
  pipe = load_safetensors_model(model_id, model_file, use_auth_token=HUGGINGFACE_TOKEN)
23
  pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
24
+ print("Model loaded successfully")
25
  except Exception as e:
26
  print(f"Error loading model: {e}")
27
  pipe = None
28
 
29
  def generate_image(prompt):
30
  if not pipe:
31
+ return "Model not loaded properly"
32
  image = pipe(prompt).images[0]
33
  return image
34