devlim commited on
Commit
c67f090
โ€ข
1 Parent(s): e5ea0bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -28
app.py CHANGED
@@ -1,36 +1,15 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import StableDiffusionPipeline
4
- from safetensors.torch import load_file
5
- import os
6
 
7
- # Hugging Face ํ† ํฐ ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ ๊ฐ€์ ธ์˜ค๊ธฐ
8
- HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
 
 
9
 
10
- # safetensors ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜
11
- def load_safetensors_model(model_id, model_file, use_auth_token):
12
- try:
13
- # safetensors ํŒŒ์ผ ๋กœ๋“œ
14
- state_dict = load_file(model_file)
15
- # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
16
- model = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=use_auth_token)
17
- # ๋ชจ๋ธ์— state_dict ๋กœ๋“œ
18
- model.model.load_state_dict(state_dict)
19
- return model
20
- except Exception as e:
21
- print(f"Error loading model: {e}")
22
- return None
23
-
24
- # ์˜ฌ๋ฐ”๋ฅธ ๋ชจ๋ธ ID๋กœ ์„ค์ •
25
- model_id = "stabilityai/stable-diffusion-2"
26
- model_file = "./Traditional_Korean_Painting_Model_2.safetensors" # ๋ฃจํŠธ ๋””๋ ‰ํ† ๋ฆฌ์— ์žˆ๋Š” ํŒŒ์ผ ๊ฒฝ๋กœ ์„ค์ •
27
 
28
- pipe = load_safetensors_model(model_id, model_file, use_auth_token=HUGGINGFACE_TOKEN)
29
- if pipe:
30
- pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
31
- print("Model loaded successfully")
32
- else:
33
- print("Failed to load the model")
34
 
35
  def generate_image(prompt):
36
  if not pipe:
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
 
 
4
 
5
+ model_id = "gagong/Traditional-Korean-Painting-Model-v2.0"
6
+ scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
7
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
8
+ pipe = pipe.to("cuda")
9
 
10
+ # prompt = "a photo of an astronaut riding a horse on mars"
11
+ # image = pipe(prompt).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
 
 
 
 
 
13
 
14
  def generate_image(prompt):
15
  if not pipe: