JarvisLabs commited on
Commit
d6f10f4
1 Parent(s): 353926b

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +144 -237
  2. model_dict.json +7 -0
  3. requirments.txt +6 -0
app.py CHANGED
@@ -1,238 +1,145 @@
1
- import gradio as gr
2
- from PIL import Image
3
- import os
4
- import replicate
5
- import time
6
- import base64
7
- import numpy as npzipfile
8
- import tempfile
9
- import zipfile
10
- import b2sdk.v2 as b2 #Backblaze img2img upload bucket
11
- import shutil
12
- import requests
13
- import io
14
-
15
-
16
-
17
- info = b2.InMemoryAccountInfo()
18
- b2_api = b2.B2Api(info)
19
- application_key_id = os.getenv("BB_KeyID")
20
- application_key = os.getenv("BB_AppKey")
21
- #print(application_key_id,application_key)
22
- b2_api.authorize_account("production", application_key_id, application_key)
23
- BB_bucket_name=os.getenv("BB_bucket")
24
- BB_bucket=b2_api.get_bucket_by_name(os.getenv("BB_bucket"))
25
- BB_defurl="https://f005.backblazeb2.com/file/"
26
-
27
- def process_images(files,model,context_text):
28
- images = []
29
- textbox =""
30
- for file in files:
31
- print(file)
32
- image = Image.open(file)
33
- caption = replicate_caption_api(image,model,context_text)
34
- textbox += f"Tags: {caption}, file: " + os.path.basename(file) + "\n"
35
- images.append(image)
36
- #texts.append(textbox)
37
- zip_path=create_zip(files,textbox,"TOK")
38
- print(zip_path)
39
- return images, textbox
40
-
41
- def BB_uploadfile(b2_api,local_file,file_name,BB_bucket_name,FRIENDLY_URL=True):
42
- metadata = {"key": "value"}
43
- uploaded_file = BB_bucket.upload_local_file(
44
- local_file=local_file,
45
- file_name=file_name,
46
- file_infos=metadata,
47
- )
48
- img_url=b2_api.get_download_url_for_fileid(uploaded_file.id_)
49
- if FRIENDLY_URL: #Get friendly URP
50
- img_url=BB_defurl+BB_bucket_name+"/"+file_name
51
- print("backblaze", img_url)
52
- return img_url
53
-
54
- def image_to_base64(img):
55
- buffered = io.BytesIO()
56
- img.save(buffered, format="PNG")
57
- img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
58
- return "data:image/png;base64,"+img_str
59
-
60
-
61
-
62
- def create_zip(files,captions,trigger):
63
- #Caption processing
64
- captions=captions.split("\n")
65
- #cute files and "tags:"
66
- captions= [cap.split("file:")[0][5:] for cap in captions]
67
- #temp_dir="/content"
68
- #os.makedirs(temp_dir, exist_ok=True)
69
- # Create a zip file
70
- #os.makedirs(temp_dir, exist_ok=True)
71
- zip_path = "training_data.zip" #os.path.join(temp_dir, "training_data.zip")
72
- with zipfile.ZipFile(zip_path, "w") as zip_file:
73
- for i, file in enumerate(files):
74
- # Add image to zip
75
- image_name = f"image_{i}.jpg"
76
- zip_file.write(file, image_name)
77
- # Add caption to zip
78
- caption_name = f"image_{i}.txt"
79
- caption_content = captions[i] +f", {trigger}"
80
- zip_file.writestr(caption_name, caption_content)
81
-
82
- file_url= BB_uploadfile(b2_api,zip_path,f"training_data_{trigger}.zip",BB_bucket_name)
83
- return file_url
84
-
85
-
86
- def replicate_caption_api(image,model,context_text):
87
- base64_image = image_to_base64(image)
88
- if model=="blip":
89
- output = replicate.run(
90
- "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9",
91
- input={
92
- "image": base64_image,
93
- "caption": True,
94
- "question": context_text,
95
- "temperature": 1,
96
- "use_nucleus_sampling": False
97
- }
98
- )
99
- print(output)
100
-
101
- elif model=="llava-16":
102
- output = replicate.run(
103
- # "yorickvp/llava-13b:80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",
104
- "yorickvp/llava-v1.6-34b:41ecfbfb261e6c1adf3ad896c9066ca98346996d7c4045c5bc944a79d430f174",
105
- input={
106
- "image": base64_image,
107
- "top_p": 1,
108
- "prompt": context_text,
109
- "max_tokens": 1024,
110
- "temperature": 0.2
111
- }
112
- )
113
- print(output)
114
- output = "".join(output)
115
-
116
- elif model=="img2prompt":
117
- output = replicate.run(
118
- "methexis-inc/img2prompt:50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5",
119
- input={
120
- "image":base64_image
121
- }
122
- )
123
- print(output)
124
- return output
125
-
126
- def update_replicate_api_key(api_key):
127
- os.environ["REPLICATE_API_TOKEN"] = api_key
128
- return f"Replicate API key updated: {api_key[:5]}..." if api_key else "Replicate API key cleared"
129
-
130
-
131
- def traning_function(files,text_output,training_model,training_destination,seed,token_string,max_train_steps,hf_repo_id,hf_token):
132
- print(files,text_output)
133
- zip_path = create_zip(files,text_output,token_stringn)
134
- print(zip_path)
135
- training_logs = f"Created zip file at: {zip_path}\n"
136
- yield training_logs, None
137
-
138
- try:
139
- training = replicate.trainings.create(
140
- destination=training_destination,
141
- version="ostris/flux-dev-lora-trainer:1296f0ab2d695af5a1b5eeee6e8ec043145bef33f1675ce1a2cdb0f81ec43f02",
142
- input={
143
- "steps": max_train_steps,
144
- "lora_rank": 16,
145
- "batch_size": 1,
146
- "autocaption": True,
147
- "trigger_word": token_string,
148
- "learning_rate": 0.0004,
149
- "seed": seed,
150
- "input_images": zip_path
151
- },
152
- )
153
-
154
- training_logs = f"Training started with model: {training_model}\n"
155
- training_logs += f"Destination: {training_destination}\n"
156
- training_logs += f"Seed: {seed}\n"
157
- training_logs += f"Token string: {token_string}\n"
158
- training_logs += f"Max train steps: {max_train_steps}\n"
159
-
160
- # Poll the training status
161
- while training.status != "succeeded":
162
- training.reload()
163
- training_logs += f"Training status: {training.status}\n"
164
- training_logs += f"{training.logs}\n"
165
- yield training_logs, None
166
- time.sleep(10) # Wait for 10 seconds before checking again
167
-
168
- training_logs += "Training completed!\n"
169
- if hf_repo_id and hf_token:
170
- training_logs += f"Uploading to Hugging Face repo: {hf_repo_id}\n"
171
- # Here you would implement the logic to upload to Hugging Face
172
-
173
- # In a real scenario, you might want to download and display some result images
174
- # For now, we'll just return the original images
175
- images = [Image.open(file) for file in files]
176
-
177
- yield training_logs, images
178
-
179
- except Exception as e:
180
- yield f"An error occurred: {str(e)}", None
181
-
182
- with gr.Blocks() as demo:
183
- gr.Markdown("# Image Captioning")
184
- with gr.Row():
185
- input_images = gr.File(file_count="multiple", type="filepath", label="Upload Images")
186
- label_model = gr.Dropdown(["blip", "llava-16","img2prompt"], label="Caption model", info="Auto caption model")
187
- context_text = gr.Textbox(label="Context Text", info="Context Text for auto catpion",value=" I want a description caption for this image")
188
- # Replicate API Key input
189
- replicate_api_key = gr.Textbox(
190
- label="Replicate API Key",
191
- info="API key for Replicate",
192
- type="password"
193
- )
194
- api_key_status = gr.Textbox(label="API Key Status", interactive=False)
195
-
196
- with gr.Row():
197
- process_button = gr.Button("Process Images")
198
- #Image outputs
199
- with gr.Row():
200
- gr.Markdown("# Captions")
201
- with gr.Row():
202
- with gr.Column():
203
- image_output = gr.Gallery(type="pil",object_fit="fill")
204
- with gr.Column():
205
- text_output = gr.Textbox( interactive=True)
206
- #Traning options
207
- with gr.Row():
208
- gr.Markdown("# Training on replicate")
209
- with gr.Row():
210
- traning_model = gr.Dropdown(["flux", "SDXL",""], label="Caption model", info="Auto caption model")
211
- traning_destination = gr.Textbox(label="destination",info="add in replicate model destination")
212
- seed = gr.Number(label="Seed", value=42,info="Random seed integer for reproducible training. Leave empty to use a random seed.")
213
- token_stringn = "TOK"# gr.Textbox(label="Token string",value="TOK",info="A unique string that will be trained to refer to the concept in the input images. Can be anything, but TOK works well.")
214
- max_train_steps =gr.Number(label="max_train_steps", value= 1000, info="Number of individual training steps. Takes precedence over num_train_epochs.")
215
- with gr.Row():
216
- hf_repo_id = gr.Textbox(label="Hugging face repo id",info="Hugging Face repository ID, if you'd like to upload the trained LoRA to Hugging Face. For example, lucataco/flux-dev-lora.")
217
- hf_token = gr.Textbox(label="Hugging face write token",info="Hugging Face token, if you'd like to upload the trained LoRA to Hugging Face.")
218
- with gr.Row():
219
- train_button = gr.Button("Train")
220
- with gr.Row():
221
- training_logs = gr.Textbox(label="Training logs")
222
- training_images = gr.Gallery(label="Training images")
223
-
224
-
225
-
226
- train_button.click(fn=traning_function, inputs=[input_images,text_output,traning_model,traning_destination,seed,token_stringn,max_train_steps,hf_repo_id,hf_token],
227
- outputs=[image_output,text_output])
228
-
229
- process_button.click(fn=process_images, inputs=[input_images,label_model,context_text], outputs=[image_output,text_output])
230
- # Add event listener for API key changes
231
- replicate_api_key.change(
232
- fn=update_replicate_api_key,
233
- inputs=[replicate_api_key],
234
- outputs=[api_key_status]
235
- )
236
-
237
-
238
  demo.launch(debug=True)
 
1
+
2
+ from dotenv import load_dotenv, find_dotenv
3
+ _ = load_dotenv(find_dotenv())
4
+ from src.utils import create_zip,add_to_prompt,update_dropdown
5
+ from src.rep_api import replicate_caption_api,generate_image_replicate,traning_function,update_replicate_api_key
6
+ import gradio as gr
7
+ from PIL import Image
8
+ import os
9
+ import time
10
+ import json
11
+
12
+
13
+ # The dictionary data
14
+ prompt_dict = {
15
+ "Character": ["Asian girl with black hair", "A man with blond hair", "A Cat girl anime character with purple hair", "A Green Alien with big black eyes"],
16
+ "Clothes": ["Wearing a blue jacket", "Wearing a black business suit", "Wearing a purple jumpsuit", "Wearing shorts and a white T-shirt"],
17
+ "Pose": ["Close up portrait", "Standing doing a peace sign", "Folding arms", "holding a phone"],
18
+ "Style": ["Simple white background", "Fashion runway", "Inside a business conference", "Inside a spaceship"]
19
+ }
20
+ style_json="model_dict.json"
21
+ model_dict=json.load(open(style_json,"r"))
22
+
23
+ def process_images(files,model,context_text):
24
+ images = []
25
+ textbox =""
26
+ for file in files:
27
+ print(file)
28
+ image = Image.open(file)
29
+ caption = replicate_caption_api(image,model,context_text)
30
+ textbox += f"Tags: {caption}, file: " + os.path.basename(file) + "\n"
31
+ images.append(image)
32
+ #texts.append(textbox)
33
+ zip_path=create_zip(files,textbox,"TOK")
34
+
35
+ return images, textbox,zip_path
36
+
37
+
38
+
39
+
40
+ with gr.Blocks( theme="NoCrypt/miku") as demo:
41
+
42
+ with gr.Tabs() as tabs:
43
+ with gr.TabItem("Image Generator"):
44
+ gr.Markdown(" #Image Generator")
45
+ with gr.Row():
46
+ with gr.Column():
47
+ inp = gr.Textbox(label="Prompt")
48
+
49
+
50
+ btn = gr.Button("Generate")
51
+ with gr.Column():
52
+ ar = gr.Dropdown(["1:1","16:9","9:16","5:3"], label="Aspect Ratio", info="Aspect Ratio")
53
+ style_mode = gr.Dropdown(model_dict.keys(),label="Style lore")
54
+ api_path = gr.Textbox(label="API_route",info="replicate api route goes here")
55
+
56
+
57
+
58
+ with gr.Accordion("Prompt Support", open=False):
59
+ for key, values in prompt_dict.items():
60
+ with gr.Row():
61
+ #gr.Markdown(f"**{key}**")
62
+ gr.Button(key,interactive=False)
63
+ for value in values:
64
+ gr.Button(value).click(add_to_prompt, inputs=[inp, gr.Textbox(value,visible=False)], outputs=inp)
65
+
66
+ with gr.Row():
67
+ gen_out = gr.Image(label="Generated Image",type="filepath")
68
+
69
+
70
+ btn.click(generate_image_replicate, inputs=[inp,api_path], outputs=gen_out,queue=True)
71
+
72
+
73
+
74
+ with gr.TabItem("Model Trainner"):
75
+ gr.Markdown("# Image Importing & Auto captions")
76
+ with gr.Row():
77
+ input_images = gr.File(file_count="multiple", type="filepath", label="Upload Images")
78
+ label_model = gr.Dropdown(["blip", "llava-16","img2prompt"], label="Caption model", info="Auto caption model")
79
+ token_string= gr.Textbox(label="Token string",value="TOK",interactive=True,
80
+ info="A unique string that will be trained to refer to the concept in the input images. Can be anything, but TOK works well.")
81
+ context_text = gr.Textbox(label="Context Text", info="Context Text for auto catpion",value=" I want a description caption for this image")
82
+ # Replicate API Key input
83
+ replicate_api_key = gr.Textbox(
84
+ label="Replicate API Key",
85
+ info="API key for Replicate",
86
+ value=os.environ.get("REPLICATE_API_TOKEN", ""),
87
+ type="password"
88
+ )
89
+ api_key_status = gr.Textbox(label="API Key Status", interactive=False)
90
+
91
+ with gr.Row():
92
+ process_button = gr.Button("Process Images")
93
+ #Image outputs
94
+ with gr.Row():
95
+ gr.Markdown("# Traning Captions Data")
96
+ with gr.Row():
97
+ with gr.Column():
98
+ image_output = gr.Gallery(type="pil",object_fit="fill")
99
+ with gr.Column():
100
+ text_output = gr.Textbox( interactive=True)
101
+ with gr.Row():
102
+ zip_output = gr.File(label="Zip file")
103
+ btn_update_zip = gr.Button("Update zip file")
104
+
105
+
106
+
107
+ #Traning options
108
+ with gr.Row():
109
+ gr.Markdown("# Training on replicate")
110
+ with gr.Row():
111
+ traning_model = gr.Dropdown(["flux"], label="Caption model", info="Auto caption model")
112
+ traning_destination = gr.Textbox(label="destination",info="add in replicate model destination")
113
+ seed = gr.Number(label="Seed", value=42,info="Random seed integer for reproducible training. Leave empty to use a random seed.")
114
+
115
+ max_train_steps =gr.Number(label="max_train_steps", value= 1000, info="Number of individual training steps. Takes precedence over num_train_epochs.")
116
+ #with gr.Row():
117
+ # hf_repo_id = gr.Textbox(label="Hugging face repo id",info="Hugging Face repository ID, if you'd like to upload the trained LoRA to Hugging Face. For example, lucataco/flux-dev-lora.")
118
+ # hf_token = gr.Textbox(label="Hugging face write token",info="Hugging Face token, if you'd like to upload the trained LoRA to Hugging Face.")
119
+ with gr.Row():
120
+ train_button = gr.Button("Train")
121
+ with gr.Row():
122
+ training_logs = gr.Textbox(label="Training logs")
123
+ traning_finnal = gr.Textbox(label="Traning finnal")
124
+ #training_images = gr.Gallery(label="Training images")
125
+
126
+
127
+ #gr.Textbox("TOK",visible=False) added to deal with odd ies of the token string being a gradio class
128
+ train_button.click(fn=traning_function, inputs=[zip_output,traning_model,traning_destination,seed,token_string,max_train_steps], #,hf_repo_id,hf_token
129
+ outputs=[training_logs,traning_finnal],queue=True)
130
+ process_button.click(fn=process_images, inputs=[input_images,label_model,context_text,token_string], outputs=[image_output,text_output,zip_output],queue=True)
131
+ btn_update_zip.click(fn=create_zip, inputs=[image_output,text_output,token_string],outputs=zip_output)
132
+ # Add event listener for API key changes
133
+ traning_finnal.change(
134
+ fn=update_dropdown,
135
+ inputs=[traning_finnal,token_string],
136
+ outputs=style_mode
137
+ )
138
+ replicate_api_key.change(
139
+ fn=update_replicate_api_key,
140
+ inputs=[replicate_api_key],
141
+ outputs=[api_key_status]
142
+ )
143
+ #jarvis-labs2024/sioux-flux
144
+ demo.queue() # Queue for concurrent users
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  demo.launch(debug=True)
model_dict.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "Base": "black-forest-labs/flux-dev",
3
+ "Raylean": "jarvis-labs2024/flux-raylene:5574556226d11e0f10855a957d91f118a9178c8fc77e7e7b18830627ce3184f1",
4
+ "Alice": "jarvis-labs2024/flux-raylene:5574556226d11e0f10855a957d91f118a9178c8fc77e7e7b18830627ce3184f1",
5
+ "AppleSeed": "jarvis-labs2024/flux-appleseed:0aecb9fdfb17a2517112cc70b4a1898aa7791da84a010419782ce7043481edec",
6
+ "console_cowboy_flux": "jarvis-labs2024/console_cowboy_flux:53ff894d719f73dc11ca54fdb6ecf044d7d202aa30fce43236fbfda30b19ef62"
7
+ }
requirments.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ fal
3
+ fal-client
4
+ numpy
5
+ replicate
6
+ python-dotenv