JarvisLabs commited on
Commit
8276f79
1 Parent(s): f3299a1

Update src/rep_api.py

Browse files
Files changed (1) hide show
  1. src/rep_api.py +210 -190
src/rep_api.py CHANGED
@@ -1,190 +1,210 @@
1
- import replicate
2
- import os
3
- from src.utils import image_to_base64 , update_model_dicts, BB_uploadfile
4
- import json
5
- import time
6
- style_json="model_dict.json"
7
- model_dict=json.load(open(style_json,"r"))
8
-
9
-
10
-
11
- def generate_image_replicate(prompt,lora_model,api_path,aspect_ratio,gallery,model="dev",lora_scale=1,num_outputs=1,guidance_scale=3.5,seed=None):
12
- print(prompt,lora_model,api_path,aspect_ratio)
13
- #if model=="dev":
14
- num_inference_steps=30
15
- if model=="schnell":
16
- num_inference_steps=5
17
-
18
- if lora_model is not None:
19
- api_path=model_dict[lora_model]
20
-
21
- inputs={
22
- "model": model,
23
- "prompt": prompt,
24
- "lora_scale":lora_scale,
25
- "aspect_ratio": aspect_ratio,
26
- "num_outputs":num_outputs,
27
- "num_inference_steps":num_inference_steps,
28
- "guidance_scale":guidance_scale,
29
- }
30
- if seed is not None:
31
- inputs["seed"]=seed
32
- output = replicate.run(
33
- api_path,
34
- input=inputs
35
- )
36
- print(output)
37
- if gallery is None:
38
- gallery=[]
39
- gallery.append(output[0])
40
- return output[0],gallery
41
-
42
-
43
- def replicate_caption_api(image,model,context_text):
44
- base64_image = image_to_base64(image)
45
- if model=="blip":
46
- output = replicate.run(
47
- "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9",
48
- input={
49
- "image": base64_image,
50
- "caption": True,
51
- "question": context_text,
52
- "temperature": 1,
53
- "use_nucleus_sampling": False
54
- }
55
- )
56
- print(output)
57
-
58
- elif model=="llava-16":
59
- output = replicate.run(
60
- # "yorickvp/llava-13b:80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",
61
- "yorickvp/llava-v1.6-34b:41ecfbfb261e6c1adf3ad896c9066ca98346996d7c4045c5bc944a79d430f174",
62
- input={
63
- "image": base64_image,
64
- "top_p": 1,
65
- "prompt": context_text,
66
- "max_tokens": 1024,
67
- "temperature": 0.2
68
- }
69
- )
70
- print(output)
71
- output = "".join(output)
72
-
73
- elif model=="img2prompt":
74
- output = replicate.run(
75
- "methexis-inc/img2prompt:50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5",
76
- input={
77
- "image":base64_image
78
- }
79
- )
80
- print(output)
81
- return output
82
-
83
- def update_replicate_api_key(api_key):
84
- os.environ["REPLICATE_API_TOKEN"] = api_key
85
- return f"Replicate API key updated: {api_key[:5]}..." if api_key else "Replicate API key cleared"
86
-
87
- from src.utils import create_zip
88
- from PIL import Image
89
-
90
-
91
- def process_images(files,model,context_text,token_string):
92
- images = []
93
- textbox =""
94
- for file in files:
95
- print(file)
96
- image = Image.open(file)
97
- if model=="None":
98
- caption="[Insert cap here]"
99
- else:
100
- caption = replicate_caption_api(image,model,context_text)
101
- textbox += f"Tags: {caption}, file: " + os.path.basename(file) + "\n"
102
- images.append(image)
103
- #texts.append(textbox)
104
- zip_path=create_zip(files,textbox,token_string)
105
-
106
- return images, textbox,zip_path
107
-
108
- def replicate_create_model(owner,name,visibility="private",hardware="gpu-a40-large"):
109
- try:
110
- model = replicate.models.create(
111
- owner=owner,
112
- name=name,
113
- visibility=visibility,
114
- hardware=hardware,
115
- )
116
- print(model)
117
- return True
118
- except Exception as e:
119
- print(e)
120
- if "A model with that name and owner already exists" in str(e):
121
- return True
122
- return False
123
-
124
-
125
-
126
- def traning_function(zip_path,training_model,training_destination,seed,token_string,max_train_steps,hf_repo_id=None,hf_token=None):
127
- ##Place holder for now
128
- BB_bucket_name="jarvisdataset"
129
- BB_defult="https://f005.backblazeb2.com/file/"
130
- if BB_defult not in zip_path:
131
- zip_path=BB_uploadfile(zip_path,os.path.basename(zip_path),BB_bucket_name)
132
- print(zip_path)
133
- training_logs = f"Using zip traning file at: {zip_path}\n"
134
- yield training_logs, None
135
- input={
136
- "steps": max_train_steps,
137
- "lora_rank": 16,
138
- "batch_size": 1,
139
- "autocaption": True,
140
- "trigger_word": token_string,
141
- "learning_rate": 0.0004,
142
- "seed": seed,
143
- "input_images": zip_path
144
- }
145
- print(training_destination)
146
- username,model_name=training_destination.split("/")
147
- assert replicate_create_model(username,model_name,visibility="private",hardware="gpu-a40-large"),"Error in creating model on replicate, check API key and username is correct "
148
-
149
- print(input)
150
- try:
151
- training = replicate.trainings.create(
152
- destination=training_destination,
153
- version="ostris/flux-dev-lora-trainer:1296f0ab2d695af5a1b5eeee6e8ec043145bef33f1675ce1a2cdb0f81ec43f02",
154
- input=input,
155
- )
156
-
157
- training_logs = f"Training started with model: {training_model}\n"
158
- training_logs += f"Destination: {training_destination}\n"
159
- training_logs += f"Seed: {seed}\n"
160
- training_logs += f"Token string: {token_string}\n"
161
- training_logs += f"Max train steps: {max_train_steps}\n"
162
-
163
- # Poll the training status
164
- while training.status != "succeeded":
165
- training.reload()
166
- training_logs += f"Training status: {training.status}\n"
167
- training_logs += f"{training.logs}\n"
168
- if training.status == "failed":
169
- training_logs += "Training failed!\n"
170
- return training_logs, training
171
-
172
- yield training_logs, None
173
- time.sleep(10) # Wait for 10 seconds before checking again
174
-
175
- training_logs += "Training completed!\n"
176
- if hf_repo_id and hf_token:
177
- training_logs += f"Uploading to Hugging Face repo: {hf_repo_id}\n"
178
- # Here you would implement the logic to upload to Hugging Face
179
-
180
- traning_finnal=training.output
181
-
182
- # In a real scenario, you might want to download and display some result images
183
- # For now, we'll just return the original images
184
- #images = [Image.open(file) for file in files]
185
- _= update_model_dicts(traning_finnal["version"],token_string,style_json="model_dict.json")
186
- traning_finnal["replicate_link"]="https://replicate.com/"+traning_finnal["version"].replace(":","/")
187
- yield training_logs, traning_finnal
188
-
189
- except Exception as e:
190
- yield f"An error occurred: {str(e)}", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import replicate
2
+ import os
3
+ from src.utils import image_to_base64 , update_model_dicts, BB_uploadfile,numpy_to_base64
4
+ import json
5
+ import time
6
+ style_json="model_dict.json"
7
+ model_dict=json.load(open(style_json,"r"))
8
+
9
+
10
+
11
+ def generate_image_replicate(prompt,lora_model,api_path,aspect_ratio,gallery,model,lora_scale=1,num_outputs=1,guidance_scale=3.5,seed=None):
12
+ print(prompt,lora_model,api_path,aspect_ratio)
13
+ #if model=="dev":
14
+ num_inference_steps=30
15
+ if model=="schnell":
16
+ num_inference_steps=5
17
+
18
+ if lora_model is not None:
19
+ api_path=model_dict[lora_model]
20
+
21
+ inputs={
22
+ "model": model,
23
+ "prompt": prompt,
24
+ "lora_scale":lora_scale,
25
+ "aspect_ratio": aspect_ratio,
26
+ "num_outputs":num_outputs,
27
+ "num_inference_steps":num_inference_steps,
28
+ "guidance_scale":guidance_scale,
29
+ }
30
+ if seed is not None:
31
+ inputs["seed"]=seed
32
+ output = replicate.run(
33
+ api_path,
34
+ input=inputs
35
+ )
36
+ print(output)
37
+ if gallery is None:
38
+ gallery=[]
39
+ gallery.append(output[0])
40
+ return output[0],gallery
41
+
42
+
43
+ def replicate_caption_api(image,model,context_text):
44
+ base64_image = image_to_base64(image)
45
+ if model=="blip":
46
+ output = replicate.run(
47
+ "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9",
48
+ input={
49
+ "image": base64_image,
50
+ "caption": True,
51
+ "question": context_text,
52
+ "temperature": 1,
53
+ "use_nucleus_sampling": False
54
+ }
55
+ )
56
+ print(output)
57
+
58
+ elif model=="llava-16":
59
+ output = replicate.run(
60
+ # "yorickvp/llava-13b:80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",
61
+ "yorickvp/llava-v1.6-34b:41ecfbfb261e6c1adf3ad896c9066ca98346996d7c4045c5bc944a79d430f174",
62
+ input={
63
+ "image": base64_image,
64
+ "top_p": 1,
65
+ "prompt": context_text,
66
+ "max_tokens": 1024,
67
+ "temperature": 0.2
68
+ }
69
+ )
70
+ print(output)
71
+ output = "".join(output)
72
+
73
+ elif model=="img2prompt":
74
+ output = replicate.run(
75
+ "methexis-inc/img2prompt:50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5",
76
+ input={
77
+ "image":base64_image
78
+ }
79
+ )
80
+ print(output)
81
+ return output
82
+
83
+ def update_replicate_api_key(api_key):
84
+ os.environ["REPLICATE_API_TOKEN"] = api_key
85
+ return f"Replicate API key updated: {api_key[:5]}..." if api_key else "Replicate API key cleared"
86
+
87
+
88
+ def virtual_try_on(crop, seed, steps, category, garm_img, human_img, garment_des):
89
+ output = replicate.run(
90
+ "cuuupid/idm-vton:906425dbca90663ff5427624839572cc56ea7d380343d13e2a4c4b09d3f0c30f",
91
+ input={
92
+ "crop": crop,
93
+ "seed": seed,
94
+ "steps": steps,
95
+ "category": category,
96
+ # "force_dc": force_dc,
97
+ "garm_img": numpy_to_base64( garm_img),
98
+ "human_img": numpy_to_base64(human_img),
99
+ #"mask_only": mask_only,
100
+ "garment_des": garment_des
101
+ }
102
+ )
103
+ print(output)
104
+ return output
105
+
106
+
107
+ from src.utils import create_zip
108
+ from PIL import Image
109
+
110
+
111
+ def process_images(files,model,context_text,token_string):
112
+ images = []
113
+ textbox =""
114
+ for file in files:
115
+ print(file)
116
+ image = Image.open(file)
117
+ if model=="None":
118
+ caption="[Insert cap here]"
119
+ else:
120
+ caption = replicate_caption_api(image,model,context_text)
121
+ textbox += f"Tags: {caption}, file: " + os.path.basename(file) + "\n"
122
+ images.append(image)
123
+ #texts.append(textbox)
124
+ zip_path=create_zip(files,textbox,token_string)
125
+
126
+ return images, textbox,zip_path
127
+
128
+ def replicate_create_model(owner,name,visibility="private",hardware="gpu-a40-large"):
129
+ try:
130
+ model = replicate.models.create(
131
+ owner=owner,
132
+ name=name,
133
+ visibility=visibility,
134
+ hardware=hardware,
135
+ )
136
+ print(model)
137
+ return True
138
+ except Exception as e:
139
+ print(e)
140
+ if "A model with that name and owner already exists" in str(e):
141
+ return True
142
+ return False
143
+
144
+
145
+
146
+ def traning_function(zip_path,training_model,training_destination,seed,token_string,max_train_steps,hf_repo_id=None,hf_token=None):
147
+ ##Place holder for now
148
+ BB_bucket_name="jarvisdataset"
149
+ BB_defult="https://f005.backblazeb2.com/file/"
150
+ if BB_defult not in zip_path:
151
+ zip_path=BB_uploadfile(zip_path,os.path.basename(zip_path),BB_bucket_name)
152
+ print(zip_path)
153
+ training_logs = f"Using zip traning file at: {zip_path}\n"
154
+ yield training_logs, None
155
+ input={
156
+ "steps": max_train_steps,
157
+ "lora_rank": 16,
158
+ "batch_size": 1,
159
+ "autocaption": True,
160
+ "trigger_word": token_string,
161
+ "learning_rate": 0.0004,
162
+ "seed": seed,
163
+ "input_images": zip_path
164
+ }
165
+ print(training_destination)
166
+ username,model_name=training_destination.split("/")
167
+ assert replicate_create_model(username,model_name,visibility="private",hardware="gpu-a40-large"),"Error in creating model on replicate, check API key and username is correct "
168
+
169
+ print(input)
170
+ try:
171
+ training = replicate.trainings.create(
172
+ destination=training_destination,
173
+ version="ostris/flux-dev-lora-trainer:1296f0ab2d695af5a1b5eeee6e8ec043145bef33f1675ce1a2cdb0f81ec43f02",
174
+ input=input,
175
+ )
176
+
177
+ training_logs = f"Training started with model: {training_model}\n"
178
+ training_logs += f"Destination: {training_destination}\n"
179
+ training_logs += f"Seed: {seed}\n"
180
+ training_logs += f"Token string: {token_string}\n"
181
+ training_logs += f"Max train steps: {max_train_steps}\n"
182
+
183
+ # Poll the training status
184
+ while training.status != "succeeded":
185
+ training.reload()
186
+ training_logs += f"Training status: {training.status}\n"
187
+ training_logs += f"{training.logs}\n"
188
+ if training.status == "failed":
189
+ training_logs += "Training failed!\n"
190
+ return training_logs, training
191
+
192
+ yield training_logs, None
193
+ time.sleep(10) # Wait for 10 seconds before checking again
194
+
195
+ training_logs += "Training completed!\n"
196
+ if hf_repo_id and hf_token:
197
+ training_logs += f"Uploading to Hugging Face repo: {hf_repo_id}\n"
198
+ # Here you would implement the logic to upload to Hugging Face
199
+
200
+ traning_finnal=training.output
201
+
202
+ # In a real scenario, you might want to download and display some result images
203
+ # For now, we'll just return the original images
204
+ #images = [Image.open(file) for file in files]
205
+ _= update_model_dicts(traning_finnal["version"],token_string,style_json="model_dict.json")
206
+ traning_finnal["replicate_link"]="https://replicate.com/"+traning_finnal["version"].replace(":","/")
207
+ yield training_logs, traning_finnal
208
+
209
+ except Exception as e:
210
+ yield f"An error occurred: {str(e)}", None