JarvisLabs commited on
Commit
481ccb1
1 Parent(s): d97c68b

Update src/rep_api.py

Browse files
Files changed (1) hide show
  1. src/rep_api.py +213 -212
src/rep_api.py CHANGED
@@ -1,212 +1,213 @@
1
- import replicate
2
- import os
3
- from src.utils import image_to_base64 , update_model_dicts, BB_uploadfile,numpy_to_base64
4
- from src.deepl import detect_and_translate
5
- import json
6
- import time
7
- style_json="model_dict.json"
8
- model_dict=json.load(open(style_json,"r"))
9
-
10
-
11
-
12
- def generate_image_replicate(prompt,lora_model,api_path,aspect_ratio,gallery,model,lora_scale,num_outputs=1,guidance_scale=3.5,seed=None):
13
- print(prompt,lora_model,api_path,aspect_ratio)
14
-
15
- #if model=="dev":
16
- num_inference_steps=30
17
- if model=="schnell":
18
- num_inference_steps=5
19
-
20
- if lora_model is not None:
21
- api_path=model_dict[lora_model]
22
-
23
- inputs={
24
- "model": model,
25
- "prompt": detect_and_translate(prompt),
26
- "lora_scale":lora_scale,
27
- "aspect_ratio": aspect_ratio,
28
- "num_outputs":num_outputs,
29
- "num_inference_steps":num_inference_steps,
30
- "guidance_scale":guidance_scale,
31
- }
32
- if seed is not None:
33
- inputs["seed"]=seed
34
- output = replicate.run(
35
- api_path,
36
- input=inputs
37
- )
38
- print(output)
39
- if gallery is None:
40
- gallery=[]
41
- gallery.append(output[0])
42
- return output[0],gallery
43
-
44
-
45
- def replicate_caption_api(image,model,context_text):
46
- base64_image = image_to_base64(image)
47
- if model=="blip":
48
- output = replicate.run(
49
- "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9",
50
- input={
51
- "image": base64_image,
52
- "caption": True,
53
- "question": context_text,
54
- "temperature": 1,
55
- "use_nucleus_sampling": False
56
- }
57
- )
58
- print(output)
59
-
60
- elif model=="llava-16":
61
- output = replicate.run(
62
- # "yorickvp/llava-13b:80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",
63
- "yorickvp/llava-v1.6-34b:41ecfbfb261e6c1adf3ad896c9066ca98346996d7c4045c5bc944a79d430f174",
64
- input={
65
- "image": base64_image,
66
- "top_p": 1,
67
- "prompt": context_text,
68
- "max_tokens": 1024,
69
- "temperature": 0.2
70
- }
71
- )
72
- print(output)
73
- output = "".join(output)
74
-
75
- elif model=="img2prompt":
76
- output = replicate.run(
77
- "methexis-inc/img2prompt:50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5",
78
- input={
79
- "image":base64_image
80
- }
81
- )
82
- print(output)
83
- return output
84
-
85
- def update_replicate_api_key(api_key):
86
- os.environ["REPLICATE_API_TOKEN"] = api_key
87
- return f"Replicate API key updated: {api_key[:5]}..." if api_key else "Replicate API key cleared"
88
-
89
-
90
- def virtual_try_on(crop, seed, steps, category, garm_img, human_img, garment_des):
91
- output = replicate.run(
92
- "cuuupid/idm-vton:906425dbca90663ff5427624839572cc56ea7d380343d13e2a4c4b09d3f0c30f",
93
- input={
94
- "crop": crop,
95
- "seed": seed,
96
- "steps": steps,
97
- "category": category,
98
- # "force_dc": force_dc,
99
- "garm_img": numpy_to_base64( garm_img),
100
- "human_img": numpy_to_base64(human_img),
101
- #"mask_only": mask_only,
102
- "garment_des": garment_des
103
- }
104
- )
105
- print(output)
106
- return output
107
-
108
-
109
- from src.utils import create_zip
110
- from PIL import Image
111
-
112
-
113
- def process_images(files,model,context_text,token_string):
114
- images = []
115
- textbox =""
116
- for file in files:
117
- print(file)
118
- image = Image.open(file)
119
- if model=="None":
120
- caption="[Insert cap here]"
121
- else:
122
- caption = replicate_caption_api(image,model,context_text)
123
- textbox += f"Tags: {caption}, file: " + os.path.basename(file) + "\n"
124
- images.append(image)
125
- #texts.append(textbox)
126
- zip_path=create_zip(files,textbox,token_string)
127
-
128
- return images, textbox,zip_path
129
-
130
- def replicate_create_model(owner,name,visibility="private",hardware="gpu-a40-large"):
131
- try:
132
- model = replicate.models.create(
133
- owner=owner,
134
- name=name,
135
- visibility=visibility,
136
- hardware=hardware,
137
- )
138
- print(model)
139
- return True
140
- except Exception as e:
141
- print(e)
142
- if "A model with that name and owner already exists" in str(e):
143
- return True
144
- return False
145
-
146
-
147
-
148
- def traning_function(zip_path,training_model,training_destination,seed,token_string,max_train_steps,hf_repo_id=None,hf_token=None):
149
- ##Place holder for now
150
- BB_bucket_name="jarvisdataset"
151
- BB_defult="https://f005.backblazeb2.com/file/"
152
- if BB_defult not in zip_path:
153
- zip_path=BB_uploadfile(zip_path,os.path.basename(zip_path),BB_bucket_name)
154
- print(zip_path)
155
- training_logs = f"Using zip traning file at: {zip_path}\n"
156
- yield training_logs, None
157
- input={
158
- "steps": max_train_steps,
159
- "lora_rank": 16,
160
- "batch_size": 1,
161
- "autocaption": True,
162
- "trigger_word": token_string,
163
- "learning_rate": 0.0004,
164
- "seed": seed,
165
- "input_images": zip_path
166
- }
167
- print(training_destination)
168
- username,model_name=training_destination.split("/")
169
- assert replicate_create_model(username,model_name,visibility="private",hardware="gpu-a40-large"),"Error in creating model on replicate, check API key and username is correct "
170
-
171
- print(input)
172
- try:
173
- training = replicate.trainings.create(
174
- destination=training_destination,
175
- version="ostris/flux-dev-lora-trainer:1296f0ab2d695af5a1b5eeee6e8ec043145bef33f1675ce1a2cdb0f81ec43f02",
176
- input=input,
177
- )
178
-
179
- training_logs = f"Training started with model: {training_model}\n"
180
- training_logs += f"Destination: {training_destination}\n"
181
- training_logs += f"Seed: {seed}\n"
182
- training_logs += f"Token string: {token_string}\n"
183
- training_logs += f"Max train steps: {max_train_steps}\n"
184
-
185
- # Poll the training status
186
- while training.status != "succeeded":
187
- training.reload()
188
- training_logs += f"Training status: {training.status}\n"
189
- training_logs += f"{training.logs}\n"
190
- if training.status == "failed":
191
- training_logs += "Training failed!\n"
192
- return training_logs, training
193
-
194
- yield training_logs, None
195
- time.sleep(10) # Wait for 10 seconds before checking again
196
-
197
- training_logs += "Training completed!\n"
198
- if hf_repo_id and hf_token:
199
- training_logs += f"Uploading to Hugging Face repo: {hf_repo_id}\n"
200
- # Here you would implement the logic to upload to Hugging Face
201
-
202
- traning_finnal=training.output
203
-
204
- # In a real scenario, you might want to download and display some result images
205
- # For now, we'll just return the original images
206
- #images = [Image.open(file) for file in files]
207
- _= update_model_dicts(traning_finnal["version"],token_string,style_json="model_dict.json")
208
- traning_finnal["replicate_link"]="https://replicate.com/"+traning_finnal["version"].replace(":","/")
209
- yield training_logs, traning_finnal
210
-
211
- except Exception as e:
212
- yield f"An error occurred: {str(e)}", None
 
 
1
+ import replicate
2
+ import os
3
+ from src.utils import image_to_base64 , update_model_dicts, BB_uploadfile,numpy_to_base64
4
+ from src.deepl import detect_and_translate
5
+ import json
6
+ import time
7
+ style_json="model_dict.json"
8
+ model_dict=json.load(open(style_json,"r"))
9
+
10
+
11
+
12
+ def generate_image_replicate(prompt,lora_model,api_path,aspect_ratio,gallery,model,lora_scale,num_outputs=1,guidance_scale=3.5,seed=None):
13
+ print(prompt,lora_model,api_path,aspect_ratio)
14
+
15
+ #if model=="dev":
16
+ num_inference_steps=30
17
+ if model=="schnell":
18
+ num_inference_steps=5
19
+
20
+ if lora_model is not None:
21
+ api_path=model_dict[lora_model]
22
+
23
+ inputs={
24
+ "model": model,
25
+ "prompt": detect_and_translate(prompt),
26
+ "lora_scale":lora_scale,
27
+ "aspect_ratio": aspect_ratio,
28
+ "num_outputs":num_outputs,
29
+ "num_inference_steps":num_inference_steps,
30
+ "guidance_scale":guidance_scale,
31
+ "output_format":"png",
32
+ }
33
+ if seed is not None:
34
+ inputs["seed"]=seed
35
+ output = replicate.run(
36
+ api_path,
37
+ input=inputs
38
+ )
39
+ print(output)
40
+ if gallery is None:
41
+ gallery=[]
42
+ gallery.append(output[0])
43
+ return output[0],gallery
44
+
45
+
46
+ def replicate_caption_api(image,model,context_text):
47
+ base64_image = image_to_base64(image)
48
+ if model=="blip":
49
+ output = replicate.run(
50
+ "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9",
51
+ input={
52
+ "image": base64_image,
53
+ "caption": True,
54
+ "question": context_text,
55
+ "temperature": 1,
56
+ "use_nucleus_sampling": False
57
+ }
58
+ )
59
+ print(output)
60
+
61
+ elif model=="llava-16":
62
+ output = replicate.run(
63
+ # "yorickvp/llava-13b:80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",
64
+ "yorickvp/llava-v1.6-34b:41ecfbfb261e6c1adf3ad896c9066ca98346996d7c4045c5bc944a79d430f174",
65
+ input={
66
+ "image": base64_image,
67
+ "top_p": 1,
68
+ "prompt": context_text,
69
+ "max_tokens": 1024,
70
+ "temperature": 0.2
71
+ }
72
+ )
73
+ print(output)
74
+ output = "".join(output)
75
+
76
+ elif model=="img2prompt":
77
+ output = replicate.run(
78
+ "methexis-inc/img2prompt:50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5",
79
+ input={
80
+ "image":base64_image
81
+ }
82
+ )
83
+ print(output)
84
+ return output
85
+
86
+ def update_replicate_api_key(api_key):
87
+ os.environ["REPLICATE_API_TOKEN"] = api_key
88
+ return f"Replicate API key updated: {api_key[:5]}..." if api_key else "Replicate API key cleared"
89
+
90
+
91
+ def virtual_try_on(crop, seed, steps, category, garm_img, human_img, garment_des):
92
+ output = replicate.run(
93
+ "cuuupid/idm-vton:906425dbca90663ff5427624839572cc56ea7d380343d13e2a4c4b09d3f0c30f",
94
+ input={
95
+ "crop": crop,
96
+ "seed": seed,
97
+ "steps": steps,
98
+ "category": category,
99
+ # "force_dc": force_dc,
100
+ "garm_img": numpy_to_base64( garm_img),
101
+ "human_img": numpy_to_base64(human_img),
102
+ #"mask_only": mask_only,
103
+ "garment_des": garment_des
104
+ }
105
+ )
106
+ print(output)
107
+ return output
108
+
109
+
110
+ from src.utils import create_zip
111
+ from PIL import Image
112
+
113
+
114
+ def process_images(files,model,context_text,token_string):
115
+ images = []
116
+ textbox =""
117
+ for file in files:
118
+ print(file)
119
+ image = Image.open(file)
120
+ if model=="None":
121
+ caption="[Insert cap here]"
122
+ else:
123
+ caption = replicate_caption_api(image,model,context_text)
124
+ textbox += f"Tags: {caption}, file: " + os.path.basename(file) + "\n"
125
+ images.append(image)
126
+ #texts.append(textbox)
127
+ zip_path=create_zip(files,textbox,token_string)
128
+
129
+ return images, textbox,zip_path
130
+
131
+ def replicate_create_model(owner,name,visibility="private",hardware="gpu-a40-large"):
132
+ try:
133
+ model = replicate.models.create(
134
+ owner=owner,
135
+ name=name,
136
+ visibility=visibility,
137
+ hardware=hardware,
138
+ )
139
+ print(model)
140
+ return True
141
+ except Exception as e:
142
+ print(e)
143
+ if "A model with that name and owner already exists" in str(e):
144
+ return True
145
+ return False
146
+
147
+
148
+
149
+ def traning_function(zip_path,training_model,training_destination,seed,token_string,max_train_steps,hf_repo_id=None,hf_token=None):
150
+ ##Place holder for now
151
+ BB_bucket_name="jarvisdataset"
152
+ BB_defult="https://f005.backblazeb2.com/file/"
153
+ if BB_defult not in zip_path:
154
+ zip_path=BB_uploadfile(zip_path,os.path.basename(zip_path),BB_bucket_name)
155
+ print(zip_path)
156
+ training_logs = f"Using zip traning file at: {zip_path}\n"
157
+ yield training_logs, None
158
+ input={
159
+ "steps": max_train_steps,
160
+ "lora_rank": 16,
161
+ "batch_size": 1,
162
+ "autocaption": True,
163
+ "trigger_word": token_string,
164
+ "learning_rate": 0.0004,
165
+ "seed": seed,
166
+ "input_images": zip_path
167
+ }
168
+ print(training_destination)
169
+ username,model_name=training_destination.split("/")
170
+ assert replicate_create_model(username,model_name,visibility="private",hardware="gpu-a40-large"),"Error in creating model on replicate, check API key and username is correct "
171
+
172
+ print(input)
173
+ try:
174
+ training = replicate.trainings.create(
175
+ destination=training_destination,
176
+ version="ostris/flux-dev-lora-trainer:1296f0ab2d695af5a1b5eeee6e8ec043145bef33f1675ce1a2cdb0f81ec43f02",
177
+ input=input,
178
+ )
179
+
180
+ training_logs = f"Training started with model: {training_model}\n"
181
+ training_logs += f"Destination: {training_destination}\n"
182
+ training_logs += f"Seed: {seed}\n"
183
+ training_logs += f"Token string: {token_string}\n"
184
+ training_logs += f"Max train steps: {max_train_steps}\n"
185
+
186
+ # Poll the training status
187
+ while training.status != "succeeded":
188
+ training.reload()
189
+ training_logs += f"Training status: {training.status}\n"
190
+ training_logs += f"{training.logs}\n"
191
+ if training.status == "failed":
192
+ training_logs += "Training failed!\n"
193
+ return training_logs, training
194
+
195
+ yield training_logs, None
196
+ time.sleep(10) # Wait for 10 seconds before checking again
197
+
198
+ training_logs += "Training completed!\n"
199
+ if hf_repo_id and hf_token:
200
+ training_logs += f"Uploading to Hugging Face repo: {hf_repo_id}\n"
201
+ # Here you would implement the logic to upload to Hugging Face
202
+
203
+ traning_finnal=training.output
204
+
205
+ # In a real scenario, you might want to download and display some result images
206
+ # For now, we'll just return the original images
207
+ #images = [Image.open(file) for file in files]
208
+ _= update_model_dicts(traning_finnal["version"],token_string,style_json="model_dict.json")
209
+ traning_finnal["replicate_link"]="https://replicate.com/"+traning_finnal["version"].replace(":","/")
210
+ yield training_logs, traning_finnal
211
+
212
+ except Exception as e:
213
+ yield f"An error occurred: {str(e)}", None