DemiPoto commited on
Commit
bd02d5f
1 Parent(s): a98a8d4

Update all_models2.py

Browse files
Files changed (1) hide show
  1. all_models2.py +26 -0
all_models2.py CHANGED
@@ -33,6 +33,32 @@ def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="l
33
  if len(models) == limit: break
34
  return models , models_plus_tags
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def process_tags(tags,bad_tags=bad_tags):
37
  t1=True
38
  new_tags=[]
 
33
  if len(models) == limit: break
34
  return models , models_plus_tags
35
 
36
+ def find_warm_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="last_modified", limit: int=30, force_gpu=False, check_status=False, bad_models=bad_models):
37
+ from huggingface_hub import HfApi
38
+ api = HfApi()
39
+ default_tags = ["diffusers"]
40
+ if not sort: sort = "last_modified"
41
+ models = []
42
+ models_plus_tags=[]
43
+ try:
44
+ model_infos = api.list_models(author=author, task="text-to-image",
45
+ tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit, inference ="warm")
46
+ except Exception as e:
47
+ print(f"Error: Failed to list models.")
48
+ print(e)
49
+ return models
50
+ for model in model_infos:
51
+
52
+ if not model.private and not model.gated:
53
+ loadable = True
54
+ if not_tag and not_tag in model.tags or not loadable: continue
55
+ if model.id not in bad_models :
56
+ models.append(model.id)
57
+ #models_plus_tags.append([model.id,process_tags(model.tags)])
58
+ models_plus_tags.append([model.id,process_tags(model.cardData.tags)])
59
+ if len(models) == limit: break
60
+ return models , models_plus_tags
61
+
62
  def process_tags(tags,bad_tags=bad_tags):
63
  t1=True
64
  new_tags=[]