JarvisLabs commited on
Commit
1de6488
1 Parent(s): 743d32b

Upload rep_api.py

Browse files
Files changed (1) hide show
  1. src/rep_api.py +212 -210
src/rep_api.py CHANGED
@@ -1,210 +1,212 @@
1
- import replicate
2
- import os
3
- from src.utils import image_to_base64 , update_model_dicts, BB_uploadfile,numpy_to_base64
4
- import json
5
- import time
6
- style_json="model_dict.json"
7
- model_dict=json.load(open(style_json,"r"))
8
-
9
-
10
-
11
- def generate_image_replicate(prompt,lora_model,api_path,aspect_ratio,gallery,model,lora_scale=1,num_outputs=1,guidance_scale=3.5,seed=None):
12
- print(prompt,lora_model,api_path,aspect_ratio)
13
- #if model=="dev":
14
- num_inference_steps=30
15
- if model=="schnell":
16
- num_inference_steps=5
17
-
18
- if lora_model is not None:
19
- api_path=model_dict[lora_model]
20
-
21
- inputs={
22
- "model": model,
23
- "prompt": prompt,
24
- "lora_scale":lora_scale,
25
- "aspect_ratio": aspect_ratio,
26
- "num_outputs":num_outputs,
27
- "num_inference_steps":num_inference_steps,
28
- "guidance_scale":guidance_scale,
29
- }
30
- if seed is not None:
31
- inputs["seed"]=seed
32
- output = replicate.run(
33
- api_path,
34
- input=inputs
35
- )
36
- print(output)
37
- if gallery is None:
38
- gallery=[]
39
- gallery.append(output[0])
40
- return output[0],gallery
41
-
42
-
43
- def replicate_caption_api(image,model,context_text):
44
- base64_image = image_to_base64(image)
45
- if model=="blip":
46
- output = replicate.run(
47
- "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9",
48
- input={
49
- "image": base64_image,
50
- "caption": True,
51
- "question": context_text,
52
- "temperature": 1,
53
- "use_nucleus_sampling": False
54
- }
55
- )
56
- print(output)
57
-
58
- elif model=="llava-16":
59
- output = replicate.run(
60
- # "yorickvp/llava-13b:80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",
61
- "yorickvp/llava-v1.6-34b:41ecfbfb261e6c1adf3ad896c9066ca98346996d7c4045c5bc944a79d430f174",
62
- input={
63
- "image": base64_image,
64
- "top_p": 1,
65
- "prompt": context_text,
66
- "max_tokens": 1024,
67
- "temperature": 0.2
68
- }
69
- )
70
- print(output)
71
- output = "".join(output)
72
-
73
- elif model=="img2prompt":
74
- output = replicate.run(
75
- "methexis-inc/img2prompt:50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5",
76
- input={
77
- "image":base64_image
78
- }
79
- )
80
- print(output)
81
- return output
82
-
83
- def update_replicate_api_key(api_key):
84
- os.environ["REPLICATE_API_TOKEN"] = api_key
85
- return f"Replicate API key updated: {api_key[:5]}..." if api_key else "Replicate API key cleared"
86
-
87
-
88
- def virtual_try_on(crop, seed, steps, category, garm_img, human_img, garment_des):
89
- output = replicate.run(
90
- "cuuupid/idm-vton:906425dbca90663ff5427624839572cc56ea7d380343d13e2a4c4b09d3f0c30f",
91
- input={
92
- "crop": crop,
93
- "seed": seed,
94
- "steps": steps,
95
- "category": category,
96
- # "force_dc": force_dc,
97
- "garm_img": numpy_to_base64( garm_img),
98
- "human_img": numpy_to_base64(human_img),
99
- #"mask_only": mask_only,
100
- "garment_des": garment_des
101
- }
102
- )
103
- print(output)
104
- return output
105
-
106
-
107
- from src.utils import create_zip
108
- from PIL import Image
109
-
110
-
111
- def process_images(files,model,context_text,token_string):
112
- images = []
113
- textbox =""
114
- for file in files:
115
- print(file)
116
- image = Image.open(file)
117
- if model=="None":
118
- caption="[Insert cap here]"
119
- else:
120
- caption = replicate_caption_api(image,model,context_text)
121
- textbox += f"Tags: {caption}, file: " + os.path.basename(file) + "\n"
122
- images.append(image)
123
- #texts.append(textbox)
124
- zip_path=create_zip(files,textbox,token_string)
125
-
126
- return images, textbox,zip_path
127
-
128
- def replicate_create_model(owner,name,visibility="private",hardware="gpu-a40-large"):
129
- try:
130
- model = replicate.models.create(
131
- owner=owner,
132
- name=name,
133
- visibility=visibility,
134
- hardware=hardware,
135
- )
136
- print(model)
137
- return True
138
- except Exception as e:
139
- print(e)
140
- if "A model with that name and owner already exists" in str(e):
141
- return True
142
- return False
143
-
144
-
145
-
146
- def traning_function(zip_path,training_model,training_destination,seed,token_string,max_train_steps,hf_repo_id=None,hf_token=None):
147
- ##Place holder for now
148
- BB_bucket_name="jarvisdataset"
149
- BB_defult="https://f005.backblazeb2.com/file/"
150
- if BB_defult not in zip_path:
151
- zip_path=BB_uploadfile(zip_path,os.path.basename(zip_path),BB_bucket_name)
152
- print(zip_path)
153
- training_logs = f"Using zip traning file at: {zip_path}\n"
154
- yield training_logs, None
155
- input={
156
- "steps": max_train_steps,
157
- "lora_rank": 16,
158
- "batch_size": 1,
159
- "autocaption": True,
160
- "trigger_word": token_string,
161
- "learning_rate": 0.0004,
162
- "seed": seed,
163
- "input_images": zip_path
164
- }
165
- print(training_destination)
166
- username,model_name=training_destination.split("/")
167
- assert replicate_create_model(username,model_name,visibility="private",hardware="gpu-a40-large"),"Error in creating model on replicate, check API key and username is correct "
168
-
169
- print(input)
170
- try:
171
- training = replicate.trainings.create(
172
- destination=training_destination,
173
- version="ostris/flux-dev-lora-trainer:1296f0ab2d695af5a1b5eeee6e8ec043145bef33f1675ce1a2cdb0f81ec43f02",
174
- input=input,
175
- )
176
-
177
- training_logs = f"Training started with model: {training_model}\n"
178
- training_logs += f"Destination: {training_destination}\n"
179
- training_logs += f"Seed: {seed}\n"
180
- training_logs += f"Token string: {token_string}\n"
181
- training_logs += f"Max train steps: {max_train_steps}\n"
182
-
183
- # Poll the training status
184
- while training.status != "succeeded":
185
- training.reload()
186
- training_logs += f"Training status: {training.status}\n"
187
- training_logs += f"{training.logs}\n"
188
- if training.status == "failed":
189
- training_logs += "Training failed!\n"
190
- return training_logs, training
191
-
192
- yield training_logs, None
193
- time.sleep(10) # Wait for 10 seconds before checking again
194
-
195
- training_logs += "Training completed!\n"
196
- if hf_repo_id and hf_token:
197
- training_logs += f"Uploading to Hugging Face repo: {hf_repo_id}\n"
198
- # Here you would implement the logic to upload to Hugging Face
199
-
200
- traning_finnal=training.output
201
-
202
- # In a real scenario, you might want to download and display some result images
203
- # For now, we'll just return the original images
204
- #images = [Image.open(file) for file in files]
205
- _= update_model_dicts(traning_finnal["version"],token_string,style_json="model_dict.json")
206
- traning_finnal["replicate_link"]="https://replicate.com/"+traning_finnal["version"].replace(":","/")
207
- yield training_logs, traning_finnal
208
-
209
- except Exception as e:
210
- yield f"An error occurred: {str(e)}", None
 
 
 
1
+ import replicate
2
+ import os
3
+ from src.utils import image_to_base64 , update_model_dicts, BB_uploadfile,numpy_to_base64
4
+ from src.deepl import detect_and_translate
5
+ import json
6
+ import time
7
+ style_json="model_dict.json"
8
+ model_dict=json.load(open(style_json,"r"))
9
+
10
+
11
+
12
+ def generate_image_replicate(prompt,lora_model,api_path,aspect_ratio,gallery,model,lora_scale,num_outputs=1,guidance_scale=3.5,seed=None):
13
+ print(prompt,lora_model,api_path,aspect_ratio)
14
+
15
+ #if model=="dev":
16
+ num_inference_steps=30
17
+ if model=="schnell":
18
+ num_inference_steps=5
19
+
20
+ if lora_model is not None:
21
+ api_path=model_dict[lora_model]
22
+
23
+ inputs={
24
+ "model": model,
25
+ "prompt": detect_and_translate(prompt),
26
+ "lora_scale":lora_scale,
27
+ "aspect_ratio": aspect_ratio,
28
+ "num_outputs":num_outputs,
29
+ "num_inference_steps":num_inference_steps,
30
+ "guidance_scale":guidance_scale,
31
+ }
32
+ if seed is not None:
33
+ inputs["seed"]=seed
34
+ output = replicate.run(
35
+ api_path,
36
+ input=inputs
37
+ )
38
+ print(output)
39
+ if gallery is None:
40
+ gallery=[]
41
+ gallery.append(output[0])
42
+ return output[0],gallery
43
+
44
+
45
+ def replicate_caption_api(image,model,context_text):
46
+ base64_image = image_to_base64(image)
47
+ if model=="blip":
48
+ output = replicate.run(
49
+ "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9",
50
+ input={
51
+ "image": base64_image,
52
+ "caption": True,
53
+ "question": context_text,
54
+ "temperature": 1,
55
+ "use_nucleus_sampling": False
56
+ }
57
+ )
58
+ print(output)
59
+
60
+ elif model=="llava-16":
61
+ output = replicate.run(
62
+ # "yorickvp/llava-13b:80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",
63
+ "yorickvp/llava-v1.6-34b:41ecfbfb261e6c1adf3ad896c9066ca98346996d7c4045c5bc944a79d430f174",
64
+ input={
65
+ "image": base64_image,
66
+ "top_p": 1,
67
+ "prompt": context_text,
68
+ "max_tokens": 1024,
69
+ "temperature": 0.2
70
+ }
71
+ )
72
+ print(output)
73
+ output = "".join(output)
74
+
75
+ elif model=="img2prompt":
76
+ output = replicate.run(
77
+ "methexis-inc/img2prompt:50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5",
78
+ input={
79
+ "image":base64_image
80
+ }
81
+ )
82
+ print(output)
83
+ return output
84
+
85
+ def update_replicate_api_key(api_key):
86
+ os.environ["REPLICATE_API_TOKEN"] = api_key
87
+ return f"Replicate API key updated: {api_key[:5]}..." if api_key else "Replicate API key cleared"
88
+
89
+
90
+ def virtual_try_on(crop, seed, steps, category, garm_img, human_img, garment_des):
91
+ output = replicate.run(
92
+ "cuuupid/idm-vton:906425dbca90663ff5427624839572cc56ea7d380343d13e2a4c4b09d3f0c30f",
93
+ input={
94
+ "crop": crop,
95
+ "seed": seed,
96
+ "steps": steps,
97
+ "category": category,
98
+ # "force_dc": force_dc,
99
+ "garm_img": numpy_to_base64( garm_img),
100
+ "human_img": numpy_to_base64(human_img),
101
+ #"mask_only": mask_only,
102
+ "garment_des": garment_des
103
+ }
104
+ )
105
+ print(output)
106
+ return output
107
+
108
+
109
+ from src.utils import create_zip
110
+ from PIL import Image
111
+
112
+
113
+ def process_images(files,model,context_text,token_string):
114
+ images = []
115
+ textbox =""
116
+ for file in files:
117
+ print(file)
118
+ image = Image.open(file)
119
+ if model=="None":
120
+ caption="[Insert cap here]"
121
+ else:
122
+ caption = replicate_caption_api(image,model,context_text)
123
+ textbox += f"Tags: {caption}, file: " + os.path.basename(file) + "\n"
124
+ images.append(image)
125
+ #texts.append(textbox)
126
+ zip_path=create_zip(files,textbox,token_string)
127
+
128
+ return images, textbox,zip_path
129
+
130
+ def replicate_create_model(owner,name,visibility="private",hardware="gpu-a40-large"):
131
+ try:
132
+ model = replicate.models.create(
133
+ owner=owner,
134
+ name=name,
135
+ visibility=visibility,
136
+ hardware=hardware,
137
+ )
138
+ print(model)
139
+ return True
140
+ except Exception as e:
141
+ print(e)
142
+ if "A model with that name and owner already exists" in str(e):
143
+ return True
144
+ return False
145
+
146
+
147
+
148
+ def traning_function(zip_path,training_model,training_destination,seed,token_string,max_train_steps,hf_repo_id=None,hf_token=None):
149
+ ##Place holder for now
150
+ BB_bucket_name="jarvisdataset"
151
+ BB_defult="https://f005.backblazeb2.com/file/"
152
+ if BB_defult not in zip_path:
153
+ zip_path=BB_uploadfile(zip_path,os.path.basename(zip_path),BB_bucket_name)
154
+ print(zip_path)
155
+ training_logs = f"Using zip traning file at: {zip_path}\n"
156
+ yield training_logs, None
157
+ input={
158
+ "steps": max_train_steps,
159
+ "lora_rank": 16,
160
+ "batch_size": 1,
161
+ "autocaption": True,
162
+ "trigger_word": token_string,
163
+ "learning_rate": 0.0004,
164
+ "seed": seed,
165
+ "input_images": zip_path
166
+ }
167
+ print(training_destination)
168
+ username,model_name=training_destination.split("/")
169
+ assert replicate_create_model(username,model_name,visibility="private",hardware="gpu-a40-large"),"Error in creating model on replicate, check API key and username is correct "
170
+
171
+ print(input)
172
+ try:
173
+ training = replicate.trainings.create(
174
+ destination=training_destination,
175
+ version="ostris/flux-dev-lora-trainer:1296f0ab2d695af5a1b5eeee6e8ec043145bef33f1675ce1a2cdb0f81ec43f02",
176
+ input=input,
177
+ )
178
+
179
+ training_logs = f"Training started with model: {training_model}\n"
180
+ training_logs += f"Destination: {training_destination}\n"
181
+ training_logs += f"Seed: {seed}\n"
182
+ training_logs += f"Token string: {token_string}\n"
183
+ training_logs += f"Max train steps: {max_train_steps}\n"
184
+
185
+ # Poll the training status
186
+ while training.status != "succeeded":
187
+ training.reload()
188
+ training_logs += f"Training status: {training.status}\n"
189
+ training_logs += f"{training.logs}\n"
190
+ if training.status == "failed":
191
+ training_logs += "Training failed!\n"
192
+ return training_logs, training
193
+
194
+ yield training_logs, None
195
+ time.sleep(10) # Wait for 10 seconds before checking again
196
+
197
+ training_logs += "Training completed!\n"
198
+ if hf_repo_id and hf_token:
199
+ training_logs += f"Uploading to Hugging Face repo: {hf_repo_id}\n"
200
+ # Here you would implement the logic to upload to Hugging Face
201
+
202
+ traning_finnal=training.output
203
+
204
+ # In a real scenario, you might want to download and display some result images
205
+ # For now, we'll just return the original images
206
+ #images = [Image.open(file) for file in files]
207
+ _= update_model_dicts(traning_finnal["version"],token_string,style_json="model_dict.json")
208
+ traning_finnal["replicate_link"]="https://replicate.com/"+traning_finnal["version"].replace(":","/")
209
+ yield training_logs, traning_finnal
210
+
211
+ except Exception as e:
212
+ yield f"An error occurred: {str(e)}", None