File size: 5,271 Bytes
13e8963 f2d4743 13e8963 e698b42 dfae691 13e8963 f2d4743 13e8963 c28665f 13e8963 f2d4743 c28665f f2d4743 c28665f 13e8963 f2d4743 13e8963 f2d4743 13e8963 c28665f 13e8963 ca2e2c2 13e8963 5e531ec f2d4743 5e531ec 13e8963 f2d4743 c28665f f2d4743 13e8963 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr # type: ignore
import plotly.express as px # type: ignore
from backend.data import load_cot_data
from backend.envs import API, REPO_ID, TOKEN
logo1_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/AI2_Logo_Square.png"
logo2_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/logo_logikon_notext_withborder.png"
LOGOS = f'<div style="display: flex; justify-content: center;"><a href="https://allenai.org/"><img src="{logo1_url}" alt="AI2" style="width: 30vw; min-width: 20px; max-width: 60px;"></a> <a href="https://logikon.ai"><img src="{logo2_url}" alt="Logikon AI" style="width: 30vw; min-width: 20px; max-width: 60px; margin-left: 10px;"></a></div>'
TITLE = f'<h1 align="center" id="space-title"> Open CoT Dashboard</h1> {LOGOS}'
INTRODUCTION_TEXT = """
Baseline accuracies and marginal accuracy gains for specific models and CoT regimes from the [Open CoT Leaderboard](https://huggingface.co/spaces/logikon/open_cot_leaderboard).
"""
def restart_space():
API.restart_space(repo_id=REPO_ID, token=TOKEN)
try:
df_cot_err, df_cot_regimes = load_cot_data()
except Exception as err:
print(err)
# sleep for 10 seconds before restarting the space
import time
time.sleep(10)
restart_space()
def plot_evals_init(model_id, regex_model_filter, plotly_mode, request: gr.Request):
if request and "model" in request.query_params:
model_param = request.query_params["model"]
if model_param in df_cot_err.model.to_list():
model_id = model_param
return plot_evals(model_id, regex_model_filter, plotly_mode)
def plot_evals(model_id, regex_model_filter, plotly_mode):
df = df_cot_err.copy()
df["selected"] = df_cot_err.model.apply(lambda x: "selected" if x==model_id else "-")
try:
df_filter = df.model.str.contains(regex_model_filter)
except Exception as err:
gr.Warning("Failed to apply regex filter", duration=4)
print("Failed to apply regex filter" + err)
df_filter = df.model.str.contains(".*")
df = df[df_filter | df.selected.eq("selected")]
#df.sort_values(["selected", "model"], inplace=True, ascending=True) # has currently no effect with px.scatter
template = "plotly_dark" if plotly_mode=="dark" else "plotly"
fig = px.scatter(df, x="base accuracy", y="marginal acc. gain", color="selected", symbol="model",
facet_col="task", facet_col_wrap=3,
category_orders={"selected": ["selected", "-"]},
color_discrete_sequence=["Orange", "Gray"],
template=template,
error_y="acc_gain-err", hover_data=['model', "cot accuracy"],
width=1200, height=700)
fig.update_layout(
title={"automargin": True},
)
return fig, model_id
def styled_model_table_init(model_id, request: gr.Request):
if request and "model" in request.query_params:
model_param = request.query_params["model"]
if model_param in df_cot_regimes.model.to_list():
model_id = model_param
return styled_model_table(model_id)
def styled_model_table(model_id):
def make_pretty(styler):
styler.hide(axis="index")
styler.format(precision=1),
styler.background_gradient(
axis=None,
subset=["acc_base", "acc_cot"],
vmin=20, vmax=100, cmap="YlGnBu"
)
styler.background_gradient(
axis=None,
subset=["acc_gain"],
vmin=-20, vmax=20, cmap="coolwarm"
)
styler.set_table_styles({
'task': [{'selector': '',
'props': [('font-weight', 'bold')]}],
'B': [{'selector': 'td',
'props': 'color: blue;'}]
}, overwrite=False)
return styler
df_cot_model = df_cot_regimes[df_cot_regimes.model.eq(model_id)][['task', 'cot_chain', 'best_of',
'temperature', 'top_k', 'top_p', 'acc_base', 'acc_cot', 'acc_gain']]
df_cot_model = df_cot_model \
.rename(columns={"temperature": "temp"}) \
.replace({'cot_chain': 'ReflectBeforeRun'}, "Reflect") \
.sort_values(["task", "cot_chain"]) \
.reset_index(drop=True)
return df_cot_model.style.pipe(make_pretty)
demo = gr.Blocks()
with demo:
gr.HTML(TITLE)
gr.Markdown(INTRODUCTION_TEXT)
with gr.Row():
selected_model = gr.Dropdown(list(df_cot_err.model.unique()), value="allenai/tulu-2-70b", label="Model", info="with performance details below", scale=2)
regex_model_filter = gr.Textbox(".*", label="Regex", info="to filter models shown in plots", scale=2)
plotly_mode = gr.Radio(["dark","light"], value="light", label="Theme", info="of plots", scale=1)
submit = gr.Button("Update", scale=1)
table = gr.DataFrame()
plot = gr.Plot(label="evals")
submit.click(plot_evals, [selected_model, regex_model_filter, plotly_mode], [plot, selected_model])
submit.click(styled_model_table, selected_model, table)
demo.load(plot_evals_init, [selected_model, regex_model_filter, plotly_mode], [plot, selected_model])
demo.load(styled_model_table_init, selected_model, table)
demo.launch() |