JarvisLabs commited on
Commit
c331d2f
1 Parent(s): 16786bc

Upload 2 files

Browse files
Files changed (2) hide show
  1. src/rep_api.py +32 -10
  2. src/utils.py +23 -16
src/rep_api.py CHANGED
@@ -1,13 +1,11 @@
1
  import replicate
2
  import os
3
- from src.utils import image_to_base64
4
- from src.utils import BB_uploadfile
5
  import json
6
  import time
7
  style_json="model_dict.json"
8
  model_dict=json.load(open(style_json,"r"))
9
- from dotenv import load_dotenv, find_dotenv
10
- _ = load_dotenv(find_dotenv())
11
 
12
 
13
  def generate_image_replicate(prompt,lora_model,api_path,aspect_ratio,gallery,model="dev",lora_scale=1,num_outputs=1,guidance_scale=3.5,seed=None):
@@ -19,8 +17,7 @@ def generate_image_replicate(prompt,lora_model,api_path,aspect_ratio,gallery,mod
19
 
20
  if lora_model is not None:
21
  api_path=model_dict[lora_model]
22
-
23
-
24
  inputs={
25
  "model": model,
26
  "prompt": prompt,
@@ -91,26 +88,46 @@ from src.utils import create_zip
91
  from PIL import Image
92
 
93
 
94
- def process_images(files,model,context_text):
95
  images = []
96
  textbox =""
97
  for file in files:
98
  print(file)
99
  image = Image.open(file)
100
- caption = replicate_caption_api(image,model,context_text)
 
 
 
101
  textbox += f"Tags: {caption}, file: " + os.path.basename(file) + "\n"
102
  images.append(image)
103
  #texts.append(textbox)
104
- zip_path=create_zip(files,textbox,"TOK")
105
 
106
  return images, textbox,zip_path
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def traning_function(zip_path,training_model,training_destination,seed,token_string,max_train_steps,hf_repo_id=None,hf_token=None):
109
  ##Place holder for now
110
  BB_bucket_name="jarvisdataset"
111
  BB_defult="https://f005.backblazeb2.com/file/"
112
  if BB_defult not in zip_path:
113
-
114
  zip_path=BB_uploadfile(zip_path,os.path.basename(zip_path),BB_bucket_name)
115
  print(zip_path)
116
  training_logs = f"Using zip traning file at: {zip_path}\n"
@@ -126,6 +143,9 @@ def traning_function(zip_path,training_model,training_destination,seed,token_str
126
  "input_images": zip_path
127
  }
128
  print(training_destination)
 
 
 
129
  print(input)
130
  try:
131
  training = replicate.trainings.create(
@@ -162,6 +182,8 @@ def traning_function(zip_path,training_model,training_destination,seed,token_str
162
  # In a real scenario, you might want to download and display some result images
163
  # For now, we'll just return the original images
164
  #images = [Image.open(file) for file in files]
 
 
165
  yield training_logs, traning_finnal
166
 
167
  except Exception as e:
 
1
  import replicate
2
  import os
3
+ from src.utils import image_to_base64 , update_model_dicts, BB_uploadfile
 
4
  import json
5
  import time
6
  style_json="model_dict.json"
7
  model_dict=json.load(open(style_json,"r"))
8
+
 
9
 
10
 
11
  def generate_image_replicate(prompt,lora_model,api_path,aspect_ratio,gallery,model="dev",lora_scale=1,num_outputs=1,guidance_scale=3.5,seed=None):
 
17
 
18
  if lora_model is not None:
19
  api_path=model_dict[lora_model]
20
+
 
21
  inputs={
22
  "model": model,
23
  "prompt": prompt,
 
88
  from PIL import Image
89
 
90
 
91
+ def process_images(files,model,context_text,token_string):
92
  images = []
93
  textbox =""
94
  for file in files:
95
  print(file)
96
  image = Image.open(file)
97
+ if model=="None":
98
+ caption="[Insert cap here]"
99
+ else:
100
+ caption = replicate_caption_api(image,model,context_text)
101
  textbox += f"Tags: {caption}, file: " + os.path.basename(file) + "\n"
102
  images.append(image)
103
  #texts.append(textbox)
104
+ zip_path=create_zip(files,textbox,token_string)
105
 
106
  return images, textbox,zip_path
107
 
108
+ def replicate_create_model(owner,name,visibility="private",hardware="gpu-a40-large"):
109
+ try:
110
+ model = replicate.models.create(
111
+ owner=owner,
112
+ name=name,
113
+ visibility=visibility,
114
+ hardware=hardware,
115
+ )
116
+ print(model)
117
+ return True
118
+ except Exception as e:
119
+ print(e)
120
+ if "A model with that name and owner already exists" in str(e):
121
+ return True
122
+ return False
123
+
124
+
125
+
126
  def traning_function(zip_path,training_model,training_destination,seed,token_string,max_train_steps,hf_repo_id=None,hf_token=None):
127
  ##Place holder for now
128
  BB_bucket_name="jarvisdataset"
129
  BB_defult="https://f005.backblazeb2.com/file/"
130
  if BB_defult not in zip_path:
 
131
  zip_path=BB_uploadfile(zip_path,os.path.basename(zip_path),BB_bucket_name)
132
  print(zip_path)
133
  training_logs = f"Using zip traning file at: {zip_path}\n"
 
143
  "input_images": zip_path
144
  }
145
  print(training_destination)
146
+ username,model_name=training_destination.split("/")
147
+ assert replicate_create_model(username,model_name,visibility="private",hardware="gpu-a40-large"),"Error in creating model on replicate, check API key and username is correct "
148
+
149
  print(input)
150
  try:
151
  training = replicate.trainings.create(
 
182
  # In a real scenario, you might want to download and display some result images
183
  # For now, we'll just return the original images
184
  #images = [Image.open(file) for file in files]
185
+ _= update_model_dicts(traning_finnal["version"],token_string,style_json="model_dict.json")
186
+ traning_finnal["replicate_link"]="https://replicate.com/"+traning_finnal["version"].replace(":","/")
187
  yield training_logs, traning_finnal
188
 
189
  except Exception as e:
src/utils.py CHANGED
@@ -36,7 +36,7 @@ def add_to_prompt(existing_prompt, new_prompt):
36
  def image_to_base64(img):
37
  buffered = io.BytesIO()
38
  img.save(buffered, format="PNG")
39
- img_str = base64.b64encode(buffered.getvalue()).decode('utexf-8')
40
  return "data:image/png;base64,"+img_str
41
 
42
  def create_zip(files,captions,trigger):
@@ -44,26 +44,33 @@ def create_zip(files,captions,trigger):
44
  captions=captions.split("\n")
45
  #cute files and "tags:"
46
  captions= [cap.split("file:")[0][5:] for cap in captions]
47
- with tempfile.TemporaryDirectory() as temp_dir:
48
- # Create a zip file
49
- zip_path = os.path.join(temp_dir, "training_data.zip")
50
- with zipfile.ZipFile(zip_path, "w") as zip_file:
51
- for i, file in enumerate(files):
52
- # Add image to zip
53
- image_name = f"image_{i}.jpg"
54
- zip_file.write(file, image_name)
55
- # Add caption to zip
56
- caption_name = f"image_{i}.txt"
57
- caption_content = captions[i] +f", {trigger}"
58
- zip_file.writestr(caption_name, caption_content)
59
- return zip_path
60
-
61
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- def BB_uploadfile(local_file,file_name,BB_bucket_name,FRIENDLY_URL=True,application_key_id = os.getenv("BB_KeyID"),application_key = os.getenv("BB_AppKey"),):
64
  info = b2.InMemoryAccountInfo()
65
  b2_api = b2.B2Api(info)
66
  #print(application_key_id,application_key)
 
 
67
  b2_api.authorize_account("production", application_key_id, application_key)
68
  BB_bucket=b2_api.get_bucket_by_name(BB_bucket_name)
69
  BB_defurl="https://f005.backblazeb2.com/file/"
 
36
  def image_to_base64(img):
37
  buffered = io.BytesIO()
38
  img.save(buffered, format="PNG")
39
+ img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
40
  return "data:image/png;base64,"+img_str
41
 
42
  def create_zip(files,captions,trigger):
 
44
  captions=captions.split("\n")
45
  #cute files and "tags:"
46
  captions= [cap.split("file:")[0][5:] for cap in captions]
47
+ print("files",len(files),"captions",len(captions))
48
+ #assert len(files)==len(captions) , "File amount does not equal the captions amount please check"
49
+ temp_dir="./datasets/"
50
+ os.makedirs(temp_dir,exist_ok=True)
51
+
52
+ zip_path = os.path.join(temp_dir, f"training_data_{trigger}.zip")
53
+ if os.path.exists(zip_path):
54
+ os.remove(zip_path)
 
 
 
 
 
 
55
 
56
+ with zipfile.ZipFile(zip_path, "w") as zip_file:
57
+ for i, file in enumerate(files):
58
+ # Add image to zip
59
+ image_name = f"image_{i}.jpg"
60
+ print(file)
61
+ zip_file.write(file, image_name)
62
+ # Add caption to zip
63
+ caption_name = f"image_{i}.txt"
64
+ caption_content = captions[i] +f", {trigger}"
65
+ zip_file.writestr(caption_name, caption_content)
66
+ return zip_path
67
 
68
+ def BB_uploadfile(local_file,file_name,BB_bucket_name,FRIENDLY_URL=True):
69
  info = b2.InMemoryAccountInfo()
70
  b2_api = b2.B2Api(info)
71
  #print(application_key_id,application_key)
72
+ application_key_id = os.getenv("BB_KeyID")
73
+ application_key = os.getenv("BB_AppKey")
74
  b2_api.authorize_account("production", application_key_id, application_key)
75
  BB_bucket=b2_api.get_bucket_by_name(BB_bucket_name)
76
  BB_defurl="https://f005.backblazeb2.com/file/"