feat: Choosable CLI, Custom Output Shard Size, LORA extraction

#30
Files changed (2) hide show
  1. app.py +164 -30
  2. requirements.txt +2 -0
app.py CHANGED
@@ -11,10 +11,13 @@ import gradio as gr
11
  import huggingface_hub
12
  import torch
13
  import yaml
 
14
  from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
15
  from mergekit.config import MergeConfiguration
16
 
17
  from clean_community_org import garbage_collect_empty_models
 
 
18
 
19
  has_gpu = torch.cuda.is_available()
20
 
@@ -42,8 +45,8 @@ has_gpu = torch.cuda.is_available()
42
  # )
43
  # )
44
 
45
- cli = "mergekit-yaml config.yaml merge --copy-tokenizer" + (
46
- " --cuda --low-cpu-memory --allow-crimes" if has_gpu else " --allow-crimes --out-shard-size 1B --lazy-unpickle"
47
  )
48
 
49
  MARKDOWN_DESCRIPTION = """
@@ -111,17 +114,19 @@ examples = [[str(f)] for f in pathlib.Path("examples").glob("*.yaml")]
111
  COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
112
 
113
 
114
- def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
115
  runner = LogsViewRunner()
116
 
117
  if not yaml_config:
118
  yield runner.log("Empty yaml, pick an example below", level="ERROR")
119
  return
120
- try:
121
- merge_config = MergeConfiguration.model_validate(yaml.safe_load(yaml_config))
122
- except Exception as e:
123
- yield runner.log(f"Invalid yaml {e}", level="ERROR")
124
- return
 
 
125
 
126
  is_community_model = False
127
  if not hf_token:
@@ -170,7 +175,7 @@ def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]
170
  # Set tmp HF_HOME to avoid filling up disk Space
171
  tmp_env = os.environ.copy() # taken from https://stackoverflow.com/a/4453495
172
  tmp_env["HF_HOME"] = f"{tmpdirname}/.cache"
173
- full_cli = cli + f" --lora-merge-cache {tmpdirname}/.lora_cache"
174
  yield from runner.run_command(full_cli.split(), cwd=merged_path, env=tmp_env)
175
 
176
  if runner.exit_code != 0:
@@ -187,27 +192,158 @@ def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]
187
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
188
 
189
 
190
- with gr.Blocks() as demo:
191
- gr.Markdown(MARKDOWN_DESCRIPTION)
 
 
192
 
193
- with gr.Row():
194
- filename = gr.Textbox(visible=False, label="filename")
195
- config = gr.Code(language="yaml", lines=10, label="config.yaml")
196
- with gr.Column():
197
- token = gr.Textbox(
198
- lines=1,
199
- label="HF Write Token",
200
- info="https://hf.co/settings/token",
201
- type="password",
202
- placeholder="Optional. Will upload merged model to MergeKit Community if empty.",
203
- )
204
- repo_name = gr.Textbox(
205
- lines=1,
206
- label="Repo name",
207
- placeholder="Optional. Will create a random name if empty.",
208
  )
209
- button = gr.Button("Merge", variant="primary")
210
- logs = LogsView(label="Terminal output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  gr.Examples(
212
  examples,
213
  fn=lambda s: (s,),
@@ -218,8 +354,6 @@ with gr.Blocks() as demo:
218
  )
219
  gr.Markdown(MARKDOWN_ARTICLE)
220
 
221
- button.click(fn=merge, inputs=[config, token, repo_name], outputs=[logs])
222
-
223
 
224
  # Run garbage collection every hour to keep the community org clean.
225
  # Empty models might exists if the merge fails abruptly (e.g. if user leaves the Space).
 
11
  import huggingface_hub
12
  import torch
13
  import yaml
14
+ import bitsandbytes
15
  from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
16
  from mergekit.config import MergeConfiguration
17
 
18
  from clean_community_org import garbage_collect_empty_models
19
+ from apscheduler.schedulers.background import BackgroundScheduler
20
+ from datetime import datetime, timezone
21
 
22
  has_gpu = torch.cuda.is_available()
23
 
 
45
  # )
46
  # )
47
 
48
+ cli = "config.yaml merge --copy-tokenizer" + (
49
+ " --cuda --low-cpu-memory --allow-crimes" if has_gpu else " --allow-crimes --lazy-unpickle"
50
  )
51
 
52
  MARKDOWN_DESCRIPTION = """
 
114
  COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
115
 
116
 
117
+ def merge(program: str, yaml_config: str, out_shard_size: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
118
  runner = LogsViewRunner()
119
 
120
  if not yaml_config:
121
  yield runner.log("Empty yaml, pick an example below", level="ERROR")
122
  return
123
+ # TODO: validate moe config and mega config?
124
+ if program not in ("mergekit-moe", "mergekit-mega"):
125
+ try:
126
+ merge_config = MergeConfiguration.model_validate(yaml.safe_load(yaml_config))
127
+ except Exception as e:
128
+ yield runner.log(f"Invalid yaml {e}", level="ERROR")
129
+ return
130
 
131
  is_community_model = False
132
  if not hf_token:
 
175
  # Set tmp HF_HOME to avoid filling up disk Space
176
  tmp_env = os.environ.copy() # taken from https://stackoverflow.com/a/4453495
177
  tmp_env["HF_HOME"] = f"{tmpdirname}/.cache"
178
+ full_cli = f"{program} {cli} --lora-merge-cache {tmpdirname}/.lora_cache --out-shard-size {out_shard_size}"
179
  yield from runner.run_command(full_cli.split(), cwd=merged_path, env=tmp_env)
180
 
181
  if runner.exit_code != 0:
 
192
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
193
 
194
 
195
+ def extract(finetuned_model: str, base_model: str, rank: int, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
196
+ runner = LogsViewRunner()
197
+ if not finetuned_model or not base_model:
198
+ yield runner.log("All field should be filled")
199
 
200
+ is_community_model = False
201
+ if not hf_token:
202
+ if "/" in repo_name and not repo_name.startswith("mergekit-community/"):
203
+ yield runner.log(
204
+ f"Cannot upload merge model to namespace {repo_name.split('/')[0]}: you must provide a valid token.",
205
+ level="ERROR",
 
 
 
 
 
 
 
 
 
206
  )
207
+ return
208
+ yield runner.log(
209
+ "No HF token provided. Your lora will be uploaded to the https://huggingface.co/mergekit-community organization."
210
+ )
211
+ is_community_model = True
212
+ if not COMMUNITY_HF_TOKEN:
213
+ raise gr.Error("Cannot upload to community org: community token not set by Space owner.")
214
+ hf_token = COMMUNITY_HF_TOKEN
215
+
216
+ api = huggingface_hub.HfApi(token=hf_token)
217
+
218
+ with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
219
+ tmpdir = pathlib.Path(tmpdirname)
220
+ merged_path = tmpdir / "merged"
221
+ merged_path.mkdir(parents=True, exist_ok=True)
222
+
223
+ if not repo_name:
224
+ yield runner.log("No repo name provided. Generating a random one.")
225
+ repo_name = "lora"
226
+ # Make repo_name "unique" (no need to be extra careful on uniqueness)
227
+ repo_name += "-" + "".join(random.choices(string.ascii_lowercase, k=7))
228
+ repo_name = repo_name.replace("/", "-").strip("-")
229
+
230
+ if is_community_model and not repo_name.startswith("mergekit-community/"):
231
+ repo_name = f"mergekit-community/{repo_name}"
232
+
233
+ try:
234
+ yield runner.log(f"Creating repo {repo_name}")
235
+ repo_url = api.create_repo(repo_name, exist_ok=True)
236
+ yield runner.log(f"Repo created: {repo_url}")
237
+ except Exception as e:
238
+ yield runner.log(f"Error creating repo {e}", level="ERROR")
239
+ return
240
+
241
+ # Set tmp HF_HOME to avoid filling up disk Space
242
+ tmp_env = os.environ.copy() # taken from https://stackoverflow.com/a/4453495
243
+ tmp_env["HF_HOME"] = f"{tmpdirname}/.cache"
244
+ full_cli = f"mergekit-extract-lora {finetuned_model} {base_model} lora --rank={rank}"
245
+ yield from runner.run_command(full_cli.split(), cwd=merged_path, env=tmp_env)
246
+
247
+ if runner.exit_code != 0:
248
+ yield runner.log("Lora extraction failed. Deleting repo as no lora is uploaded.", level="ERROR")
249
+ api.delete_repo(repo_url.repo_id)
250
+ return
251
+
252
+ yield runner.log("Lora extracted successfully. Uploading to HF.")
253
+ yield from runner.run_python(
254
+ api.upload_folder,
255
+ repo_id=repo_url.repo_id,
256
+ folder_path=merged_path / "lora",
257
+ )
258
+ yield runner.log(f"Lora successfully uploaded to HF: {repo_url.repo_id}")
259
+
260
+ # This is workaround. As the space always getting stuck.
261
+ def _restart_space():
262
+ huggingface_hub.HfApi().restart_space(repo_id="arcee-ai/mergekit-gui", token=COMMUNITY_HF_TOKEN, factory_reboot=False)
263
+ # Run garbage collection every hour to keep the community org clean.
264
+ # Empty models might exists if the merge fails abruptly (e.g. if user leaves the Space).
265
+ def _garbage_remover():
266
+ try:
267
+ garbage_collect_empty_models(token=COMMUNITY_HF_TOKEN)
268
+ except Exception as e:
269
+ print("Error running garbage collection", e)
270
+
271
+ scheduler = BackgroundScheduler()
272
+ restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=21600)
273
+ garbage_remover_job = scheduler.add_job(_garbage_remover, "interval", seconds=3600)
274
+ scheduler.start()
275
+ next_run_time_utc = restart_space_job.next_run_time.astimezone(timezone.utc)
276
+
277
+ NEXT_RESTART = f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC)"
278
+
279
+ with gr.Blocks() as demo:
280
+ gr.Markdown(MARKDOWN_DESCRIPTION)
281
+ gr.Markdown(NEXT_RESTART)
282
+
283
+ with gr.Tabs():
284
+ with gr.TabItem("Merge Model"):
285
+ with gr.Row():
286
+ filename = gr.Textbox(visible=False, label="filename")
287
+ config = gr.Code(language="yaml", lines=10, label="config.yaml")
288
+ with gr.Column():
289
+ program = gr.Dropdown(
290
+ ["mergekit-yaml", "mergekit-mega", "mergekit-moe"],
291
+ label="Mergekit Command",
292
+ info="Choose CLI",
293
+ )
294
+ out_shard_size = gr.Dropdown(
295
+ ["500M", "1B", "2B", "3B", "4B", "5B"],
296
+ label="Output Shard Size",
297
+ value="500M",
298
+ )
299
+ token = gr.Textbox(
300
+ lines=1,
301
+ label="HF Write Token",
302
+ info="https://hf.co/settings/token",
303
+ type="password",
304
+ placeholder="Optional. Will upload merged model to MergeKit Community if empty.",
305
+ )
306
+ repo_name = gr.Textbox(
307
+ lines=1,
308
+ label="Repo name",
309
+ placeholder="Optional. Will create a random name if empty.",
310
+ )
311
+ button = gr.Button("Merge", variant="primary")
312
+ logs = LogsView(label="Terminal output")
313
+ button.click(fn=merge, inputs=[program, config, out_shard_size, token, repo_name], outputs=[logs])
314
+
315
+ with gr.TabItem("LORA Extraction"):
316
+ with gr.Row():
317
+ with gr.Column():
318
+ finetuned_model = gr.Textbox(
319
+ lines=1,
320
+ label="Finetuned Model",
321
+ )
322
+ base_model = gr.Textbox(
323
+ lines=1,
324
+ label="Base Model",
325
+ )
326
+ rank = gr.Dropdown(
327
+ [32, 64, 128],
328
+ label="Rank level",
329
+ value=32,
330
+ )
331
+ with gr.Column():
332
+ token = gr.Textbox(
333
+ lines=1,
334
+ label="HF Write Token",
335
+ info="https://hf.co/settings/token",
336
+ type="password",
337
+ placeholder="Optional. Will upload merged model to MergeKit Community if empty.",
338
+ )
339
+ repo_name = gr.Textbox(
340
+ lines=1,
341
+ label="Repo name",
342
+ placeholder="Optional. Will create a random name if empty.",
343
+ )
344
+ button = gr.Button("Extract LORA", variant="primary")
345
+ logs = LogsView(label="Terminal output")
346
+ button.click(fn=extract, inputs=[finetuned_model, base_model, rank, token, repo_name], outputs=[logs])
347
  gr.Examples(
348
  examples,
349
  fn=lambda s: (s,),
 
354
  )
355
  gr.Markdown(MARKDOWN_ARTICLE)
356
 
 
 
357
 
358
  # Run garbage collection every hour to keep the community org clean.
359
  # Empty models might exists if the merge fails abruptly (e.g. if user leaves the Space).
requirements.txt CHANGED
@@ -1,4 +1,6 @@
 
1
  torch
 
2
  git+https://github.com/arcee-ai/mergekit.git
3
  # see https://huggingface.co/spaces/Wauplin/gradio_logsview
4
  gradio_logsview@https://huggingface.co/spaces/Wauplin/gradio_logsview/resolve/main/gradio_logsview-0.0.5-py3-none-any.whl
 
1
+ apscheduler
2
  torch
3
+ bitsandbytes
4
  git+https://github.com/arcee-ai/mergekit.git
5
  # see https://huggingface.co/spaces/Wauplin/gradio_logsview
6
  gradio_logsview@https://huggingface.co/spaces/Wauplin/gradio_logsview/resolve/main/gradio_logsview-0.0.5-py3-none-any.whl