File size: 3,590 Bytes
d6ac3c5
 
 
 
 
 
 
 
 
 
16786bc
d6ac3c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16786bc
d6ac3c5
 
 
 
 
 
 
 
 
 
 
16786bc
 
 
 
 
d6ac3c5
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
import os
from src.utils import create_zip, update_dropdown
from src.rep_api import process_images, traning_function, update_replicate_api_key

def create_train_tab():
    with gr.TabItem("Model Trainer"):
        gr.Markdown("# Image Importing & Auto captions")
        with gr.Row():
            input_images = gr.File(file_count="multiple", type="filepath", label="Upload Images")
            label_model = gr.Dropdown(["None","blip", "llava-16","img2prompt"],value="None", label="Caption model", info="Auto caption model")
            token_string= gr.Textbox(label="Token string",value="TOK",interactive=True,
                                     info="A unique string that will be trained to refer to the concept in the input images. Can be anything, but TOK works well.")
            context_text = gr.Textbox(label="Context Text", info="Context Text for auto caption",value=" I want a description caption for this image")
            
            replicate_api_key = gr.Textbox(
                label="Replicate API Key",
                info="API key for Replicate",
                value=os.environ.get("REPLICATE_API_TOKEN", ""),
                type="password"
            )
            api_key_status = gr.Textbox(label="API Key Status", interactive=False)

        with gr.Row():
            process_button = gr.Button("Process Images")
        
        with gr.Row():
            gr.Markdown("# Training Captions Data")
        with gr.Row():
            with gr.Column():
                image_output = gr.Gallery(type="pil",object_fit="fill")
            with gr.Column():
                text_output = gr.Textbox( interactive=True)
        with gr.Row():
            zip_output = gr.File(label="Zip file")
            btn_update_zip = gr.Button("Update zip file")

        with gr.Row():
            gr.Markdown("# Training on replicate")
        with gr.Row():
            traning_model = gr.Dropdown(["flux"], label="Caption model", info="Auto caption model")
            traning_destination = gr.Textbox(label="destination",info="add in replicate model destination, format [user]/[model_name]")
            seed = gr.Number(label="Seed", value=42,info="Random seed integer for reproducible training. Leave empty to use a random seed.")
            max_train_steps =gr.Number(label="max_train_steps", value= 1000, info="Number of individual training steps. Takes precedence over num_train_epochs.")

        with gr.Row():
            train_button = gr.Button("Train")
        with gr.Row():
            training_logs = gr.Textbox(label="Training logs")
            traning_finnal = gr.Textbox(label="Traning final")

        train_button.click(fn=traning_function, inputs=[zip_output,traning_model,traning_destination,seed,token_string,max_train_steps], 
                          outputs=[training_logs,traning_finnal],queue=True)
        
        process_button.click(fn=process_images, inputs=[input_images,label_model,context_text,token_string],
                              outputs=[image_output,text_output,zip_output],queue=True)
        
        btn_update_zip.click(fn=create_zip, inputs=[input_images,text_output,token_string],outputs=zip_output)

      #  traning_finnal.change(
      #      fn=update_dropdown,
      #      inputs=[traning_finnal,token_string],
      #      outputs=style_mode
      #  )
        replicate_api_key.change(
            fn=update_replicate_api_key,
            inputs=[replicate_api_key],
            outputs=[api_key_status]
        )