Files changed (1) hide show
  1. convert.py +69 -68
convert.py CHANGED
@@ -1,69 +1,70 @@
1
- import gradio as gr
2
- import requests
3
- import os
4
- import shutil
5
- from pathlib import Path
6
- from typing import Any
7
- from tempfile import TemporaryDirectory
8
- from typing import Optional
9
-
10
- import torch
11
- from io import BytesIO
12
-
13
- from huggingface_hub import CommitInfo, Discussion, HfApi, hf_hub_download
14
- from huggingface_hub.file_download import repo_folder_name
15
- from diffusers import StableDiffusionXLPipeline
16
- from transformers import CONFIG_MAPPING
17
-
18
-
19
- COMMIT_MESSAGE = " This PR adds fp32 and fp16 weights in safetensors format to {}"
20
-
21
-
22
- def convert_single(model_id: str, filename: str, folder: str, progress: Any, token: str):
23
- progress(0, desc="Downloading model")
24
- local_file = os.path.join(model_id, filename)
25
- ckpt_file = local_file if os.path.isfile(local_file) else hf_hub_download(repo_id=model_id, filename=filename, token=token)
26
-
27
- pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_file)
28
-
29
- pipeline.save_pretrained(folder, safe_serialization=True)
30
- pipeline = pipeline.to(torch_dtype=torch.float16)
31
- pipeline.save_pretrained(folder, safe_serialization=True, variant="fp16")
32
-
33
- return folder
34
-
35
-
36
- def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
37
- try:
38
- discussions = api.get_repo_discussions(repo_id=model_id)
39
- except Exception:
40
- return None
41
- for discussion in discussions:
42
- if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
43
- details = api.get_discussion_details(repo_id=model_id, discussion_num=discussion.num)
44
- if details.target_branch == "refs/heads/main":
45
- return discussion
46
-
47
-
48
- def convert(token: str, model_id: str, filename: str, progress=gr.Progress()):
49
- api = HfApi()
50
-
51
- pr_title = "Adding `diffusers` weights of this model"
52
-
53
- with TemporaryDirectory() as d:
54
- folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
55
- os.makedirs(folder)
56
- new_pr = None
57
- try:
58
- folder = convert_single(model_id, filename, folder, progress, token)
59
- progress(0.7, desc="Uploading to Hub")
60
- new_pr = api.upload_folder(folder_path=folder, path_in_repo="./", repo_id=model_id, repo_type="model", token=token, commit_message=pr_title, commit_description=COMMIT_MESSAGE.format(model_id), create_pr=True)
61
- pr_number = new_pr.split("%2F")[-1].split("/")[0]
62
- link = f"Pr created at: {'https://huggingface.co/' + os.path.join(model_id, 'discussions', pr_number)}"
63
- progress(1, desc="Done")
64
- except Exception as e:
65
- raise gr.exceptions.Error(str(e))
66
- finally:
67
- shutil.rmtree(folder)
68
-
 
69
  return link
 
1
+ import gradio as gr
2
+ import requests
3
+ import os
4
+ import shutil
5
+ from pathlib import Path
6
+ from typing import Any
7
+ from tempfile import TemporaryDirectory
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from io import BytesIO
12
+
13
+ from huggingface_hub import CommitInfo, Discussion, HfApi, hf_hub_download
14
+ from huggingface_hub.file_download import repo_folder_name
15
+ from diffusers import StableDiffusionXLPipeline
16
+ from transformers import CONFIG_MAPPING
17
+
18
+
19
+ COMMIT_MESSAGE = " This PR adds fp32 and fp16 weights in safetensors format to {}"
20
+
21
+
22
+ def convert_single(model_id: str, filename: str, folder: str, progress: Any, token: str):
23
+ progress(0, desc="Downloading model")
24
+ local_file = os.path.join(model_id, filename)
25
+ ckpt_file = local_file if os.path.isfile(local_file) else hf_hub_download(repo_id=model_id, filename=filename, token=token)
26
+
27
+ pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_file)
28
+
29
+ pipeline.save_pretrained(folder, safe_serialization=True)
30
+ #pipeline = pipeline.to(torch_dtype=torch.float16)
31
+ pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_file, torch_dtype=torch.float16)
32
+ pipeline.save_pretrained(folder, safe_serialization=True, variant="fp16")
33
+
34
+ return folder
35
+
36
+
37
+ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
38
+ try:
39
+ discussions = api.get_repo_discussions(repo_id=model_id)
40
+ except Exception:
41
+ return None
42
+ for discussion in discussions:
43
+ if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
44
+ details = api.get_discussion_details(repo_id=model_id, discussion_num=discussion.num)
45
+ if details.target_branch == "refs/heads/main":
46
+ return discussion
47
+
48
+
49
+ def convert(token: str, model_id: str, filename: str, progress=gr.Progress()):
50
+ api = HfApi()
51
+
52
+ pr_title = "Adding `diffusers` weights of this model"
53
+
54
+ with TemporaryDirectory() as d:
55
+ folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
56
+ os.makedirs(folder)
57
+ new_pr = None
58
+ try:
59
+ folder = convert_single(model_id, filename, folder, progress, token)
60
+ progress(0.7, desc="Uploading to Hub")
61
+ new_pr = api.upload_folder(folder_path=folder, path_in_repo="./", repo_id=model_id, repo_type="model", token=token, commit_message=pr_title, commit_description=COMMIT_MESSAGE.format(model_id), create_pr=True)
62
+ pr_number = new_pr.split("%2F")[-1].split("/")[0]
63
+ link = f"Pr created at: {'https://huggingface.co/' + os.path.join(model_id, 'discussions', pr_number)}"
64
+ progress(1, desc="Done")
65
+ except Exception as e:
66
+ raise gr.exceptions.Error(str(e))
67
+ finally:
68
+ shutil.rmtree(folder)
69
+
70
  return link