import gradio as gr import os from src.utils import create_zip, update_dropdown from src.rep_api import process_images, traning_function, update_replicate_api_key def create_train_tab(): with gr.TabItem("Model Trainer"): gr.Markdown("# Image Importing & Auto captions") with gr.Row(): input_images = gr.File(file_count="multiple", type="filepath", label="Upload Images") label_model = gr.Dropdown(["None","blip", "llava-16","img2prompt"],value="None", label="Caption model", info="Auto caption model") token_string= gr.Textbox(label="Token string",value="TOK",interactive=True, info="A unique string that will be trained to refer to the concept in the input images. Can be anything, but TOK works well.") context_text = gr.Textbox(label="Context Text", info="Context Text for auto caption",value=" I want a description caption for this image") replicate_api_key = gr.Textbox( label="Replicate API Key", info="API key for Replicate", value=os.environ.get("REPLICATE_API_TOKEN", ""), type="password" ) api_key_status = gr.Textbox(label="API Key Status", interactive=False) with gr.Row(): process_button = gr.Button("Process Images") with gr.Row(): gr.Markdown("# Training Captions Data") with gr.Row(): with gr.Column(): image_output = gr.Gallery(type="pil",object_fit="fill") with gr.Column(): text_output = gr.Textbox( interactive=True) with gr.Row(): zip_output = gr.File(label="Zip file") btn_update_zip = gr.Button("Update zip file") with gr.Row(): gr.Markdown("# Training on replicate") with gr.Row(): traning_model = gr.Dropdown(["flux"], label="Caption model", info="Auto caption model") traning_destination = gr.Textbox(label="destination",info="add in replicate model destination, format [user]/[model_name]") seed = gr.Number(label="Seed", value=42,info="Random seed integer for reproducible training. Leave empty to use a random seed.") max_train_steps =gr.Number(label="max_train_steps", value= 1000, info="Number of individual training steps. Takes precedence over num_train_epochs.") with gr.Row(): train_button = gr.Button("Train") with gr.Row(): training_logs = gr.Textbox(label="Training logs") traning_finnal = gr.Textbox(label="Traning final") train_button.click(fn=traning_function, inputs=[zip_output,traning_model,traning_destination,seed,token_string,max_train_steps], outputs=[training_logs,traning_finnal],queue=True) process_button.click(fn=process_images, inputs=[input_images,label_model,context_text,token_string], outputs=[image_output,text_output,zip_output],queue=True) btn_update_zip.click(fn=create_zip, inputs=[input_images,text_output,token_string],outputs=zip_output) # traning_finnal.change( # fn=update_dropdown, # inputs=[traning_finnal,token_string], # outputs=style_mode # ) replicate_api_key.change( fn=update_replicate_api_key, inputs=[replicate_api_key], outputs=[api_key_status] )