FAW-test-app / train_tab.py
JarvisLabs's picture
Upload 3 files
16786bc verified
raw
history blame
No virus
3.59 kB
import gradio as gr
import os
from src.utils import create_zip, update_dropdown
from src.rep_api import process_images, traning_function, update_replicate_api_key
def create_train_tab():
with gr.TabItem("Model Trainer"):
gr.Markdown("# Image Importing & Auto captions")
with gr.Row():
input_images = gr.File(file_count="multiple", type="filepath", label="Upload Images")
label_model = gr.Dropdown(["None","blip", "llava-16","img2prompt"],value="None", label="Caption model", info="Auto caption model")
token_string= gr.Textbox(label="Token string",value="TOK",interactive=True,
info="A unique string that will be trained to refer to the concept in the input images. Can be anything, but TOK works well.")
context_text = gr.Textbox(label="Context Text", info="Context Text for auto caption",value=" I want a description caption for this image")
replicate_api_key = gr.Textbox(
label="Replicate API Key",
info="API key for Replicate",
value=os.environ.get("REPLICATE_API_TOKEN", ""),
type="password"
)
api_key_status = gr.Textbox(label="API Key Status", interactive=False)
with gr.Row():
process_button = gr.Button("Process Images")
with gr.Row():
gr.Markdown("# Training Captions Data")
with gr.Row():
with gr.Column():
image_output = gr.Gallery(type="pil",object_fit="fill")
with gr.Column():
text_output = gr.Textbox( interactive=True)
with gr.Row():
zip_output = gr.File(label="Zip file")
btn_update_zip = gr.Button("Update zip file")
with gr.Row():
gr.Markdown("# Training on replicate")
with gr.Row():
traning_model = gr.Dropdown(["flux"], label="Caption model", info="Auto caption model")
traning_destination = gr.Textbox(label="destination",info="add in replicate model destination, format [user]/[model_name]")
seed = gr.Number(label="Seed", value=42,info="Random seed integer for reproducible training. Leave empty to use a random seed.")
max_train_steps =gr.Number(label="max_train_steps", value= 1000, info="Number of individual training steps. Takes precedence over num_train_epochs.")
with gr.Row():
train_button = gr.Button("Train")
with gr.Row():
training_logs = gr.Textbox(label="Training logs")
traning_finnal = gr.Textbox(label="Traning final")
train_button.click(fn=traning_function, inputs=[zip_output,traning_model,traning_destination,seed,token_string,max_train_steps],
outputs=[training_logs,traning_finnal],queue=True)
process_button.click(fn=process_images, inputs=[input_images,label_model,context_text,token_string],
outputs=[image_output,text_output,zip_output],queue=True)
btn_update_zip.click(fn=create_zip, inputs=[input_images,text_output,token_string],outputs=zip_output)
# traning_finnal.change(
# fn=update_dropdown,
# inputs=[traning_finnal,token_string],
# outputs=style_mode
# )
replicate_api_key.change(
fn=update_replicate_api_key,
inputs=[replicate_api_key],
outputs=[api_key_status]
)