Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from src.utils import create_zip, update_dropdown | |
from src.rep_api import process_images, traning_function, update_replicate_api_key | |
def create_train_tab(): | |
with gr.TabItem("Model Trainer"): | |
gr.Markdown("# Image Importing & Auto captions") | |
with gr.Row(): | |
input_images = gr.File(file_count="multiple", type="filepath", label="Upload Images") | |
label_model = gr.Dropdown(["None","blip", "llava-16","img2prompt"],value="None", label="Caption model", info="Auto caption model") | |
token_string= gr.Textbox(label="Token string",value="TOK",interactive=True, | |
info="A unique string that will be trained to refer to the concept in the input images. Can be anything, but TOK works well.") | |
context_text = gr.Textbox(label="Context Text", info="Context Text for auto caption",value=" I want a description caption for this image") | |
replicate_api_key = gr.Textbox( | |
label="Replicate API Key", | |
info="API key for Replicate", | |
value=os.environ.get("REPLICATE_API_TOKEN", ""), | |
type="password" | |
) | |
api_key_status = gr.Textbox(label="API Key Status", interactive=False) | |
with gr.Row(): | |
process_button = gr.Button("Process Images") | |
with gr.Row(): | |
gr.Markdown("# Training Captions Data") | |
with gr.Row(): | |
with gr.Column(): | |
image_output = gr.Gallery(type="pil",object_fit="fill") | |
with gr.Column(): | |
text_output = gr.Textbox( interactive=True) | |
with gr.Row(): | |
zip_output = gr.File(label="Zip file") | |
btn_update_zip = gr.Button("Update zip file") | |
with gr.Row(): | |
gr.Markdown("# Training on replicate") | |
with gr.Row(): | |
traning_model = gr.Dropdown(["flux"], label="Caption model", info="Auto caption model") | |
traning_destination = gr.Textbox(label="destination",info="add in replicate model destination, format [user]/[model_name]") | |
seed = gr.Number(label="Seed", value=42,info="Random seed integer for reproducible training. Leave empty to use a random seed.") | |
max_train_steps =gr.Number(label="max_train_steps", value= 1000, info="Number of individual training steps. Takes precedence over num_train_epochs.") | |
with gr.Row(): | |
train_button = gr.Button("Train") | |
with gr.Row(): | |
training_logs = gr.Textbox(label="Training logs") | |
traning_finnal = gr.Textbox(label="Traning final") | |
train_button.click(fn=traning_function, inputs=[zip_output,traning_model,traning_destination,seed,token_string,max_train_steps], | |
outputs=[training_logs,traning_finnal],queue=True) | |
process_button.click(fn=process_images, inputs=[input_images,label_model,context_text,token_string], | |
outputs=[image_output,text_output,zip_output],queue=True) | |
btn_update_zip.click(fn=create_zip, inputs=[input_images,text_output,token_string],outputs=zip_output) | |
# traning_finnal.change( | |
# fn=update_dropdown, | |
# inputs=[traning_finnal,token_string], | |
# outputs=style_mode | |
# ) | |
replicate_api_key.change( | |
fn=update_replicate_api_key, | |
inputs=[replicate_api_key], | |
outputs=[api_key_status] | |
) | |