Spaces:
Running
Running
The ML.ENERGY Colosseum (#22)
Browse filesCo-authored-by: AmberLJC <[email protected]>
- .gitignore +9 -0
- README.md +7 -2
- app.py +288 -15
- Dockerfile → deployment/benchmark.Dockerfile +0 -0
- deployment/controller-container.sh +11 -0
- deployment/controller.Dockerfile +30 -0
- deployment/docker-compose-0.yaml +74 -0
- deployment/docker-compose-1.yaml +40 -0
- docs/colosseum_bottom.md +14 -0
- docs/colosseum_top.md +8 -0
- LEADERBOARD.md → docs/leaderboard.md +2 -2
- requirements.txt +1 -2
- setup.py +21 -0
- spitfight/__init__.py +0 -0
- spitfight/colosseum/__init__.py +0 -0
- spitfight/colosseum/client.py +106 -0
- spitfight/colosseum/common.py +35 -0
- spitfight/colosseum/controller/__init__.py +0 -0
- spitfight/colosseum/controller/controller.py +266 -0
- spitfight/colosseum/controller/router.py +125 -0
- spitfight/colosseum/controller/worker.py +151 -0
- spitfight/log.py +76 -0
- spitfight/prompt.py +69 -0
- spitfight/utils.py +305 -0
.gitignore
CHANGED
@@ -7,3 +7,12 @@
|
|
7 |
# Editor
|
8 |
pyrightconfig.json
|
9 |
.idea
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
# Editor
|
8 |
pyrightconfig.json
|
9 |
.idea
|
10 |
+
|
11 |
+
# Python
|
12 |
+
*.egg-info
|
13 |
+
**/__pycache__
|
14 |
+
build/
|
15 |
+
|
16 |
+
# Data files
|
17 |
+
*.log
|
18 |
+
pegasus/consumed.yaml
|
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: "⚡"
|
|
4 |
python_version: "3.9"
|
5 |
app_file: "app.py"
|
6 |
sdk: "gradio"
|
7 |
-
sdk_version: "3.
|
8 |
pinned: true
|
9 |
tags: ["energy", "leaderboard"]
|
10 |
colorFrom: "black"
|
@@ -22,7 +22,12 @@ How much energy do LLMs consume?
|
|
22 |
This README focuses on explaining how to run the benchmark yourself.
|
23 |
The actual leaderboard is here: https://ml.energy/leaderboard.
|
24 |
|
25 |
-
##
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
### Model weights
|
28 |
|
|
|
4 |
python_version: "3.9"
|
5 |
app_file: "app.py"
|
6 |
sdk: "gradio"
|
7 |
+
sdk_version: "3.39.0"
|
8 |
pinned: true
|
9 |
tags: ["energy", "leaderboard"]
|
10 |
colorFrom: "black"
|
|
|
22 |
This README focuses on explaining how to run the benchmark yourself.
|
23 |
The actual leaderboard is here: https://ml.energy/leaderboard.
|
24 |
|
25 |
+
## Colosseum
|
26 |
+
|
27 |
+
We instrumented [Hugging Face TGI](https://github.com/huggingface/text-generation-inference) so that it measures and returns GPU energy consumption.
|
28 |
+
Then, our [controller](/spitfight/colosseum/controller) server receives user prompts from the [Gradio app](/app.py), selects two models randomly, and streams model responses back with energy consumption.
|
29 |
+
|
30 |
+
## Setup for benchmarking
|
31 |
|
32 |
### Model weights
|
33 |
|
app.py
CHANGED
@@ -5,6 +5,9 @@ import yaml
|
|
5 |
import requests
|
6 |
import itertools
|
7 |
import contextlib
|
|
|
|
|
|
|
8 |
from dateutil import parser, tz
|
9 |
|
10 |
import numpy as np
|
@@ -13,9 +16,10 @@ import pandas as pd
|
|
13 |
import plotly.io as pio
|
14 |
import plotly.express as px
|
15 |
from pandas.api.types import is_numeric_dtype, is_float_dtype
|
16 |
-
|
17 |
pio.templates.default = "plotly_white"
|
18 |
|
|
|
|
|
19 |
|
20 |
class TableManager:
|
21 |
def __init__(self, data_dir: str) -> None:
|
@@ -215,7 +219,6 @@ class TableManager:
|
|
215 |
|
216 |
return fig, width, height, ""
|
217 |
|
218 |
-
|
219 |
# The global instance of the TableManager should only be used when
|
220 |
# initializing components in the Gradio interface. If the global instance
|
221 |
# is mutated while handling user sessions, the change will be reflected
|
@@ -280,7 +283,7 @@ function format_model_link() {{
|
|
280 |
"""
|
281 |
|
282 |
# Custom CSS.
|
283 |
-
|
284 |
/* Make ML.ENERGY look like a clickable logo. */
|
285 |
.text-logo {
|
286 |
color: #23d175 !important;
|
@@ -311,6 +314,14 @@ table th:first-child {
|
|
311 |
.tab-nav > button {
|
312 |
font-size: 18px !important;
|
313 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
"""
|
315 |
|
316 |
intro_text = """
|
@@ -324,13 +335,262 @@ including the ARC Challenge (reasoning), HellaSwag (common sense), and TruthfulQ
|
|
324 |
Every benchmark is limited in some sense -- Before you interpret the results, please take a look at the <b>Limitations</b> section there, too.</p>
|
325 |
"""
|
326 |
|
327 |
-
|
328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
tbm = gr.State(global_tbm) # type: ignore
|
330 |
with gr.Box():
|
331 |
gr.HTML("<h1><a href='https://ml.energy' class='text-logo'>ML.ENERGY</a> Leaderboard</h1>")
|
332 |
|
333 |
with gr.Tabs():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
# Tab: Leaderboard.
|
335 |
with gr.Tab("Leaderboard"):
|
336 |
with gr.Box():
|
@@ -340,7 +600,7 @@ with block:
|
|
340 |
with gr.Row():
|
341 |
with gr.Box():
|
342 |
gr.Markdown("### Benchmark results to show")
|
343 |
-
checkboxes = []
|
344 |
for key, choices in global_tbm.schema.items():
|
345 |
# Specifying `value` makes everything checked by default.
|
346 |
checkboxes.append(gr.CheckboxGroup(choices=choices, value=choices[:1], label=key))
|
@@ -349,10 +609,10 @@ with block:
|
|
349 |
with gr.Row():
|
350 |
dataframe = gr.Dataframe(type="pandas", elem_id="tab-leaderboard")
|
351 |
# Make sure the models have clickable links.
|
352 |
-
dataframe.change(None, None, None, _js=dataframe_update_js)
|
353 |
# Table automatically updates when users check or uncheck any checkbox.
|
354 |
for checkbox in checkboxes:
|
355 |
-
checkbox.change(TableManager.set_filter_get_df, inputs=[tbm, *checkboxes], outputs=dataframe)
|
356 |
|
357 |
# Block: Allow users to add new columns.
|
358 |
with gr.Box():
|
@@ -381,21 +641,25 @@ with block:
|
|
381 |
TableManager.add_column,
|
382 |
inputs=[tbm, colname_input, formula_input],
|
383 |
outputs=[dataframe, add_col_message],
|
|
|
384 |
)
|
385 |
formula_input.submit(
|
386 |
TableManager.add_column,
|
387 |
inputs=[tbm, colname_input, formula_input],
|
388 |
outputs=[dataframe, add_col_message],
|
|
|
389 |
)
|
390 |
add_col_btn.click(
|
391 |
TableManager.add_column,
|
392 |
inputs=[tbm, colname_input, formula_input],
|
393 |
outputs=[dataframe, add_col_message],
|
|
|
394 |
)
|
395 |
clear_input_btn.click(
|
396 |
lambda: (None, None, None),
|
397 |
inputs=None,
|
398 |
outputs=[colname_input, formula_input, add_col_message],
|
|
|
399 |
)
|
400 |
|
401 |
# Block: Allow users to plot 2D and 3D scatter plots.
|
@@ -425,42 +689,51 @@ with block:
|
|
425 |
)[0]) # type: ignore
|
426 |
with gr.Row():
|
427 |
plot_message = gr.HTML("")
|
428 |
-
add_col_btn.click(TableManager.update_dropdown, inputs=tbm, outputs=axis_dropdowns) # type: ignore
|
429 |
plot_width_input.submit(
|
430 |
TableManager.plot_scatter,
|
431 |
inputs=[tbm, plot_width_input, plot_height_input, *axis_dropdowns],
|
432 |
outputs=[plot, plot_width_input, plot_height_input, plot_message],
|
|
|
433 |
)
|
434 |
plot_height_input.submit(
|
435 |
TableManager.plot_scatter,
|
436 |
inputs=[tbm, plot_width_input, plot_height_input, *axis_dropdowns],
|
437 |
outputs=[plot, plot_width_input, plot_height_input, plot_message],
|
|
|
438 |
)
|
439 |
plot_btn.click(
|
440 |
TableManager.plot_scatter,
|
441 |
inputs=[tbm, plot_width_input, plot_height_input, *axis_dropdowns],
|
442 |
outputs=[plot, plot_width_input, plot_height_input, plot_message],
|
|
|
443 |
)
|
444 |
clear_plot_btn.click(
|
445 |
lambda: (None,) * 7,
|
446 |
None,
|
447 |
outputs=[*axis_dropdowns, plot, plot_width_input, plot_height_input, plot_message],
|
|
|
448 |
)
|
449 |
|
450 |
# Block: Leaderboard date.
|
451 |
with gr.Row():
|
452 |
gr.HTML(f"<h3 style='color: gray'>Last updated: {current_date}</h3>")
|
453 |
|
454 |
-
# Tab: Online demo.
|
455 |
-
with gr.Tab("Online demo (Coming in August!)"):
|
456 |
-
gr.Markdown("# Online demo with real time energy measurements\n\nComing soon in August!")
|
457 |
-
|
458 |
# Tab: About page.
|
459 |
with gr.Tab("About"):
|
460 |
# Read in LEADERBOARD.md
|
461 |
-
gr.Markdown(open("
|
462 |
|
463 |
# Load the table on page load.
|
464 |
block.load(lambda: global_tbm.set_filter_get_df(), outputs=dataframe)
|
465 |
|
466 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import requests
|
6 |
import itertools
|
7 |
import contextlib
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
from typing import Literal
|
11 |
from dateutil import parser, tz
|
12 |
|
13 |
import numpy as np
|
|
|
16 |
import plotly.io as pio
|
17 |
import plotly.express as px
|
18 |
from pandas.api.types import is_numeric_dtype, is_float_dtype
|
|
|
19 |
pio.templates.default = "plotly_white"
|
20 |
|
21 |
+
from spitfight.colosseum.client import ControllerClient
|
22 |
+
|
23 |
|
24 |
class TableManager:
|
25 |
def __init__(self, data_dir: str) -> None:
|
|
|
219 |
|
220 |
return fig, width, height, ""
|
221 |
|
|
|
222 |
# The global instance of the TableManager should only be used when
|
223 |
# initializing components in the Gradio interface. If the global instance
|
224 |
# is mutated while handling user sessions, the change will be reflected
|
|
|
283 |
"""
|
284 |
|
285 |
# Custom CSS.
|
286 |
+
custom_css = """
|
287 |
/* Make ML.ENERGY look like a clickable logo. */
|
288 |
.text-logo {
|
289 |
color: #23d175 !important;
|
|
|
314 |
.tab-nav > button {
|
315 |
font-size: 18px !important;
|
316 |
}
|
317 |
+
|
318 |
+
/* Color texts. */
|
319 |
+
.green-text {
|
320 |
+
color: #23d175 !important;
|
321 |
+
}
|
322 |
+
.red-text {
|
323 |
+
color: #ff3860 !important;
|
324 |
+
}
|
325 |
"""
|
326 |
|
327 |
intro_text = """
|
|
|
335 |
Every benchmark is limited in some sense -- Before you interpret the results, please take a look at the <b>Limitations</b> section there, too.</p>
|
336 |
"""
|
337 |
|
338 |
+
# The app will not start without a controller address set.
|
339 |
+
controller_addr = os.environ["COLOSSEUM_CONTROLLER_ADDR"]
|
340 |
+
global_controller_client = ControllerClient(controller_addr=controller_addr, timeout=15)
|
341 |
+
|
342 |
+
ANONYMOUS_MODEL_TEXT = "## Anonymous 🤫"
|
343 |
+
|
344 |
+
# Colosseum helper functions.
|
345 |
+
def enable_interact():
|
346 |
+
return [gr.update(interactive=True)] * 2
|
347 |
+
|
348 |
+
def disable_interact():
|
349 |
+
return [gr.update(interactive=False)] * 2
|
350 |
+
|
351 |
+
def consumed_less_energy_message(energy_a, energy_b):
|
352 |
+
"""Return a message that indicates that the user chose the model that consumed less energy.
|
353 |
+
|
354 |
+
By default report in "%f %" but if the difference is larger than 2 times, report in "%f X".
|
355 |
+
"""
|
356 |
+
less_energy = min(energy_a, energy_b)
|
357 |
+
more_energy = max(energy_a, energy_b)
|
358 |
+
factor = less_energy / more_energy
|
359 |
+
if factor <= 0.5:
|
360 |
+
message = f"<h2>That response also <span class='green-text'>consumed {1/factor:.1f}X less energy</span>!</h2>"
|
361 |
+
else:
|
362 |
+
message = f"<h2>That response also <span class='green-text'>consumed {100 - factor * 100:.1f}% less energy</span>!</h2>"
|
363 |
+
return message
|
364 |
+
|
365 |
+
def consumed_more_energy_message(energy_a, energy_b):
|
366 |
+
"""Return a message that indicates that the user chose the model that consumed more energy.
|
367 |
+
|
368 |
+
By default report in "%f %" but if the difference is larger than 2 times, report in "%f X".
|
369 |
+
"""
|
370 |
+
less_energy = min(energy_a, energy_b)
|
371 |
+
more_energy = max(energy_a, energy_b)
|
372 |
+
factor = more_energy / less_energy
|
373 |
+
if factor >= 2.0:
|
374 |
+
message = f"<h2>That response <span class='red-text'>consumed {factor:.1f}x more energy</span>.</h2>"
|
375 |
+
else:
|
376 |
+
message = f"<h2>That response <span class='red-text'>consumed {factor * 100 - 100:.1f}% more energy</span>.</h2>"
|
377 |
+
return message
|
378 |
+
|
379 |
+
# Colosseum event handlers
|
380 |
+
def add_prompt_disable_submit(prompt, history_a, history_b):
|
381 |
+
"""Add the user's prompt to the two model's history and disable the submit button."""
|
382 |
+
client = global_controller_client.fork()
|
383 |
+
return [
|
384 |
+
gr.Textbox.update(value=" ", interactive=False),
|
385 |
+
gr.Button.update(interactive=False),
|
386 |
+
history_a + [[prompt, ""]],
|
387 |
+
history_b + [[prompt, ""]],
|
388 |
+
client,
|
389 |
+
]
|
390 |
+
|
391 |
+
def generate_responses(client: ControllerClient, history_a, history_b):
|
392 |
+
"""Generate responses for the two models."""
|
393 |
+
for resp_a, resp_b in itertools.zip_longest(
|
394 |
+
client.prompt(prompt=history_a[-1][0], index=0),
|
395 |
+
client.prompt(prompt=history_b[-1][0], index=1),
|
396 |
+
):
|
397 |
+
if resp_a is not None:
|
398 |
+
history_a[-1][1] += resp_a
|
399 |
+
if resp_b is not None:
|
400 |
+
history_b[-1][1] += resp_b
|
401 |
+
yield [history_a, history_b]
|
402 |
+
|
403 |
+
def make_resp_vote_func(victory_index: Literal[0, 1]):
|
404 |
+
"""Return a function that will be called when the user clicks on response preference vote buttons."""
|
405 |
+
def resp_vote_func(client: ControllerClient):
|
406 |
+
vote_response = client.response_vote(victory_index=victory_index)
|
407 |
+
model_name_a, model_name_b = map(lambda n: f"## {n}", vote_response.model_names)
|
408 |
+
energy_a, energy_b = vote_response.energy_consumptions
|
409 |
+
# User liked the model that also consumed less energy.
|
410 |
+
if (victory_index == 0 and energy_a <= energy_b) or (victory_index == 1 and energy_a >= energy_b):
|
411 |
+
energy_message = consumed_less_energy_message(energy_a, energy_b)
|
412 |
+
return [
|
413 |
+
# Disable response vote buttons
|
414 |
+
gr.Button.update(interactive=False), gr.Button.update(interactive=False),
|
415 |
+
# Reveal model names
|
416 |
+
gr.Markdown.update(model_name_a), gr.Markdown.update(model_name_b),
|
417 |
+
# Display energy consumption comparison message
|
418 |
+
gr.Markdown.update(energy_message, visible=True),
|
419 |
+
# Keep energy vote buttons hidden
|
420 |
+
gr.Button.update(visible=False, interactive=False), gr.Button.update(visible=False, interactive=False),
|
421 |
+
# Enable reset button
|
422 |
+
gr.Button.update(visible=True, interactive=True),
|
423 |
+
]
|
424 |
+
# User liked the model that consumed more energy.
|
425 |
+
else:
|
426 |
+
energy_message = consumed_more_energy_message(energy_a, energy_b)
|
427 |
+
return [
|
428 |
+
# Disable response vote buttons
|
429 |
+
gr.Button.update(interactive=False), gr.Button.update(interactive=False),
|
430 |
+
# Leave model names hidden
|
431 |
+
gr.Markdown.update(ANONYMOUS_MODEL_TEXT), gr.Markdown.update(ANONYMOUS_MODEL_TEXT),
|
432 |
+
# Display energy consumption comparison message
|
433 |
+
gr.Markdown.update(energy_message, visible=True),
|
434 |
+
# Reveal and enable energy vote buttons
|
435 |
+
gr.Button.update(visible=True, interactive=True), gr.Button.update(visible=True, interactive=True),
|
436 |
+
# Keep the reset button disabled
|
437 |
+
gr.Button.update(visible=False, interactive=False),
|
438 |
+
]
|
439 |
+
return resp_vote_func
|
440 |
+
|
441 |
+
def make_energy_vote_func(is_worth: bool):
|
442 |
+
"""Return a function that will be called when the user clicks on energy vote buttons."""
|
443 |
+
def energy_vote_func(client: ControllerClient, energy_message: str):
|
444 |
+
vote_response = client.energy_vote(is_worth=is_worth)
|
445 |
+
model_name_a, model_name_b = map(lambda n: f"## {n}", vote_response.model_names)
|
446 |
+
return [
|
447 |
+
# Reveal model names
|
448 |
+
gr.Markdown.update(model_name_a), gr.Markdown.update(model_name_b),
|
449 |
+
# Disable energy vote buttons
|
450 |
+
gr.Button.update(interactive=False), gr.Button.update(interactive=False),
|
451 |
+
# Enable reset button
|
452 |
+
gr.Button.update(interactive=True, visible=True),
|
453 |
+
# Append to the energy comparison message
|
454 |
+
energy_message[:-5] + (" Fair enough.</h2>" if is_worth else " Wasn't worth it.</h2>"),
|
455 |
+
]
|
456 |
+
return energy_vote_func
|
457 |
+
|
458 |
+
def play_again():
|
459 |
+
return [
|
460 |
+
# Clear chatbot history
|
461 |
+
None, None,
|
462 |
+
# Turn on prompt textbox and submit button
|
463 |
+
gr.Textbox.update(value="", interactive=True), gr.Button.update(interactive=True),
|
464 |
+
# Mask model names
|
465 |
+
gr.Markdown.update(ANONYMOUS_MODEL_TEXT),
|
466 |
+
gr.Markdown.update(ANONYMOUS_MODEL_TEXT),
|
467 |
+
# Hide energy vote buttons and message
|
468 |
+
gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Markdown.update(visible=False),
|
469 |
+
# Disable reset button
|
470 |
+
gr.Button.update(interactive=False, visible=False),
|
471 |
+
]
|
472 |
+
|
473 |
+
focus_prompt_input_js = """
|
474 |
+
function() {
|
475 |
+
for (let textarea of document.getElementsByTagName("textarea")) {
|
476 |
+
if (textarea.hasAttribute("autofocus")) {
|
477 |
+
textarea.focus();
|
478 |
+
return;
|
479 |
+
}
|
480 |
+
}
|
481 |
+
}
|
482 |
+
"""
|
483 |
+
|
484 |
+
with gr.Blocks(css=custom_css) as block:
|
485 |
tbm = gr.State(global_tbm) # type: ignore
|
486 |
with gr.Box():
|
487 |
gr.HTML("<h1><a href='https://ml.energy' class='text-logo'>ML.ENERGY</a> Leaderboard</h1>")
|
488 |
|
489 |
with gr.Tabs():
|
490 |
+
# Tab: Colosseum.
|
491 |
+
with gr.TabItem("Colosseum ⚔️️"):
|
492 |
+
gr.Markdown(open("docs/colosseum_top.md").read())
|
493 |
+
|
494 |
+
with gr.Group():
|
495 |
+
with gr.Row():
|
496 |
+
prompt_input = gr.Textbox(
|
497 |
+
show_label=False,
|
498 |
+
placeholder="Type your prompt and press ENTER",
|
499 |
+
autofocus=True,
|
500 |
+
container=False,
|
501 |
+
scale=20,
|
502 |
+
elem_id="prompt-textarea",
|
503 |
+
)
|
504 |
+
prompt_submit_btn = gr.Button(
|
505 |
+
value="⚔️️ Fight!",
|
506 |
+
elem_classes=["btn-submit"],
|
507 |
+
min_width=60,
|
508 |
+
scale=1,
|
509 |
+
)
|
510 |
+
|
511 |
+
with gr.Row():
|
512 |
+
masked_model_names = []
|
513 |
+
chatbots = []
|
514 |
+
resp_vote_btn_list: list[gr.component.Component] = []
|
515 |
+
with gr.Column():
|
516 |
+
with gr.Row():
|
517 |
+
masked_model_names.append(gr.Markdown(ANONYMOUS_MODEL_TEXT))
|
518 |
+
with gr.Row():
|
519 |
+
chatbots.append(gr.Chatbot(label="Model A", elem_id="chatbot", height=600))
|
520 |
+
with gr.Row():
|
521 |
+
left_resp_vote_btn = gr.Button(value="👈 Model A is better", interactive=False)
|
522 |
+
resp_vote_btn_list.append(left_resp_vote_btn)
|
523 |
+
|
524 |
+
with gr.Column():
|
525 |
+
with gr.Row():
|
526 |
+
masked_model_names.append(gr.Markdown(ANONYMOUS_MODEL_TEXT))
|
527 |
+
with gr.Row():
|
528 |
+
chatbots.append(gr.Chatbot(label="Model B", elem_id="chatbot", height=600))
|
529 |
+
with gr.Row():
|
530 |
+
right_resp_vote_btn = gr.Button(value="👉 Model B is better", interactive=False)
|
531 |
+
resp_vote_btn_list.append(right_resp_vote_btn)
|
532 |
+
|
533 |
+
with gr.Row():
|
534 |
+
energy_comparison_message = gr.HTML(visible=False)
|
535 |
+
|
536 |
+
with gr.Row():
|
537 |
+
worth_energy_vote_btn = gr.Button(value="The better response was worth the extra energy.", visible=False)
|
538 |
+
notworth_energy_vote_btn = gr.Button(value="Not really worth it.", visible=False)
|
539 |
+
energy_vote_btn_list: list[gr.component.Component] = [worth_energy_vote_btn, notworth_energy_vote_btn]
|
540 |
+
|
541 |
+
with gr.Row():
|
542 |
+
play_again_btn = gr.Button("Play again!", visible=False)
|
543 |
+
|
544 |
+
gr.Markdown(open("docs/colosseum_bottom.md").read())
|
545 |
+
|
546 |
+
controller_client = gr.State()
|
547 |
+
|
548 |
+
|
549 |
+
(prompt_input
|
550 |
+
.submit(add_prompt_disable_submit, [prompt_input, *chatbots], [prompt_input, prompt_submit_btn, *chatbots, controller_client], queue=False)
|
551 |
+
.then(generate_responses, [controller_client, *chatbots], [*chatbots], queue=True)
|
552 |
+
.then(enable_interact, None, resp_vote_btn_list, queue=False))
|
553 |
+
(prompt_submit_btn
|
554 |
+
.click(add_prompt_disable_submit, [prompt_input, *chatbots], [prompt_input, prompt_submit_btn, *chatbots, controller_client], queue=False)
|
555 |
+
.then(generate_responses, [controller_client, *chatbots], [*chatbots], queue=True)
|
556 |
+
.then(enable_interact, None, resp_vote_btn_list, queue=False))
|
557 |
+
|
558 |
+
left_resp_vote_btn.click(
|
559 |
+
make_resp_vote_func(victory_index=0),
|
560 |
+
[controller_client],
|
561 |
+
[*resp_vote_btn_list, *masked_model_names, energy_comparison_message, *energy_vote_btn_list, play_again_btn],
|
562 |
+
queue=False,
|
563 |
+
)
|
564 |
+
right_resp_vote_btn.click(
|
565 |
+
make_resp_vote_func(victory_index=1),
|
566 |
+
[controller_client],
|
567 |
+
[*resp_vote_btn_list, *masked_model_names, energy_comparison_message, *energy_vote_btn_list, play_again_btn],
|
568 |
+
queue=False,
|
569 |
+
)
|
570 |
+
|
571 |
+
worth_energy_vote_btn.click(
|
572 |
+
make_energy_vote_func(is_worth=True),
|
573 |
+
[controller_client, energy_comparison_message],
|
574 |
+
[*masked_model_names, *energy_vote_btn_list, play_again_btn, energy_comparison_message],
|
575 |
+
queue=False,
|
576 |
+
)
|
577 |
+
notworth_energy_vote_btn.click(
|
578 |
+
make_energy_vote_func(is_worth=False),
|
579 |
+
[controller_client, energy_comparison_message],
|
580 |
+
[*masked_model_names, *energy_vote_btn_list, play_again_btn, energy_comparison_message],
|
581 |
+
queue=False,
|
582 |
+
)
|
583 |
+
|
584 |
+
(play_again_btn
|
585 |
+
.click(
|
586 |
+
play_again,
|
587 |
+
None,
|
588 |
+
[*chatbots, prompt_input, prompt_submit_btn, *masked_model_names, *energy_vote_btn_list, energy_comparison_message, play_again_btn],
|
589 |
+
queue=False,
|
590 |
+
)
|
591 |
+
.then(None, _js=focus_prompt_input_js, queue=False))
|
592 |
+
|
593 |
+
|
594 |
# Tab: Leaderboard.
|
595 |
with gr.Tab("Leaderboard"):
|
596 |
with gr.Box():
|
|
|
600 |
with gr.Row():
|
601 |
with gr.Box():
|
602 |
gr.Markdown("### Benchmark results to show")
|
603 |
+
checkboxes: list[gr.CheckboxGroup] = []
|
604 |
for key, choices in global_tbm.schema.items():
|
605 |
# Specifying `value` makes everything checked by default.
|
606 |
checkboxes.append(gr.CheckboxGroup(choices=choices, value=choices[:1], label=key))
|
|
|
609 |
with gr.Row():
|
610 |
dataframe = gr.Dataframe(type="pandas", elem_id="tab-leaderboard")
|
611 |
# Make sure the models have clickable links.
|
612 |
+
dataframe.change(None, None, None, _js=dataframe_update_js, queue=False)
|
613 |
# Table automatically updates when users check or uncheck any checkbox.
|
614 |
for checkbox in checkboxes:
|
615 |
+
checkbox.change(TableManager.set_filter_get_df, inputs=[tbm, *checkboxes], outputs=dataframe, queue=False)
|
616 |
|
617 |
# Block: Allow users to add new columns.
|
618 |
with gr.Box():
|
|
|
641 |
TableManager.add_column,
|
642 |
inputs=[tbm, colname_input, formula_input],
|
643 |
outputs=[dataframe, add_col_message],
|
644 |
+
queue=False,
|
645 |
)
|
646 |
formula_input.submit(
|
647 |
TableManager.add_column,
|
648 |
inputs=[tbm, colname_input, formula_input],
|
649 |
outputs=[dataframe, add_col_message],
|
650 |
+
queue=False,
|
651 |
)
|
652 |
add_col_btn.click(
|
653 |
TableManager.add_column,
|
654 |
inputs=[tbm, colname_input, formula_input],
|
655 |
outputs=[dataframe, add_col_message],
|
656 |
+
queue=False,
|
657 |
)
|
658 |
clear_input_btn.click(
|
659 |
lambda: (None, None, None),
|
660 |
inputs=None,
|
661 |
outputs=[colname_input, formula_input, add_col_message],
|
662 |
+
queue=False,
|
663 |
)
|
664 |
|
665 |
# Block: Allow users to plot 2D and 3D scatter plots.
|
|
|
689 |
)[0]) # type: ignore
|
690 |
with gr.Row():
|
691 |
plot_message = gr.HTML("")
|
692 |
+
add_col_btn.click(TableManager.update_dropdown, inputs=tbm, outputs=axis_dropdowns, queue=False) # type: ignore
|
693 |
plot_width_input.submit(
|
694 |
TableManager.plot_scatter,
|
695 |
inputs=[tbm, plot_width_input, plot_height_input, *axis_dropdowns],
|
696 |
outputs=[plot, plot_width_input, plot_height_input, plot_message],
|
697 |
+
queue=False,
|
698 |
)
|
699 |
plot_height_input.submit(
|
700 |
TableManager.plot_scatter,
|
701 |
inputs=[tbm, plot_width_input, plot_height_input, *axis_dropdowns],
|
702 |
outputs=[plot, plot_width_input, plot_height_input, plot_message],
|
703 |
+
queue=False,
|
704 |
)
|
705 |
plot_btn.click(
|
706 |
TableManager.plot_scatter,
|
707 |
inputs=[tbm, plot_width_input, plot_height_input, *axis_dropdowns],
|
708 |
outputs=[plot, plot_width_input, plot_height_input, plot_message],
|
709 |
+
queue=False,
|
710 |
)
|
711 |
clear_plot_btn.click(
|
712 |
lambda: (None,) * 7,
|
713 |
None,
|
714 |
outputs=[*axis_dropdowns, plot, plot_width_input, plot_height_input, plot_message],
|
715 |
+
queue=False,
|
716 |
)
|
717 |
|
718 |
# Block: Leaderboard date.
|
719 |
with gr.Row():
|
720 |
gr.HTML(f"<h3 style='color: gray'>Last updated: {current_date}</h3>")
|
721 |
|
|
|
|
|
|
|
|
|
722 |
# Tab: About page.
|
723 |
with gr.Tab("About"):
|
724 |
# Read in LEADERBOARD.md
|
725 |
+
gr.Markdown(open("docs/leaderboard.md").read())
|
726 |
|
727 |
# Load the table on page load.
|
728 |
block.load(lambda: global_tbm.set_filter_get_df(), outputs=dataframe)
|
729 |
|
730 |
+
|
731 |
+
if __name__ == "__main__":
|
732 |
+
parser = argparse.ArgumentParser()
|
733 |
+
parser.add_argument("--share", action="store_true", help="Specify if sharing is enabled")
|
734 |
+
parser.add_argument("--concurrency", type=int, default=10)
|
735 |
+
|
736 |
+
args = parser.parse_args()
|
737 |
+
block.queue(
|
738 |
+
concurrency_count=args.concurrency, status_update_rate=10, api_open=False
|
739 |
+
).launch(share=args.share, show_error=True)
|
Dockerfile → deployment/benchmark.Dockerfile
RENAMED
File without changes
|
deployment/controller-container.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
docker run \
|
4 |
+
--name controller \
|
5 |
+
--net leaderboard \
|
6 |
+
-v $HOME/workspace/leaderboard:/workspace/leaderboard \
|
7 |
+
-v $HOME/workspace/text-generation-inference/deployment:/workspace/text-generation-inference/deployment:ro \
|
8 |
+
-v /data/leaderboard/colosseum-controller-logs:/logs \
|
9 |
+
-p 7778:8000 \
|
10 |
+
-e LOG_DIR=/logs \
|
11 |
+
mlenergy/colosseum-controller:latest
|
deployment/controller.Dockerfile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM ubuntu:22.04
|
2 |
+
|
3 |
+
# Basic installs
|
4 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
5 |
+
ENV TZ='America/Detroit'
|
6 |
+
RUN apt-get update -qq \
|
7 |
+
&& apt-get -y --no-install-recommends install \
|
8 |
+
tzdata software-properties-common wget git \
|
9 |
+
&& apt-get clean all \
|
10 |
+
&& rm -r /var/lib/apt/lists/* \
|
11 |
+
&& ln -fs /usr/share/zoneinfo/America/Detroit /etc/localtime \
|
12 |
+
&& dpkg-reconfigure -f noninteractive tzdata
|
13 |
+
|
14 |
+
# Install Miniconda3 23.3.1
|
15 |
+
ENV PATH="/root/.local/miniconda3/bin:$PATH"
|
16 |
+
RUN mkdir -p /root/.local \
|
17 |
+
&& wget https://repo.anaconda.com/miniconda/Miniconda3-py39_23.3.1-0-Linux-x86_64.sh \
|
18 |
+
&& mkdir /root/.conda \
|
19 |
+
&& bash Miniconda3-py39_23.3.1-0-Linux-x86_64.sh -b -p /root/.local/miniconda3 \
|
20 |
+
&& rm -f Miniconda3-py39_23.3.1-0-Linux-x86_64.sh \
|
21 |
+
&& ln -sf /root/.local/miniconda3/etc/profile.d/conda.sh /etc/profile.d/conda.sh
|
22 |
+
|
23 |
+
# Install spitfight
|
24 |
+
ADD . /workspace/leaderboard
|
25 |
+
RUN cd /workspace/leaderboard \
|
26 |
+
&& pip install -e .[colosseum-controller]
|
27 |
+
|
28 |
+
WORKDIR /workspace/leaderboard
|
29 |
+
|
30 |
+
CMD ["python", "spitfight/colosseum/controller/router.py"]
|
deployment/docker-compose-0.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
Falcon-7B:
|
3 |
+
container_name: worker0
|
4 |
+
image: mlenergy/tgi:latest
|
5 |
+
command: ["--model-id", "tiiuae/falcon-7b-instruct", "--num-shard", "1", "--otlp-endpoint", "http://jaeger:4317"]
|
6 |
+
shm_size: 1g
|
7 |
+
networks:
|
8 |
+
- leaderboard
|
9 |
+
volumes:
|
10 |
+
- /data/leaderboard/tgi-data:/data
|
11 |
+
deploy:
|
12 |
+
resources:
|
13 |
+
reservations:
|
14 |
+
devices:
|
15 |
+
- driver: nvidia
|
16 |
+
device_ids: ["0"]
|
17 |
+
capabilities: [gpu]
|
18 |
+
Llama2-7B:
|
19 |
+
container_name: worker1
|
20 |
+
image: mlenergy/tgi:latest
|
21 |
+
command: ["--model-id", "/weights/metaai/Llama-2-7b-chat-hf", "--num-shard", "1", "--otlp-endpoint", "http://jaeger:4317"]
|
22 |
+
shm_size: 1g
|
23 |
+
networks:
|
24 |
+
- leaderboard
|
25 |
+
volumes:
|
26 |
+
- /data/leaderboard/tgi-data:/data
|
27 |
+
- /data/leaderboard/weights:/weights
|
28 |
+
deploy:
|
29 |
+
resources:
|
30 |
+
reservations:
|
31 |
+
devices:
|
32 |
+
- driver: nvidia
|
33 |
+
device_ids: ["1"]
|
34 |
+
capabilities: [gpu]
|
35 |
+
FastChat-T5-3B:
|
36 |
+
container_name: worker2
|
37 |
+
image: mlenergy/tgi:latest
|
38 |
+
command: ["--model-id", "lmsys/fastchat-t5-3b-v1.0", "--num-shard", "1", "--otlp-endpoint", "http://jaeger:4317"]
|
39 |
+
environment:
|
40 |
+
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION: python
|
41 |
+
shm_size: 1g
|
42 |
+
networks:
|
43 |
+
- leaderboard
|
44 |
+
volumes:
|
45 |
+
- /data/leaderboard/tgi-data:/data
|
46 |
+
deploy:
|
47 |
+
resources:
|
48 |
+
reservations:
|
49 |
+
devices:
|
50 |
+
- driver: nvidia
|
51 |
+
device_ids: ["2"]
|
52 |
+
capabilities: [gpu]
|
53 |
+
Llama2-13B:
|
54 |
+
container_name: worker3
|
55 |
+
image: mlenergy/tgi:latest
|
56 |
+
command: ["--model-id", "/weights/metaai/Llama-2-13b-chat-hf", "--num-shard", "1", "--otlp-endpoint", "http://jaeger:4317"]
|
57 |
+
shm_size: 1g
|
58 |
+
networks:
|
59 |
+
- leaderboard
|
60 |
+
volumes:
|
61 |
+
- /data/leaderboard/tgi-data:/data
|
62 |
+
- /data/leaderboard/weights:/weights
|
63 |
+
deploy:
|
64 |
+
resources:
|
65 |
+
reservations:
|
66 |
+
devices:
|
67 |
+
- driver: nvidia
|
68 |
+
device_ids: ["3"]
|
69 |
+
capabilities: [gpu]
|
70 |
+
|
71 |
+
networks:
|
72 |
+
leaderboard:
|
73 |
+
name: leaderboard
|
74 |
+
external: true
|
deployment/docker-compose-1.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
Llama2-70B-INT8:
|
3 |
+
container_name: worker4
|
4 |
+
image: mlenergy/tgi:latest
|
5 |
+
command: ["--model-id", "meta-llama/Llama-2-70b-chat-hf", "--num-shard", "2", "--otlp-endpoint", "http://jaeger:4317", "--quantize", "bitsandbytes"]
|
6 |
+
shm_size: 1g
|
7 |
+
environment:
|
8 |
+
HUGGING_FACE_HUB_TOKEN: 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
|
9 |
+
networks:
|
10 |
+
- leaderboard
|
11 |
+
volumes:
|
12 |
+
- /data/leaderboard/tgi-data:/data
|
13 |
+
deploy:
|
14 |
+
resources:
|
15 |
+
reservations:
|
16 |
+
devices:
|
17 |
+
- driver: nvidia
|
18 |
+
device_ids: ["0", "1"]
|
19 |
+
capabilities: [gpu]
|
20 |
+
Falcon-40B:
|
21 |
+
container_name: worker5
|
22 |
+
image: mlenergy/tgi:latest
|
23 |
+
command: ["--model-id", "tiiuae/falcon-40b-instruct", "--num-shard", "2", "--otlp-endpoint", "http://jaeger:4317"]
|
24 |
+
shm_size: 1g
|
25 |
+
networks:
|
26 |
+
- leaderboard
|
27 |
+
volumes:
|
28 |
+
- /data/leaderboard/tgi-data:/data
|
29 |
+
deploy:
|
30 |
+
resources:
|
31 |
+
reservations:
|
32 |
+
devices:
|
33 |
+
- driver: nvidia
|
34 |
+
device_ids: ["2", "3"]
|
35 |
+
capabilities: [gpu]
|
36 |
+
|
37 |
+
networks:
|
38 |
+
leaderboard:
|
39 |
+
name: leaderboard
|
40 |
+
external: true
|
docs/colosseum_bottom.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Terms of use
|
2 |
+
|
3 |
+
By using our service, you agree to these Terms of Use and accept that the Service provides an approximate estimation of model inference energy usage for research purposes only. We are not liable for any damages or loss incurred by you or any third party arising from the use of the Service. It may generate offensive content and offers limited safety measures, thus should not be used for any illegal, harmful, violent, racist, or sexual purposes. The service collects user dialogue data and voting results. We reserve the right to distribute the dataset in the future.
|
4 |
+
|
5 |
+
### Technical details
|
6 |
+
|
7 |
+
- We allow models to generate only up to 512 new tokens. Due to this, some responses may be cut off in the middle.
|
8 |
+
- Tokens are sampled from the model output with `temperature` 1.0, `repetition_penalty` 1.0, `top_k` 50, and `top_p` 0.95.
|
9 |
+
- Large models (>= 30B) run on two NVIDIA A40 GPUs with tensor parallelism, whereas other models run on one NVIDIA A40 GPU. We directly measure the energy consumption of these GPUs.
|
10 |
+
|
11 |
+
### Contact
|
12 |
+
|
13 |
+
Please direct general questions and issues related to the Colosseum to our GitHub repository's [discussion board](https://github.com/ml-energy/leaderboard/discussions).
|
14 |
+
You can find the ML.ENERGY initiative members in [our homepage](https://ml.energy#members).
|
docs/colosseum_top.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
> Enter the ML.ENERGY Colosseum, where language models duel with intellect, and your judgment tips the scales of victory.
|
2 |
+
|
3 |
+
### Rules of the Colosseum
|
4 |
+
|
5 |
+
- As the spectator, you'll decide the fates of two anonymous language models -- our gladiators.
|
6 |
+
- Your role is twofold: First, you vote for the model that delivered the best response to your prompt.
|
7 |
+
- Next, mighty [Zeus](https://ml.energy/zeus) will reveal which language model consumed more energy. Evaluate if its performance justified the energy consumption.
|
8 |
+
- Only after you cast votes will the models' identities be unveiled.
|
LEADERBOARD.md → docs/leaderboard.md
RENAMED
@@ -3,7 +3,7 @@ The goal of the ML.ENERGY Leaderboard is to give people a sense of how much **en
|
|
3 |
The code for the leaderboard, backing data, and scripts for benchmarking are all open-source in our [repository](https://github.com/ml-energy/leaderboard).
|
4 |
We'll see you at the [Discussion board](https://github.com/ml-energy/leaderboard/discussions), where you can ask questions, suggest improvement ideas, or just discuss leaderboard results!
|
5 |
|
6 |
-
## Columns
|
7 |
|
8 |
- `gpu`: NVIDIA GPU model name.
|
9 |
- `task`: Name of the task. See *Tasks* below for details.
|
@@ -113,7 +113,7 @@ By doing this, we can provide numbers for reasonable comparison without being ti
|
|
113 |
|
114 |
This leaderboard is a research preview intended for non-commercial use only.
|
115 |
Model weights were taken as is from the Hugging Face Hub if available and are subject to their licenses.
|
116 |
-
The use of
|
117 |
Please direct inquiries/reports of potential violation to Jae-Won Chung.
|
118 |
|
119 |
## Acknowledgements
|
|
|
3 |
The code for the leaderboard, backing data, and scripts for benchmarking are all open-source in our [repository](https://github.com/ml-energy/leaderboard).
|
4 |
We'll see you at the [Discussion board](https://github.com/ml-energy/leaderboard/discussions), where you can ask questions, suggest improvement ideas, or just discuss leaderboard results!
|
5 |
|
6 |
+
## Leaderboard Columns
|
7 |
|
8 |
- `gpu`: NVIDIA GPU model name.
|
9 |
- `task`: Name of the task. See *Tasks* below for details.
|
|
|
113 |
|
114 |
This leaderboard is a research preview intended for non-commercial use only.
|
115 |
Model weights were taken as is from the Hugging Face Hub if available and are subject to their licenses.
|
116 |
+
The use of Llama weights are subject to their [license](https://github.com/facebookresearch/llama/blob/main/LICENSE).
|
117 |
Please direct inquiries/reports of potential violation to Jae-Won Chung.
|
118 |
|
119 |
## Acknowledgements
|
requirements.txt
CHANGED
@@ -1,2 +1 @@
|
|
1 |
-
|
2 |
-
gradio==3.35.2
|
|
|
1 |
+
.[app]
|
|
setup.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
extras_require = {
|
4 |
+
"colosseum-controller": [
|
5 |
+
"fastapi",
|
6 |
+
"fschat==0.2.23",
|
7 |
+
"text_generation @ git+https://github.com/ml-energy/text_generation_energy@master",
|
8 |
+
],
|
9 |
+
"app": ["plotly==5.15.0", "gradio==3.39.0", "pydantic==1.10.9"],
|
10 |
+
"benchmark": ["zeus-ml", "fschat==0.2.23", "tyro", "rich"],
|
11 |
+
}
|
12 |
+
|
13 |
+
extras_require["all"] = list(set(sum(extras_require.values(), [])))
|
14 |
+
|
15 |
+
setup(
|
16 |
+
name="spitfight",
|
17 |
+
version="0.0.1",
|
18 |
+
url="https://github.com/ml-energy/leaderboard",
|
19 |
+
packages=find_packages("."),
|
20 |
+
extras_require=extras_require,
|
21 |
+
)
|
spitfight/__init__.py
ADDED
File without changes
|
spitfight/colosseum/__init__.py
ADDED
File without changes
|
spitfight/colosseum/client.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import json
|
4 |
+
import unittest
|
5 |
+
import contextlib
|
6 |
+
from uuid import uuid4, UUID
|
7 |
+
from copy import deepcopy
|
8 |
+
from typing import Generator, Literal
|
9 |
+
|
10 |
+
import requests
|
11 |
+
import gradio as gr
|
12 |
+
|
13 |
+
from spitfight.colosseum.common import (
|
14 |
+
COLOSSEUM_PROMPT_ROUTE,
|
15 |
+
COLOSSEUM_RESP_VOTE_ROUTE,
|
16 |
+
COLOSSEUM_ENERGY_VOTE_ROUTE,
|
17 |
+
PromptRequest,
|
18 |
+
ResponseVoteRequest,
|
19 |
+
ResponseVoteResponse,
|
20 |
+
EnergyVoteRequest,
|
21 |
+
EnergyVoteResponse,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class ControllerClient:
|
26 |
+
"""Client for the Colosseum controller, to be used by Gradio."""
|
27 |
+
|
28 |
+
def __init__(self, controller_addr: str, timeout: int = 15, request_id: UUID | None = None) -> None:
|
29 |
+
"""Initialize the controller client."""
|
30 |
+
self.controller_addr = controller_addr
|
31 |
+
self.timeout = timeout
|
32 |
+
self.request_id = str(request_id) or str(uuid4())
|
33 |
+
|
34 |
+
def fork(self) -> ControllerClient:
|
35 |
+
"""Return a copy of the client with a new request ID."""
|
36 |
+
return ControllerClient(
|
37 |
+
controller_addr=self.controller_addr,
|
38 |
+
timeout=self.timeout,
|
39 |
+
request_id=uuid4(),
|
40 |
+
)
|
41 |
+
|
42 |
+
def prompt(self, prompt: str, index: Literal[0, 1]) -> Generator[str, None, None]:
|
43 |
+
"""Generate the response of the `index`th model with the prompt."""
|
44 |
+
prompt_request = PromptRequest(request_id=self.request_id, prompt=prompt, model_index=index)
|
45 |
+
with _catch_requests_exceptions():
|
46 |
+
resp = requests.post(
|
47 |
+
f"http://{self.controller_addr}{COLOSSEUM_PROMPT_ROUTE}",
|
48 |
+
json=prompt_request.dict(),
|
49 |
+
stream=True,
|
50 |
+
timeout=self.timeout,
|
51 |
+
)
|
52 |
+
_check_response(resp)
|
53 |
+
# XXX: Why can't the server just yield `text + "\n"` and here we just iter_lines?
|
54 |
+
for chunk in resp.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
55 |
+
if chunk:
|
56 |
+
yield json.loads(chunk.decode("utf-8"))
|
57 |
+
|
58 |
+
def response_vote(self, victory_index: Literal[0, 1]) -> ResponseVoteResponse:
|
59 |
+
"""Notify the controller of the user's vote for the response."""
|
60 |
+
response_vote_request = ResponseVoteRequest(request_id=self.request_id, victory_index=victory_index)
|
61 |
+
with _catch_requests_exceptions():
|
62 |
+
resp = requests.post(
|
63 |
+
f"http://{self.controller_addr}{COLOSSEUM_RESP_VOTE_ROUTE}",
|
64 |
+
json=response_vote_request.dict(),
|
65 |
+
)
|
66 |
+
_check_response(resp)
|
67 |
+
return ResponseVoteResponse(**resp.json())
|
68 |
+
|
69 |
+
def energy_vote(self, is_worth: bool) -> EnergyVoteResponse:
|
70 |
+
"""Notify the controller of the user's vote for energy."""
|
71 |
+
energy_vote_request = EnergyVoteRequest(request_id=self.request_id, is_worth=is_worth)
|
72 |
+
with _catch_requests_exceptions():
|
73 |
+
resp = requests.post(
|
74 |
+
f"http://{self.controller_addr}{COLOSSEUM_ENERGY_VOTE_ROUTE}",
|
75 |
+
json=energy_vote_request.dict(),
|
76 |
+
)
|
77 |
+
_check_response(resp)
|
78 |
+
return EnergyVoteResponse(**resp.json())
|
79 |
+
|
80 |
+
|
81 |
+
@contextlib.contextmanager
|
82 |
+
def _catch_requests_exceptions():
|
83 |
+
"""Catch requests exceptions and raise gr.Error instead."""
|
84 |
+
try:
|
85 |
+
yield
|
86 |
+
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
87 |
+
raise gr.Error("Failed to connect to our the backend server. Please try again later.")
|
88 |
+
|
89 |
+
|
90 |
+
def _check_response(response: requests.Response) -> None:
|
91 |
+
if 400 <= response.status_code < 500:
|
92 |
+
raise gr.Error(response.json()["detail"])
|
93 |
+
elif response.status_code >= 500:
|
94 |
+
raise gr.Error("Failed to talk to our backend server. Please try again later.")
|
95 |
+
|
96 |
+
|
97 |
+
class TestControllerClient(unittest.TestCase):
|
98 |
+
def test_new_uuid_on_deepcopy(self):
|
99 |
+
client = ControllerClient("http://localhost:8000")
|
100 |
+
clients = [client.fork() for _ in range(50)]
|
101 |
+
request_ids = [client.request_id for client in clients]
|
102 |
+
assert len(set(request_ids)) == len(request_ids)
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
unittest.main()
|
spitfight/colosseum/common.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Literal
|
4 |
+
|
5 |
+
from pydantic import BaseModel
|
6 |
+
|
7 |
+
COLOSSEUM_PROMPT_ROUTE = "/prompt"
|
8 |
+
COLOSSEUM_RESP_VOTE_ROUTE = "/response_vote"
|
9 |
+
COLOSSEUM_ENERGY_VOTE_ROUTE = "/energy_vote"
|
10 |
+
COLOSSEUM_HEALTH_ROUTE = "/health"
|
11 |
+
|
12 |
+
|
13 |
+
class PromptRequest(BaseModel):
|
14 |
+
request_id: str
|
15 |
+
prompt: str
|
16 |
+
model_index: Literal[0, 1]
|
17 |
+
|
18 |
+
|
19 |
+
class ResponseVoteRequest(BaseModel):
|
20 |
+
request_id: str
|
21 |
+
victory_index: Literal[0, 1]
|
22 |
+
|
23 |
+
|
24 |
+
class ResponseVoteResponse(BaseModel):
|
25 |
+
model_names: list[str]
|
26 |
+
energy_consumptions: list[float]
|
27 |
+
|
28 |
+
|
29 |
+
class EnergyVoteRequest(BaseModel):
|
30 |
+
request_id: str
|
31 |
+
is_worth: bool
|
32 |
+
|
33 |
+
|
34 |
+
class EnergyVoteResponse(BaseModel):
|
35 |
+
model_names: list[str]
|
spitfight/colosseum/controller/__init__.py
ADDED
File without changes
|
spitfight/colosseum/controller/controller.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import json
|
4 |
+
import asyncio
|
5 |
+
from datetime import datetime
|
6 |
+
from typing import AsyncGenerator, Literal, Optional, TYPE_CHECKING
|
7 |
+
|
8 |
+
import aiohttp
|
9 |
+
from pytz import timezone
|
10 |
+
from pydantic import BaseModel, Field
|
11 |
+
|
12 |
+
from spitfight.log import get_logger
|
13 |
+
from spitfight.utils import BoundedExpiringDict, TokenGenerationBuffer, create_task
|
14 |
+
from spitfight.colosseum.controller.worker import WorkerService
|
15 |
+
from spitfight.prompt import get_system_prompt, apply_model_characteristics
|
16 |
+
|
17 |
+
if TYPE_CHECKING:
|
18 |
+
from spitfight.colosseum.controller.router import ControllerConfig
|
19 |
+
|
20 |
+
controller_logger = get_logger(__name__)
|
21 |
+
request_logger = get_logger("colosseum_requests")
|
22 |
+
|
23 |
+
|
24 |
+
def now() -> datetime:
|
25 |
+
return datetime.now(tz=timezone("US/Eastern"))
|
26 |
+
|
27 |
+
|
28 |
+
# Internal states
|
29 |
+
# The two "chose_*" stages are both the result of voting on a response.
|
30 |
+
# A normal user will sequentially go through either
|
31 |
+
# "prompted" -> "chose_less_energy_response", or
|
32 |
+
# "prompted" -> "chose_more_energy_response" -> "voted_energy"
|
33 |
+
UserStage = Literal[
|
34 |
+
"prompted",
|
35 |
+
"chose_less_energy_response",
|
36 |
+
"chose_more_energy_response",
|
37 |
+
"voted_energy",
|
38 |
+
]
|
39 |
+
|
40 |
+
|
41 |
+
class RequestState(BaseModel):
|
42 |
+
"""Models the state of a Colosseum play.
|
43 |
+
|
44 |
+
This model is also serialized as is and logged.
|
45 |
+
"""
|
46 |
+
request_id: str
|
47 |
+
prompt: str
|
48 |
+
model_names: list[str]
|
49 |
+
responses: list[str] = ["EMPTY", "EMPTY"]
|
50 |
+
energy_consumptions: list[float] = [0.0, 0.0]
|
51 |
+
response_victory_index: Optional[Literal[0, 1]] = None
|
52 |
+
extra_energy_was_worth: Optional[bool] = None
|
53 |
+
|
54 |
+
# The time when the user's stage changed.
|
55 |
+
timestamp: datetime = Field(default_factory=now)
|
56 |
+
# The user's current stage.
|
57 |
+
user_stage: UserStage = "prompted"
|
58 |
+
# When the the user is not going through the aforementioned stages,
|
59 |
+
# the user's stage transition is recorded here.
|
60 |
+
abnormal_stage_change: list[tuple[UserStage, UserStage]] = []
|
61 |
+
|
62 |
+
def set_response_and_energy(self, model_index: Literal[0, 1], response: str, energy_consumption: float) -> None:
|
63 |
+
self.timestamp = now()
|
64 |
+
self.energy_consumptions[model_index] = energy_consumption
|
65 |
+
self.responses[model_index] = response
|
66 |
+
|
67 |
+
def set_response_vote(self, victory_index: Literal[0, 1]) -> None:
|
68 |
+
self.timestamp = now()
|
69 |
+
|
70 |
+
# Next stage depends on the user's vote.
|
71 |
+
energy_a, energy_b = self.energy_consumptions
|
72 |
+
if (victory_index == 0 and energy_a <= energy_b) or (victory_index == 1 and energy_a >= energy_b):
|
73 |
+
next_stage = "chose_less_energy_response"
|
74 |
+
else:
|
75 |
+
next_stage = "chose_more_energy_response"
|
76 |
+
|
77 |
+
# Detect abnormal stage change.
|
78 |
+
if self.user_stage != "prompted":
|
79 |
+
self.abnormal_stage_change.append((self.user_stage, next_stage))
|
80 |
+
|
81 |
+
self.user_stage = next_stage
|
82 |
+
self.response_victory_index = victory_index
|
83 |
+
|
84 |
+
def set_energy_vote(self, is_worth: bool) -> None:
|
85 |
+
self.timestamp = now()
|
86 |
+
|
87 |
+
# Detect abnormal stage change.
|
88 |
+
if self.user_stage != "chose_more_energy_response":
|
89 |
+
self.abnormal_stage_change.append((self.user_stage, "voted_energy"))
|
90 |
+
|
91 |
+
self.user_stage = "voted_energy"
|
92 |
+
self.extra_energy_was_worth = is_worth
|
93 |
+
|
94 |
+
|
95 |
+
class GenerationConfig(BaseModel):
|
96 |
+
"""Configuration for generation of prompts."""
|
97 |
+
max_new_tokens: int
|
98 |
+
do_sample: bool
|
99 |
+
temperature: float
|
100 |
+
repetition_penalty: float
|
101 |
+
top_k: int
|
102 |
+
top_p: float
|
103 |
+
|
104 |
+
|
105 |
+
class Controller:
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
background_task_interval: int,
|
109 |
+
max_num_req_states: int,
|
110 |
+
req_state_expiration_time: int,
|
111 |
+
worker_service: WorkerService,
|
112 |
+
generation_config: GenerationConfig,
|
113 |
+
):
|
114 |
+
self.request_states: BoundedExpiringDict[str, RequestState] = \
|
115 |
+
BoundedExpiringDict(max_num_req_states, req_state_expiration_time)
|
116 |
+
self.worker_service = worker_service
|
117 |
+
|
118 |
+
self.generation_config = generation_config
|
119 |
+
|
120 |
+
self.background_task_handle = create_task(
|
121 |
+
self._background_task(background_task_interval),
|
122 |
+
)
|
123 |
+
|
124 |
+
def shutdown(self) -> None:
|
125 |
+
"""Shutdown the controller."""
|
126 |
+
self.background_task_handle.cancel()
|
127 |
+
|
128 |
+
async def _background_task(self, heartbeat_interval: int) -> None:
|
129 |
+
"""Periodically check if dead workers are alive again and do request state GC."""
|
130 |
+
while True:
|
131 |
+
await asyncio.sleep(heartbeat_interval)
|
132 |
+
|
133 |
+
await self.worker_service.check_workers()
|
134 |
+
|
135 |
+
prev_num_req_states = len(self.request_states)
|
136 |
+
self.request_states.cleanup()
|
137 |
+
controller_logger.info(
|
138 |
+
"Request state garbage collection done: Removed %d reqeusts",
|
139 |
+
prev_num_req_states - len(self.request_states),
|
140 |
+
)
|
141 |
+
|
142 |
+
def response_vote(self, request_id: str, victory_index: Literal[0, 1]) -> RequestState | None:
|
143 |
+
"""Record the user's response vote and return the new state."""
|
144 |
+
if (state := self.request_states.get(request_id)) is not None:
|
145 |
+
state.set_response_vote(victory_index)
|
146 |
+
# Pop the state from the dict if the user has voted on energy.
|
147 |
+
if state.user_stage == "chose_less_energy_response":
|
148 |
+
self.request_states.pop(request_id)
|
149 |
+
request_logger.info(state.json())
|
150 |
+
return state
|
151 |
+
return None
|
152 |
+
|
153 |
+
def energy_vote(self, request_id: str, is_worth: bool) -> RequestState | None:
|
154 |
+
"""Record the user's energy vote and return the new state."""
|
155 |
+
# Pop the state from the dict, since this is the last step in any case.
|
156 |
+
if (state := self.request_states.pop(request_id)) is not None:
|
157 |
+
state.set_energy_vote(is_worth)
|
158 |
+
request_logger.info(state.json())
|
159 |
+
return state
|
160 |
+
return None
|
161 |
+
|
162 |
+
async def prompt(
|
163 |
+
self,
|
164 |
+
request_id: str,
|
165 |
+
prompt: str,
|
166 |
+
model_index: Literal[0, 1],
|
167 |
+
) -> AsyncGenerator[bytes, None]:
|
168 |
+
# This method is called twice for the same request, once for each model.
|
169 |
+
# If it's the first time this method is called, assign models to the request.
|
170 |
+
if request_id not in self.request_states:
|
171 |
+
workers = self.worker_service.choose_two()
|
172 |
+
model_names = [worker.model_name for worker in workers]
|
173 |
+
self.request_states[request_id] = RequestState(
|
174 |
+
request_id=request_id,
|
175 |
+
prompt=prompt,
|
176 |
+
model_names=model_names,
|
177 |
+
)
|
178 |
+
request_state = self.request_states[request_id]
|
179 |
+
model_name = request_state.model_names[model_index]
|
180 |
+
try:
|
181 |
+
worker = self.worker_service.get_worker(model_name)
|
182 |
+
except KeyError:
|
183 |
+
controller_logger.error("Worker %s not found.", model_name)
|
184 |
+
raise
|
185 |
+
except RuntimeError:
|
186 |
+
controller_logger.error("Worker %s is dead.", model_name)
|
187 |
+
raise
|
188 |
+
prompt, stop_str, stop_token_ids = apply_model_characteristics(
|
189 |
+
system_prompt=get_system_prompt("chat"),
|
190 |
+
prompt=prompt,
|
191 |
+
model_name=worker.model_id,
|
192 |
+
)
|
193 |
+
|
194 |
+
# Request the model worker to stream the response to the user's prompt.
|
195 |
+
response = ""
|
196 |
+
energy = 0.0
|
197 |
+
client = worker.get_client()
|
198 |
+
buffer = TokenGenerationBuffer(stop_str=stop_str)
|
199 |
+
try:
|
200 |
+
async for resp in client.generate_stream(
|
201 |
+
prompt=prompt,
|
202 |
+
stop_sequences=[stop_str] if stop_str is not None else None,
|
203 |
+
**self.generation_config.dict(),
|
204 |
+
):
|
205 |
+
# Even special tokens consume energy when they're generated.
|
206 |
+
energy += resp.token.energy
|
207 |
+
|
208 |
+
# Stop tokens usually don't overlap with (human-readable) stop sequences.
|
209 |
+
# if resp.token.special or resp.token.id in stop_token_ids:
|
210 |
+
if resp.token.id in stop_token_ids:
|
211 |
+
# If the buffer is not empty (i.e., we had partial stop_str matches),
|
212 |
+
# just yield it to the user.
|
213 |
+
if (chunk := buffer.token_buffer):
|
214 |
+
response += chunk
|
215 |
+
yield json.dumps(chunk).encode() + b"\0"
|
216 |
+
break
|
217 |
+
|
218 |
+
# Skip special tokens.
|
219 |
+
if resp.token.special:
|
220 |
+
continue
|
221 |
+
|
222 |
+
# The buffer automatically handles `stop_str` partial and full matches.
|
223 |
+
buffer.append(resp.token.text)
|
224 |
+
if (chunk := buffer.pop()) is not None:
|
225 |
+
response += chunk
|
226 |
+
yield json.dumps(chunk).encode() + b"\0"
|
227 |
+
elif buffer.matched_stop_str:
|
228 |
+
break
|
229 |
+
except aiohttp.ClientConnectorError:
|
230 |
+
worker.status = "down"
|
231 |
+
controller_logger.error(
|
232 |
+
"Problem talking to %s. Aborting and setting worker status to down",
|
233 |
+
repr(worker),
|
234 |
+
)
|
235 |
+
raise
|
236 |
+
except Exception:
|
237 |
+
yield json.dumps(buffer.token_buffer).encode() + b"\0"
|
238 |
+
raise
|
239 |
+
finally:
|
240 |
+
request_state.set_response_and_energy(model_index, response, energy)
|
241 |
+
request_logger.info(request_state.json())
|
242 |
+
|
243 |
+
|
244 |
+
CONTROLLER: Controller | None = None
|
245 |
+
|
246 |
+
def init_global_controller(config: ControllerConfig) -> None:
|
247 |
+
global CONTROLLER
|
248 |
+
CONTROLLER = Controller(
|
249 |
+
background_task_interval=config.background_task_interval,
|
250 |
+
max_num_req_states=config.max_num_req_states,
|
251 |
+
req_state_expiration_time=config.req_state_expiration_time,
|
252 |
+
worker_service=WorkerService(config.compose_files),
|
253 |
+
generation_config=GenerationConfig(
|
254 |
+
max_new_tokens=config.max_new_tokens,
|
255 |
+
do_sample=config.do_sample,
|
256 |
+
temperature=config.temperature,
|
257 |
+
repetition_penalty=config.repetition_penalty,
|
258 |
+
top_k=config.top_k,
|
259 |
+
top_p=config.top_p,
|
260 |
+
),
|
261 |
+
)
|
262 |
+
|
263 |
+
def get_global_controller() -> Controller:
|
264 |
+
global CONTROLLER
|
265 |
+
assert CONTROLLER is not None
|
266 |
+
return CONTROLLER
|
spitfight/colosseum/controller/router.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
import uvicorn
|
5 |
+
from pydantic import BaseSettings
|
6 |
+
from fastapi import FastAPI, Depends
|
7 |
+
from fastapi.responses import StreamingResponse
|
8 |
+
from fastapi.exceptions import HTTPException
|
9 |
+
from text_generation.errors import OverloadedError, UnknownError, ValidationError
|
10 |
+
|
11 |
+
from spitfight.log import get_logger, init_queued_root_logger, shutdown_queued_root_loggers
|
12 |
+
from spitfight.colosseum.common import (
|
13 |
+
COLOSSEUM_PROMPT_ROUTE,
|
14 |
+
COLOSSEUM_RESP_VOTE_ROUTE,
|
15 |
+
COLOSSEUM_ENERGY_VOTE_ROUTE,
|
16 |
+
COLOSSEUM_HEALTH_ROUTE,
|
17 |
+
PromptRequest,
|
18 |
+
ResponseVoteRequest,
|
19 |
+
ResponseVoteResponse,
|
20 |
+
EnergyVoteRequest,
|
21 |
+
EnergyVoteResponse,
|
22 |
+
)
|
23 |
+
from spitfight.colosseum.controller.controller import (
|
24 |
+
Controller,
|
25 |
+
init_global_controller,
|
26 |
+
get_global_controller,
|
27 |
+
)
|
28 |
+
from spitfight.utils import prepend_generator
|
29 |
+
|
30 |
+
|
31 |
+
class ControllerConfig(BaseSettings):
|
32 |
+
"""Controller settings automatically loaded from environment variables."""
|
33 |
+
# Controller
|
34 |
+
background_task_interval: int = 300
|
35 |
+
max_num_req_states: int = 10000
|
36 |
+
req_state_expiration_time: int = 600
|
37 |
+
compose_files: list[str] = ["deployment/docker-compose-0.yaml", "deployment/docker-compose-1.yaml"]
|
38 |
+
|
39 |
+
# Logging
|
40 |
+
log_dir: str = "/logs"
|
41 |
+
controller_log_file: str = "controller.log"
|
42 |
+
request_log_file: str = "requests.log"
|
43 |
+
uvicorn_log_file: str = "uvicorn.log"
|
44 |
+
|
45 |
+
# Generation
|
46 |
+
max_new_tokens: int = 512
|
47 |
+
do_sample: bool = True
|
48 |
+
temperature: float = 1.0
|
49 |
+
repetition_penalty: float = 1.0
|
50 |
+
top_k: int = 50
|
51 |
+
top_p: float = 0.95
|
52 |
+
|
53 |
+
|
54 |
+
app = FastAPI()
|
55 |
+
settings = ControllerConfig()
|
56 |
+
logger = get_logger("spitfight.colosseum.controller.router")
|
57 |
+
|
58 |
+
@app.on_event("startup")
|
59 |
+
async def startup_event():
|
60 |
+
init_queued_root_logger("uvicorn", os.path.join(settings.log_dir, settings.uvicorn_log_file))
|
61 |
+
init_queued_root_logger("spitfight.colosseum.controller", os.path.join(settings.log_dir, settings.controller_log_file))
|
62 |
+
init_queued_root_logger("colosseum_requests", os.path.join(settings.log_dir, settings.request_log_file))
|
63 |
+
init_global_controller(settings)
|
64 |
+
|
65 |
+
@app.on_event("shutdown")
|
66 |
+
async def shutdown_event():
|
67 |
+
get_global_controller().shutdown()
|
68 |
+
shutdown_queued_root_loggers()
|
69 |
+
|
70 |
+
@app.post(COLOSSEUM_PROMPT_ROUTE)
|
71 |
+
async def prompt(
|
72 |
+
request: PromptRequest,
|
73 |
+
controller: Controller = Depends(get_global_controller),
|
74 |
+
):
|
75 |
+
generator = controller.prompt(request.request_id, request.prompt, request.model_index)
|
76 |
+
|
77 |
+
# First try to get the first token in order to catch TGI errors.
|
78 |
+
try:
|
79 |
+
first_token = await generator.__anext__()
|
80 |
+
except OverloadedError:
|
81 |
+
name = controller.request_states[request.request_id].model_names[request.model_index]
|
82 |
+
logger.warning("Model %s is overloaded. Failed request: %s", name, repr(request))
|
83 |
+
raise HTTPException(status_code=429, detail="Model overloaded. Pleaes try again later.")
|
84 |
+
except ValidationError as e:
|
85 |
+
logger.info("TGI returned validation error: %s. Failed request: %s", str(e), repr(request))
|
86 |
+
raise HTTPException(status_code=422, detail=str(e))
|
87 |
+
except StopAsyncIteration:
|
88 |
+
logger.info("TGI returned empty response. Failed request: %s", repr(request))
|
89 |
+
return StreamingResponse(
|
90 |
+
iter([json.dumps("*The model generated an empty response.*").encode() + b"\0"]),
|
91 |
+
)
|
92 |
+
except UnknownError as e:
|
93 |
+
logger.error("TGI returned unknown error: %s. Failed request: %s", str(e), repr(request))
|
94 |
+
raise HTTPException(status_code=500, detail=str(e))
|
95 |
+
|
96 |
+
return StreamingResponse(prepend_generator(first_token, generator))
|
97 |
+
|
98 |
+
@app.post(COLOSSEUM_RESP_VOTE_ROUTE, response_model=ResponseVoteResponse)
|
99 |
+
async def response_vote(
|
100 |
+
request: ResponseVoteRequest,
|
101 |
+
controller: Controller = Depends(get_global_controller),
|
102 |
+
):
|
103 |
+
if (state := controller.response_vote(request.request_id, request.victory_index)) is None:
|
104 |
+
raise HTTPException(status_code=410, detail="Colosseum battle session timeout expired.")
|
105 |
+
return ResponseVoteResponse(
|
106 |
+
energy_consumptions=state.energy_consumptions,
|
107 |
+
model_names=state.model_names,
|
108 |
+
)
|
109 |
+
|
110 |
+
@app.post(COLOSSEUM_ENERGY_VOTE_ROUTE, response_model=EnergyVoteResponse)
|
111 |
+
async def energy_vote(
|
112 |
+
request: EnergyVoteRequest,
|
113 |
+
controller: Controller = Depends(get_global_controller),
|
114 |
+
):
|
115 |
+
if (state := controller.energy_vote(request.request_id, request.is_worth)) is None:
|
116 |
+
raise HTTPException(status_code=410, detail="Colosseum battle session timeout expired.")
|
117 |
+
return EnergyVoteResponse(model_names=state.model_names)
|
118 |
+
|
119 |
+
@app.get(COLOSSEUM_HEALTH_ROUTE)
|
120 |
+
async def health():
|
121 |
+
return "OK"
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
uvicorn.run(app, host="0.0.0.0", log_config=None)
|
spitfight/colosseum/controller/worker.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
import random
|
3 |
+
import asyncio
|
4 |
+
from typing import Literal
|
5 |
+
from functools import cached_property
|
6 |
+
|
7 |
+
import httpx
|
8 |
+
from pydantic import BaseModel
|
9 |
+
from text_generation import AsyncClient
|
10 |
+
|
11 |
+
from spitfight.log import get_logger
|
12 |
+
|
13 |
+
logger = get_logger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
class Worker(BaseModel):
|
17 |
+
"""A worker that serves a model."""
|
18 |
+
# Worker's container name, since we're using Overlay networks.
|
19 |
+
hostname: str
|
20 |
+
# For TGI, this would always be 80.
|
21 |
+
port: int
|
22 |
+
# User-friendly model name, e.g. "metaai/llama2-13b-chat".
|
23 |
+
model_name: str
|
24 |
+
# Hugging Face model ID, e.g. "metaai/Llama-2-13b-chat-hf".
|
25 |
+
model_id: str
|
26 |
+
# Whether the model worker container is good.
|
27 |
+
status: Literal["up", "down"]
|
28 |
+
|
29 |
+
class Config:
|
30 |
+
keep_untouched = (cached_property,)
|
31 |
+
|
32 |
+
@cached_property
|
33 |
+
def url(self) -> str:
|
34 |
+
return f"http://{self.hostname}:{self.port}"
|
35 |
+
|
36 |
+
def get_client(self) -> AsyncClient:
|
37 |
+
return AsyncClient(base_url=self.url)
|
38 |
+
|
39 |
+
def audit(self) -> None:
|
40 |
+
"""Make sure the worker is running and information is as expected.
|
41 |
+
|
42 |
+
Assumed to be called on app startup when workers are initialized.
|
43 |
+
This method will just raise `ValueError`s if audit fails in order to
|
44 |
+
prevent the controller from starting if anything is wrong.
|
45 |
+
"""
|
46 |
+
try:
|
47 |
+
response = httpx.get(self.url + "/info")
|
48 |
+
except (httpx.ConnectError, httpx.TimeoutException) as e:
|
49 |
+
raise ValueError(f"Could not connect to {self!r}: {e!r}")
|
50 |
+
if response.status_code != 200:
|
51 |
+
raise ValueError(f"Could not get /info from {self!r}.")
|
52 |
+
info = response.json()
|
53 |
+
if info["model_id"] != self.model_id:
|
54 |
+
raise ValueError(f"Model name mismatch: {info['model_id']} != {self.model_id}")
|
55 |
+
self.status = "up"
|
56 |
+
logger.info("%s is up.", repr(self))
|
57 |
+
|
58 |
+
async def check_status(self) -> None:
|
59 |
+
"""Check worker status and update `self.status` accordingly."""
|
60 |
+
async with httpx.AsyncClient() as client:
|
61 |
+
try:
|
62 |
+
response = await client.get(self.url + "/info")
|
63 |
+
except (httpx.ConnectError, httpx.TimeoutException) as e:
|
64 |
+
self.status = "down"
|
65 |
+
logger.warning("%s is down: %s", repr(self), repr(e))
|
66 |
+
return
|
67 |
+
if response.status_code != 200:
|
68 |
+
self.status = "down"
|
69 |
+
logger.warning("GET /info from %s returned %s.", repr(self), response.json())
|
70 |
+
return
|
71 |
+
info = response.json()
|
72 |
+
if info["model_id"] != self.model_id:
|
73 |
+
self.status = "down"
|
74 |
+
logger.warning(
|
75 |
+
"Model name mismatch for %s: %s != %s",
|
76 |
+
repr(self),
|
77 |
+
info["model_id"],
|
78 |
+
self.model_id,
|
79 |
+
)
|
80 |
+
return
|
81 |
+
logger.info("%s is up.", repr(self))
|
82 |
+
self.status = "up"
|
83 |
+
|
84 |
+
|
85 |
+
class WorkerService:
|
86 |
+
"""A service that manages model serving workers.
|
87 |
+
|
88 |
+
Worker objects are only created once and shared across the
|
89 |
+
entire application. Especially, changing the status of a worker
|
90 |
+
will immediately take effect on the result of `choose_two`.
|
91 |
+
|
92 |
+
Attributes:
|
93 |
+
workers (list[Worker]): The list of workers.
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(self, compose_files: list[str]) -> None:
|
97 |
+
"""Initialize the worker service."""
|
98 |
+
self.workers: list[Worker] = []
|
99 |
+
worker_model_names = set()
|
100 |
+
for compose_file in compose_files:
|
101 |
+
spec = yaml.safe_load(open(compose_file))
|
102 |
+
for model_name, service_spec in spec["services"].items():
|
103 |
+
command = service_spec["command"]
|
104 |
+
for i, cmd in enumerate(command):
|
105 |
+
if cmd == "--model-id":
|
106 |
+
model_id = command[i + 1]
|
107 |
+
break
|
108 |
+
else:
|
109 |
+
raise ValueError(f"Could not find model ID in {command!r}")
|
110 |
+
worker_model_names.add(model_name)
|
111 |
+
worker = Worker(
|
112 |
+
hostname=service_spec["container_name"],
|
113 |
+
port=80,
|
114 |
+
model_name=model_name,
|
115 |
+
model_id=model_id,
|
116 |
+
status="down",
|
117 |
+
)
|
118 |
+
worker.audit()
|
119 |
+
self.workers.append(worker)
|
120 |
+
|
121 |
+
if len(worker_model_names) != len(self.workers):
|
122 |
+
raise ValueError("Model names must be unique.")
|
123 |
+
|
124 |
+
def get_worker(self, model_name: str) -> Worker:
|
125 |
+
"""Get a worker by model name."""
|
126 |
+
for worker in self.workers:
|
127 |
+
if worker.model_name == model_name:
|
128 |
+
if worker.status == "down":
|
129 |
+
# This is an unfortunate case where, when the two models were chosen,
|
130 |
+
# the worker was up, but after that went down before the request
|
131 |
+
# completed. We'll just raise a 500 internal error and have the user
|
132 |
+
# try again. This won't be common.
|
133 |
+
raise RuntimeError(f"The worker with model name {model_name} is down.")
|
134 |
+
return worker
|
135 |
+
raise ValueError(f"Worker with model name {model_name} does not exist.")
|
136 |
+
|
137 |
+
def choose_two(self) -> tuple[Worker, Worker]:
|
138 |
+
"""Choose two different workers.
|
139 |
+
|
140 |
+
Good place to use the Strategy Pattern when we want to
|
141 |
+
implement different strategies for choosing workers.
|
142 |
+
"""
|
143 |
+
live_workers = [worker for worker in self.workers if worker.status == "up"]
|
144 |
+
if len(live_workers) < 2:
|
145 |
+
raise ValueError("Not enough live workers to choose from.")
|
146 |
+
worker_a, worker_b = random.sample(live_workers, 2)
|
147 |
+
return worker_a, worker_b
|
148 |
+
|
149 |
+
async def check_workers(self) -> None:
|
150 |
+
"""Check the status of all workers."""
|
151 |
+
await asyncio.gather(*[worker.check_status() for worker in self.workers])
|
spitfight/log.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import queue
|
4 |
+
import logging
|
5 |
+
from logging.handlers import QueueHandler, QueueListener
|
6 |
+
|
7 |
+
ROOT_LOGGER_NAMES: list[str | None] = []
|
8 |
+
ROOT_LOGGER_QUEUE_LISTENERS: list[QueueListener] = []
|
9 |
+
|
10 |
+
|
11 |
+
def init_queued_root_logger(
|
12 |
+
name: str | None,
|
13 |
+
filepath: str,
|
14 |
+
level: int = logging.INFO,
|
15 |
+
) -> None:
|
16 |
+
"""Initialize a queue-based pseudo-root logger.
|
17 |
+
|
18 |
+
The pseudo-root logger will aggregate log messages from children
|
19 |
+
loggers under its namespace and send them to a queue. A QueueListener,
|
20 |
+
running in a separate thread, will then process the messages in the
|
21 |
+
queue and send them to the configured handlers.
|
22 |
+
"""
|
23 |
+
global ROOT_LOGGER_NAMES, ROOT_LOGGER_QUEUE_LISTENERS
|
24 |
+
|
25 |
+
# Make this function idempotent.
|
26 |
+
if name in ROOT_LOGGER_NAMES:
|
27 |
+
return
|
28 |
+
|
29 |
+
logger = logging.getLogger(name)
|
30 |
+
logger.setLevel(level)
|
31 |
+
logger.propagate = False
|
32 |
+
|
33 |
+
shared_queue = queue.SimpleQueue()
|
34 |
+
queue_handler = QueueHandler(shared_queue)
|
35 |
+
logger.addHandler(queue_handler)
|
36 |
+
|
37 |
+
formatter = logging.Formatter(
|
38 |
+
"[%(asctime)s] [%(levelname)s] [%(name)s](%(filename)s:%(lineno)d) %(message)s"
|
39 |
+
)
|
40 |
+
|
41 |
+
stderr_handler = logging.StreamHandler()
|
42 |
+
stderr_handler.setLevel(level)
|
43 |
+
stderr_handler.setFormatter(formatter)
|
44 |
+
|
45 |
+
file_handler = logging.FileHandler(filepath, encoding="utf-8")
|
46 |
+
file_handler.setLevel(level)
|
47 |
+
file_handler.setFormatter(formatter)
|
48 |
+
|
49 |
+
queue_listener = QueueListener(shared_queue, file_handler, stderr_handler)
|
50 |
+
queue_listener.start()
|
51 |
+
|
52 |
+
ROOT_LOGGER_NAMES.append(name)
|
53 |
+
ROOT_LOGGER_QUEUE_LISTENERS.append(queue_listener)
|
54 |
+
|
55 |
+
|
56 |
+
def shutdown_queued_root_loggers() -> None:
|
57 |
+
"""Shutdown all queue-based pseudo-root loggers.
|
58 |
+
|
59 |
+
This is necessary to make sure all log messages are flushed
|
60 |
+
before the application exits.
|
61 |
+
"""
|
62 |
+
for queue_listener in ROOT_LOGGER_QUEUE_LISTENERS:
|
63 |
+
queue_listener.stop()
|
64 |
+
|
65 |
+
|
66 |
+
def get_logger(name: str, level: int = logging.INFO) -> logging.Logger:
|
67 |
+
"""Setup a logger with the given name and level."""
|
68 |
+
# Don't reconfigure existing loggers.
|
69 |
+
if name in logging.Logger.manager.loggerDict:
|
70 |
+
return logging.getLogger(name)
|
71 |
+
|
72 |
+
logger = logging.getLogger(name)
|
73 |
+
logger.setLevel(level)
|
74 |
+
logger.propagate = True
|
75 |
+
|
76 |
+
return logger
|
spitfight/prompt.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""An abstraction layer for prompting different models."""
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import enum
|
6 |
+
|
7 |
+
from fastchat.model.model_adapter import get_conversation_template
|
8 |
+
|
9 |
+
|
10 |
+
class Task(enum.Enum):
|
11 |
+
"""Different system prompt styles."""
|
12 |
+
|
13 |
+
CHAT = "chat"
|
14 |
+
CHAT_CONCISE = "chat-concise"
|
15 |
+
INSTRUCT = "instruct"
|
16 |
+
INSTRUCT_CONCISE = "instruct-concise"
|
17 |
+
|
18 |
+
|
19 |
+
SYSTEM_PROMPTS = {
|
20 |
+
Task.CHAT: (
|
21 |
+
"A chat between a human user (prompter) and an artificial intelligence (AI) assistant. "
|
22 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
|
23 |
+
),
|
24 |
+
Task.CHAT_CONCISE: (
|
25 |
+
"A chat between a human user (prompter) and an artificial intelligence (AI) assistant. "
|
26 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
|
27 |
+
"The assistant's answers are very concise. "
|
28 |
+
),
|
29 |
+
Task.INSTRUCT: (
|
30 |
+
"Below is an instruction that describes a task. "
|
31 |
+
"Write a response that appropriately completes the request. "
|
32 |
+
),
|
33 |
+
Task.INSTRUCT_CONCISE: (
|
34 |
+
"Below is an instruction that describes a task. "
|
35 |
+
"Write a response that appropriately completes the request. "
|
36 |
+
"The response should be very concise. "
|
37 |
+
),
|
38 |
+
}
|
39 |
+
|
40 |
+
def get_system_prompt(task: Task | str) -> str:
|
41 |
+
"""Get the system prompt for a given task."""
|
42 |
+
if isinstance(task, str):
|
43 |
+
task = Task(task)
|
44 |
+
return SYSTEM_PROMPTS[task]
|
45 |
+
|
46 |
+
|
47 |
+
def apply_model_characteristics(
|
48 |
+
system_prompt: str,
|
49 |
+
prompt: str,
|
50 |
+
model_name: str,
|
51 |
+
) -> tuple[str, str | None, list[int]]:
|
52 |
+
"""Apply and return model-specific differences."""
|
53 |
+
conv = get_conversation_template(model_name)
|
54 |
+
|
55 |
+
if "llama-2" in model_name.lower():
|
56 |
+
conv.system = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
|
57 |
+
elif "stablelm" in model_name.lower():
|
58 |
+
conv.system = f"""<|SYSTEM|># {system_prompt}\n"""
|
59 |
+
else:
|
60 |
+
conv.system = system_prompt
|
61 |
+
conv.messages = []
|
62 |
+
conv.offset = 0
|
63 |
+
|
64 |
+
conv.append_message(conv.roles[0], prompt)
|
65 |
+
conv.append_message(conv.roles[1], "")
|
66 |
+
|
67 |
+
stop_str = None if conv.stop_str is None or not conv.stop_str else conv.stop_str
|
68 |
+
|
69 |
+
return conv.get_prompt(), stop_str, (conv.stop_token_ids or [])
|
spitfight/utils.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import time
|
4 |
+
import heapq
|
5 |
+
import asyncio
|
6 |
+
import unittest
|
7 |
+
from typing import TypeVar, Generic, AsyncGenerator, Any, Coroutine
|
8 |
+
|
9 |
+
from fastapi.logger import logger
|
10 |
+
|
11 |
+
K = TypeVar('K')
|
12 |
+
V = TypeVar('V')
|
13 |
+
|
14 |
+
|
15 |
+
class BoundedExpiringDict(Generic[K, V]):
|
16 |
+
def __init__(self, max_size: int, expiration_time: int) -> None:
|
17 |
+
self.data_dict: dict[K, V] = {}
|
18 |
+
self.timestamp_heap: list[tuple[float, K]] = []
|
19 |
+
self.timeout = expiration_time
|
20 |
+
|
21 |
+
# Without this, the controller is vulnerable to "user flood attacks,"
|
22 |
+
# where someone can create a bunch of users by polling /request before
|
23 |
+
# self.timeout expires and blow up memory.
|
24 |
+
self.max_size = max_size
|
25 |
+
|
26 |
+
def __getitem__(self, key: K) -> V:
|
27 |
+
return self.data_dict[key]
|
28 |
+
|
29 |
+
def __setitem__(self, key: K, value: V) -> None:
|
30 |
+
if len(self.data_dict) >= self.max_size:
|
31 |
+
self.cleanup()
|
32 |
+
|
33 |
+
heapq.heappush(self.timestamp_heap, (time.monotonic(), key))
|
34 |
+
self.data_dict[key] = value
|
35 |
+
|
36 |
+
def __delitem__(self, key: K) -> None:
|
37 |
+
# This is a bit inefficient, but it's not a common case operation.
|
38 |
+
# We still need to do this to keep timestamp_heap in sync.
|
39 |
+
del self.data_dict[key]
|
40 |
+
for i, (_, existing_key) in enumerate(self.timestamp_heap):
|
41 |
+
if existing_key == key:
|
42 |
+
del self.timestamp_heap[i]
|
43 |
+
break
|
44 |
+
heapq.heapify(self.timestamp_heap)
|
45 |
+
|
46 |
+
def __contains__(self, key: K) -> bool:
|
47 |
+
return key in self.data_dict
|
48 |
+
|
49 |
+
def __len__(self) -> int:
|
50 |
+
return len(self.data_dict)
|
51 |
+
|
52 |
+
def get(self, key: K, default: V | None = None) -> V | None:
|
53 |
+
return self.data_dict.get(key, default)
|
54 |
+
|
55 |
+
def pop(self, key: K, default: V | None = None) -> V | None:
|
56 |
+
item = self.data_dict.pop(key, default)
|
57 |
+
if item is not None:
|
58 |
+
for i, (_, existing_key) in enumerate(self.timestamp_heap):
|
59 |
+
if existing_key == key:
|
60 |
+
del self.timestamp_heap[i]
|
61 |
+
break
|
62 |
+
heapq.heapify(self.timestamp_heap)
|
63 |
+
return item
|
64 |
+
|
65 |
+
def cleanup(self) -> None:
|
66 |
+
now = time.monotonic()
|
67 |
+
# After the while loop, the dictionary will be smaller than max_size
|
68 |
+
# and all keys will have been accessed within the timeout.
|
69 |
+
while (self.timestamp_heap and now - self.timestamp_heap[0][0] > self.timeout) or len(self.data_dict) > self.max_size:
|
70 |
+
_, key = heapq.heappop(self.timestamp_heap)
|
71 |
+
del self.data_dict[key]
|
72 |
+
|
73 |
+
assert len(self.data_dict) == len(self.timestamp_heap)
|
74 |
+
|
75 |
+
|
76 |
+
T = TypeVar("T")
|
77 |
+
|
78 |
+
|
79 |
+
async def prepend_generator(
|
80 |
+
first_item: T,
|
81 |
+
generator: AsyncGenerator[T, None],
|
82 |
+
) -> AsyncGenerator[T, None]:
|
83 |
+
"""Prepend an item to an async generator."""
|
84 |
+
yield first_item
|
85 |
+
async for item in generator:
|
86 |
+
yield item
|
87 |
+
|
88 |
+
|
89 |
+
def create_task(coroutine: Coroutine[Any, Any, T]) -> asyncio.Task[T]:
|
90 |
+
"""Create an `asyncio.Task` but ensure that exceptions are logged.
|
91 |
+
|
92 |
+
Reference: https://quantlane.com/blog/ensure-asyncio-task-exceptions-get-logged/
|
93 |
+
"""
|
94 |
+
loop = asyncio.get_running_loop()
|
95 |
+
task = loop.create_task(coroutine)
|
96 |
+
task.add_done_callback(_handle_task_exception)
|
97 |
+
return task
|
98 |
+
|
99 |
+
|
100 |
+
def _handle_task_exception(task: asyncio.Task) -> None:
|
101 |
+
"""Print out exception and tracebook when a task dies with an exception."""
|
102 |
+
try:
|
103 |
+
task.result()
|
104 |
+
except asyncio.CancelledError:
|
105 |
+
# Cancellation should not be logged as an error.
|
106 |
+
pass
|
107 |
+
except Exception: # pylint: disable=broad-except
|
108 |
+
# `logger.exception` automatically handles exception and traceback info.
|
109 |
+
logger.exception("Job task died with an exception!")
|
110 |
+
|
111 |
+
|
112 |
+
class TokenGenerationBuffer:
|
113 |
+
"""A constant sized buffer for tokens, used to handle stop sequences.
|
114 |
+
|
115 |
+
Attributes:
|
116 |
+
token_buffer (str): Internal buffer for tokens.
|
117 |
+
matched_stop_str (bool): Whether the stop string has been seen. When this
|
118 |
+
is True, generation should stop and `pop` will always return None.
|
119 |
+
"""
|
120 |
+
def __init__(self, stop_str: str | None = None) -> None:
|
121 |
+
"""Initialize the buffer.
|
122 |
+
|
123 |
+
If `stop_str` is None, the buffer will just return all tokens as they come.
|
124 |
+
"""
|
125 |
+
self.stop_str = stop_str
|
126 |
+
self.token_len_list = []
|
127 |
+
self.token_buffer = ""
|
128 |
+
self.matched_stop_str = False
|
129 |
+
|
130 |
+
def append(self, text: str) -> None:
|
131 |
+
"""Append a token to the buffer."""
|
132 |
+
if self.stop_str is not None:
|
133 |
+
self.token_len_list.append(len(text))
|
134 |
+
self.token_buffer += text
|
135 |
+
|
136 |
+
def _pop_one(self) -> str:
|
137 |
+
"""Remove and return the first token in the buffer."""
|
138 |
+
token_len = self.token_len_list.pop(0)
|
139 |
+
token, self.token_buffer = self.token_buffer[:token_len], self.token_buffer[token_len:]
|
140 |
+
return token
|
141 |
+
|
142 |
+
def pop(self) -> str | None:
|
143 |
+
"""Try to pop a token from the buffer.
|
144 |
+
|
145 |
+
Return value None means that there is nothing to yield for now.
|
146 |
+
Repeated calls to this method will always just return None before more
|
147 |
+
tokens are appended to the buffer.
|
148 |
+
"""
|
149 |
+
# A short circuit for no stop string.
|
150 |
+
if self.stop_str is None:
|
151 |
+
return_buffer = self.token_buffer or None
|
152 |
+
self.token_buffer = ""
|
153 |
+
return return_buffer
|
154 |
+
|
155 |
+
if self.matched_stop_str:
|
156 |
+
return None
|
157 |
+
|
158 |
+
# The token buffer matched the stop string. We're done generating.
|
159 |
+
if self.stop_str == self.token_buffer:
|
160 |
+
self.matched_stop_str = True
|
161 |
+
return None
|
162 |
+
|
163 |
+
# The tokens in the buffer could potentially be part of the stop string.
|
164 |
+
# We'll stay put until we see more tokens. This also covers the case of
|
165 |
+
# empty token buffer.
|
166 |
+
if self.stop_str.startswith(self.token_buffer):
|
167 |
+
return None
|
168 |
+
|
169 |
+
# We can return tokens from the beginning of the buffer until the buffer
|
170 |
+
# is a prefix of the stop string.
|
171 |
+
return_buffer = ""
|
172 |
+
while self.token_buffer:
|
173 |
+
return_buffer += self._pop_one()
|
174 |
+
if self.stop_str == self.token_buffer:
|
175 |
+
self.matched_stop_str = True
|
176 |
+
break
|
177 |
+
if self.stop_str.startswith(self.token_buffer):
|
178 |
+
break
|
179 |
+
|
180 |
+
return return_buffer or None
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
class TestTokenGenerationBuffer(unittest.TestCase):
|
185 |
+
def test_basic1(self):
|
186 |
+
buffer = TokenGenerationBuffer(stop_str="stop")
|
187 |
+
|
188 |
+
buffer.append("hello")
|
189 |
+
self.assertEqual(buffer.pop(), "hello")
|
190 |
+
self.assertEqual(buffer.pop(), None)
|
191 |
+
self.assertFalse(buffer.matched_stop_str)
|
192 |
+
|
193 |
+
buffer.append("world")
|
194 |
+
self.assertEqual(buffer.pop(), "world")
|
195 |
+
self.assertFalse(buffer.matched_stop_str)
|
196 |
+
|
197 |
+
buffer.append("stop")
|
198 |
+
self.assertEqual(buffer.pop(), None)
|
199 |
+
self.assertTrue(buffer.matched_stop_str)
|
200 |
+
self.assertEqual(buffer.pop(), None)
|
201 |
+
self.assertTrue(buffer.matched_stop_str)
|
202 |
+
self.assertEqual(buffer.pop(), None)
|
203 |
+
self.assertTrue(buffer.matched_stop_str)
|
204 |
+
self.assertEqual(buffer.pop(), None)
|
205 |
+
self.assertTrue(buffer.matched_stop_str)
|
206 |
+
|
207 |
+
def test_basic2(self):
|
208 |
+
buffer = TokenGenerationBuffer(stop_str="stop")
|
209 |
+
|
210 |
+
buffer.append("hi")
|
211 |
+
self.assertEqual(buffer.pop(), "hi")
|
212 |
+
self.assertFalse(buffer.matched_stop_str)
|
213 |
+
|
214 |
+
buffer.append("stole")
|
215 |
+
self.assertEqual(buffer.pop(), "stole")
|
216 |
+
self.assertFalse(buffer.matched_stop_str)
|
217 |
+
|
218 |
+
buffer.append("sto")
|
219 |
+
self.assertEqual(buffer.pop(), None)
|
220 |
+
self.assertFalse(buffer.matched_stop_str)
|
221 |
+
|
222 |
+
buffer.append("ic")
|
223 |
+
self.assertEqual(buffer.pop(), "stoic")
|
224 |
+
self.assertFalse(buffer.matched_stop_str)
|
225 |
+
|
226 |
+
buffer.append("st")
|
227 |
+
self.assertEqual(buffer.pop(), None)
|
228 |
+
self.assertFalse(buffer.matched_stop_str)
|
229 |
+
|
230 |
+
buffer.append("opper")
|
231 |
+
self.assertEqual(buffer.pop(), "stopper")
|
232 |
+
self.assertFalse(buffer.matched_stop_str)
|
233 |
+
|
234 |
+
buffer.append("sto")
|
235 |
+
self.assertEqual(buffer.pop(), None)
|
236 |
+
self.assertFalse(buffer.matched_stop_str)
|
237 |
+
|
238 |
+
buffer.append("p")
|
239 |
+
self.assertEqual(buffer.pop(), None)
|
240 |
+
self.assertTrue(buffer.matched_stop_str)
|
241 |
+
|
242 |
+
def test_falcon1(self):
|
243 |
+
buffer = TokenGenerationBuffer(stop_str="\nUser")
|
244 |
+
|
245 |
+
buffer.append("Hi")
|
246 |
+
self.assertEqual(buffer.pop(), "Hi")
|
247 |
+
self.assertFalse(buffer.matched_stop_str)
|
248 |
+
|
249 |
+
buffer.append("!")
|
250 |
+
self.assertEqual(buffer.pop(), "!")
|
251 |
+
self.assertFalse(buffer.matched_stop_str)
|
252 |
+
|
253 |
+
buffer.append("\n")
|
254 |
+
self.assertEqual(buffer.pop(), None)
|
255 |
+
self.assertFalse(buffer.matched_stop_str)
|
256 |
+
|
257 |
+
buffer.append("User")
|
258 |
+
self.assertEqual(buffer.pop(), None)
|
259 |
+
self.assertTrue(buffer.matched_stop_str)
|
260 |
+
|
261 |
+
def test_falcon2(self):
|
262 |
+
buffer = TokenGenerationBuffer(stop_str="\nUser")
|
263 |
+
|
264 |
+
buffer.append("\n")
|
265 |
+
self.assertEqual(buffer.pop(), None)
|
266 |
+
self.assertFalse(buffer.matched_stop_str)
|
267 |
+
|
268 |
+
buffer.append("\n")
|
269 |
+
self.assertEqual(buffer.pop(), "\n")
|
270 |
+
self.assertFalse(buffer.matched_stop_str)
|
271 |
+
|
272 |
+
buffer.append("\n")
|
273 |
+
self.assertEqual(buffer.pop(), "\n")
|
274 |
+
self.assertFalse(buffer.matched_stop_str)
|
275 |
+
|
276 |
+
buffer.append("\n")
|
277 |
+
self.assertEqual(buffer.pop(), "\n")
|
278 |
+
self.assertFalse(buffer.matched_stop_str)
|
279 |
+
|
280 |
+
buffer.append("User")
|
281 |
+
self.assertEqual(buffer.pop(), None)
|
282 |
+
self.assertEqual(buffer.pop(), None)
|
283 |
+
self.assertTrue(buffer.matched_stop_str)
|
284 |
+
|
285 |
+
def test_no_stop_str(self):
|
286 |
+
buffer = TokenGenerationBuffer(stop_str=None)
|
287 |
+
|
288 |
+
buffer.append("hello")
|
289 |
+
self.assertEqual(buffer.pop(), "hello")
|
290 |
+
self.assertEqual(buffer.pop(), None)
|
291 |
+
self.assertFalse(buffer.matched_stop_str)
|
292 |
+
|
293 |
+
buffer.append("world")
|
294 |
+
self.assertEqual(buffer.pop(), "world")
|
295 |
+
self.assertEqual(buffer.pop(), None)
|
296 |
+
self.assertFalse(buffer.matched_stop_str)
|
297 |
+
|
298 |
+
buffer.append("\n")
|
299 |
+
self.assertEqual(buffer.pop(), "\n")
|
300 |
+
self.assertEqual(buffer.pop(), None)
|
301 |
+
self.assertFalse(buffer.matched_stop_str)
|
302 |
+
|
303 |
+
|
304 |
+
if __name__ == "__main__":
|
305 |
+
unittest.main()
|