import replicate import os from src.utils import image_to_base64 , update_model_dicts, BB_uploadfile,numpy_to_base64 from src.deepl import detect_and_translate import json import time style_json="model_dict.json" model_dict=json.load(open(style_json,"r")) def generate_image_replicate(prompt,lora_model,api_path,aspect_ratio,gallery,model,lora_scale,num_outputs=1,guidance_scale=3.5,seed=None): print(prompt,lora_model,api_path,aspect_ratio) #if model=="dev": num_inference_steps=30 if model=="schnell": num_inference_steps=5 if lora_model is not None: api_path=model_dict[lora_model] inputs={ "model": model, "prompt": detect_and_translate(prompt), "lora_scale":lora_scale, "aspect_ratio": aspect_ratio, "num_outputs":num_outputs, "num_inference_steps":num_inference_steps, "guidance_scale":guidance_scale, } if seed is not None: inputs["seed"]=seed output = replicate.run( api_path, input=inputs ) print(output) if gallery is None: gallery=[] gallery.append(output[0]) return output[0],gallery def replicate_caption_api(image,model,context_text): base64_image = image_to_base64(image) if model=="blip": output = replicate.run( "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9", input={ "image": base64_image, "caption": True, "question": context_text, "temperature": 1, "use_nucleus_sampling": False } ) print(output) elif model=="llava-16": output = replicate.run( # "yorickvp/llava-13b:80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb", "yorickvp/llava-v1.6-34b:41ecfbfb261e6c1adf3ad896c9066ca98346996d7c4045c5bc944a79d430f174", input={ "image": base64_image, "top_p": 1, "prompt": context_text, "max_tokens": 1024, "temperature": 0.2 } ) print(output) output = "".join(output) elif model=="img2prompt": output = replicate.run( "methexis-inc/img2prompt:50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5", input={ "image":base64_image } ) print(output) return output def update_replicate_api_key(api_key): os.environ["REPLICATE_API_TOKEN"] = api_key return f"Replicate API key updated: {api_key[:5]}..." if api_key else "Replicate API key cleared" def virtual_try_on(crop, seed, steps, category, garm_img, human_img, garment_des): output = replicate.run( "cuuupid/idm-vton:906425dbca90663ff5427624839572cc56ea7d380343d13e2a4c4b09d3f0c30f", input={ "crop": crop, "seed": seed, "steps": steps, "category": category, # "force_dc": force_dc, "garm_img": numpy_to_base64( garm_img), "human_img": numpy_to_base64(human_img), #"mask_only": mask_only, "garment_des": garment_des } ) print(output) return output from src.utils import create_zip from PIL import Image def process_images(files,model,context_text,token_string): images = [] textbox ="" for file in files: print(file) image = Image.open(file) if model=="None": caption="[Insert cap here]" else: caption = replicate_caption_api(image,model,context_text) textbox += f"Tags: {caption}, file: " + os.path.basename(file) + "\n" images.append(image) #texts.append(textbox) zip_path=create_zip(files,textbox,token_string) return images, textbox,zip_path def replicate_create_model(owner,name,visibility="private",hardware="gpu-a40-large"): try: model = replicate.models.create( owner=owner, name=name, visibility=visibility, hardware=hardware, ) print(model) return True except Exception as e: print(e) if "A model with that name and owner already exists" in str(e): return True return False def traning_function(zip_path,training_model,training_destination,seed,token_string,max_train_steps,hf_repo_id=None,hf_token=None): ##Place holder for now BB_bucket_name="jarvisdataset" BB_defult="https://f005.backblazeb2.com/file/" if BB_defult not in zip_path: zip_path=BB_uploadfile(zip_path,os.path.basename(zip_path),BB_bucket_name) print(zip_path) training_logs = f"Using zip traning file at: {zip_path}\n" yield training_logs, None input={ "steps": max_train_steps, "lora_rank": 16, "batch_size": 1, "autocaption": True, "trigger_word": token_string, "learning_rate": 0.0004, "seed": seed, "input_images": zip_path } print(training_destination) username,model_name=training_destination.split("/") assert replicate_create_model(username,model_name,visibility="private",hardware="gpu-a40-large"),"Error in creating model on replicate, check API key and username is correct " print(input) try: training = replicate.trainings.create( destination=training_destination, version="ostris/flux-dev-lora-trainer:1296f0ab2d695af5a1b5eeee6e8ec043145bef33f1675ce1a2cdb0f81ec43f02", input=input, ) training_logs = f"Training started with model: {training_model}\n" training_logs += f"Destination: {training_destination}\n" training_logs += f"Seed: {seed}\n" training_logs += f"Token string: {token_string}\n" training_logs += f"Max train steps: {max_train_steps}\n" # Poll the training status while training.status != "succeeded": training.reload() training_logs += f"Training status: {training.status}\n" training_logs += f"{training.logs}\n" if training.status == "failed": training_logs += "Training failed!\n" return training_logs, training yield training_logs, None time.sleep(10) # Wait for 10 seconds before checking again training_logs += "Training completed!\n" if hf_repo_id and hf_token: training_logs += f"Uploading to Hugging Face repo: {hf_repo_id}\n" # Here you would implement the logic to upload to Hugging Face traning_finnal=training.output # In a real scenario, you might want to download and display some result images # For now, we'll just return the original images #images = [Image.open(file) for file in files] _= update_model_dicts(traning_finnal["version"],token_string,style_json="model_dict.json") traning_finnal["replicate_link"]="https://replicate.com/"+traning_finnal["version"].replace(":","/") yield training_logs, traning_finnal except Exception as e: yield f"An error occurred: {str(e)}", None