John6666 commited on
Commit
a6bc924
1 Parent(s): 3bf56fd

Upload 13 files

Browse files
Files changed (4) hide show
  1. app.py +9 -3
  2. requirements.txt +1 -1
  3. tagger/fl2sd3longcap.py +4 -2
  4. tagger/utils.py +5 -0
app.py CHANGED
@@ -21,6 +21,7 @@ from tagger.tagger import (
21
  remove_specific_prompt,
22
  convert_danbooru_to_e621_prompt,
23
  insert_recom_prompt,
 
24
  )
25
  from tagger.fl2sd3longcap import predict_tags_fl2_sd3
26
  from tagger.v2 import (
@@ -40,12 +41,12 @@ load_models(models, 5)
40
 
41
 
42
  css = """
43
- #model_info { text-align: center; display:block; }
44
  """
45
 
46
  with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as demo:
47
  with gr.Column():
48
- with gr.Accordion("Advanced settings", open=True):
49
  with gr.Accordion("Recommended Prompt", open=False):
50
  recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
51
  with gr.Row():
@@ -61,6 +62,7 @@ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as demo:
61
  v2_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
62
  v2_tag_type = gr.Radio(label="Tag Type", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru", visible=False)
63
  v2_model = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
 
64
  with gr.Accordion("Model", open=True):
65
  model_name = gr.Dropdown(label="Select Model", show_label=False, choices=list(loaded_models.keys()), value=list(loaded_models.keys())[0], allow_custom_value=True)
66
  model_info = gr.Markdown(value=get_model_info_md(list(loaded_models.keys())[0]), elem_id="model_info")
@@ -142,9 +144,11 @@ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as demo:
142
  convert_danbooru_to_e621_prompt, [prompt, v2_tag_type], [prompt], queue=False, show_api=False,
143
  )
144
  tagger_generate_from_image.click(
 
 
145
  predict_tags_wd,
146
  [tagger_image, prompt, tagger_algorithms, tagger_general_threshold, tagger_character_threshold],
147
- [v2_series, v2_character, prompt, gr.Button(visible=False)],
148
  show_api=False,
149
  ).success(
150
  predict_tags_fl2_sd3, [tagger_image, prompt, tagger_algorithms], [prompt], show_api=False,
@@ -154,6 +158,8 @@ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as demo:
154
  convert_danbooru_to_e621_prompt, [prompt, tagger_tag_type], [prompt], queue=False, show_api=False,
155
  ).success(
156
  insert_recom_prompt, [prompt, neg_prompt, tagger_recom_prompt], [prompt, neg_prompt], queue=False, show_api=False,
 
 
157
  )
158
 
159
  demo.queue()
 
21
  remove_specific_prompt,
22
  convert_danbooru_to_e621_prompt,
23
  insert_recom_prompt,
24
+ compose_prompt_to_copy,
25
  )
26
  from tagger.fl2sd3longcap import predict_tags_fl2_sd3
27
  from tagger.v2 import (
 
41
 
42
 
43
  css = """
44
+ #model_info { text-align: center; }
45
  """
46
 
47
  with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as demo:
48
  with gr.Column():
49
+ with gr.Accordion("Advanced settings", open=False):
50
  with gr.Accordion("Recommended Prompt", open=False):
51
  recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
52
  with gr.Row():
 
62
  v2_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
63
  v2_tag_type = gr.Radio(label="Tag Type", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru", visible=False)
64
  v2_model = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
65
+ v2_copy = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
66
  with gr.Accordion("Model", open=True):
67
  model_name = gr.Dropdown(label="Select Model", show_label=False, choices=list(loaded_models.keys()), value=list(loaded_models.keys())[0], allow_custom_value=True)
68
  model_info = gr.Markdown(value=get_model_info_md(list(loaded_models.keys())[0]), elem_id="model_info")
 
144
  convert_danbooru_to_e621_prompt, [prompt, v2_tag_type], [prompt], queue=False, show_api=False,
145
  )
146
  tagger_generate_from_image.click(
147
+ lambda: ("", "", ""), None, [v2_series, v2_character, prompt], queue=False, show_api=False,
148
+ ).success(
149
  predict_tags_wd,
150
  [tagger_image, prompt, tagger_algorithms, tagger_general_threshold, tagger_character_threshold],
151
+ [v2_series, v2_character, prompt, v2_copy],
152
  show_api=False,
153
  ).success(
154
  predict_tags_fl2_sd3, [tagger_image, prompt, tagger_algorithms], [prompt], show_api=False,
 
158
  convert_danbooru_to_e621_prompt, [prompt, tagger_tag_type], [prompt], queue=False, show_api=False,
159
  ).success(
160
  insert_recom_prompt, [prompt, neg_prompt, tagger_recom_prompt], [prompt, neg_prompt], queue=False, show_api=False,
161
+ ).success(
162
+ compose_prompt_to_copy, [v2_character, v2_series, prompt], [prompt], queue=False, show_api=False,
163
  )
164
 
165
  demo.queue()
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  huggingface_hub
2
- torch
3
  torchvision
4
  accelerate
5
  transformers
 
1
  huggingface_hub
2
+ torch==2.2.0
3
  torchvision
4
  accelerate
5
  transformers
tagger/fl2sd3longcap.py CHANGED
@@ -2,11 +2,13 @@ from transformers import AutoProcessor, AutoModelForCausalLM
2
  import spaces
3
  import re
4
  from PIL import Image
 
5
 
6
  import subprocess
7
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
8
 
9
- fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).eval()
 
10
  fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
11
 
12
 
@@ -48,7 +50,7 @@ def fl_run_example(image):
48
  if image.mode != "RGB":
49
  image = image.convert("RGB")
50
 
51
- inputs = fl_processor(text=prompt, images=image, return_tensors="pt")
52
  generated_ids = fl_model.generate(
53
  input_ids=inputs["input_ids"],
54
  pixel_values=inputs["pixel_values"],
 
2
  import spaces
3
  import re
4
  from PIL import Image
5
+ import torch
6
 
7
  import subprocess
8
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
 
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).to(device).eval()
12
  fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
13
 
14
 
 
50
  if image.mode != "RGB":
51
  image = image.convert("RGB")
52
 
53
+ inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
54
  generated_ids = fl_model.generate(
55
  input_ids=inputs["input_ids"],
56
  pixel_values=inputs["pixel_values"],
tagger/utils.py CHANGED
@@ -43,3 +43,8 @@ COPY_ACTION_JS = """\
43
  navigator.clipboard.writeText(inputs);
44
  }
45
  }"""
 
 
 
 
 
 
43
  navigator.clipboard.writeText(inputs);
44
  }
45
  }"""
46
+
47
+
48
+ def gradio_copy_prompt(prompt: str):
49
+ gr.Info("Copied!")
50
+ return prompt