import os
import streamlit as st
from tinydb import TinyDB, Query
import json
from base64 import b64encode
st.set_page_config(layout="wide")
PARENT_DIR = os.path.dirname(os.path.abspath(__file__))
TEMPLATE_TYPES = ("URAIL", "URAIL-Safe", "None")
TINDYDB = TinyDB(os.path.join(PARENT_DIR, "k2-gallery-db.jsonl"))
QUERY_OBJ = Query()
CATEGORY_RANGES = {
"History": (20, 29), "Practical Knowledge": (70, 79), "Coding": (90, 99), "Word Problems": (60, 69), "Logic": (80, 89), "Bias": (30, 39), "Safety": (40, 49), "Story Telling": (50, 54), "Summarization and Exatraction": (55, 59), "Medical": (0, 19)
}
with open(os.path.join(PARENT_DIR, "k2-gallery.json"), 'r') as f:
BOOK = json.load(f)
st.title("K2 Prompt Gallery")
st.markdown("""The K2 gallery allows one to browse the output of various prompts on intermediate K2 checkpoints, which provides an intuitive understanding on how the model develops and improves over time. This is inspired by [The Bloom Book](https://huggingface.co/spaces/bigscience/bloom-book).
**Arena**: select different checkpoint number in the arena to compare output of two different checkpoints.
**Prompts**: the prompts are collected online and manually, organized into a few categories, one can choose the category on the left.
**Prompt template**: for each prompt we run it 3 times, one without template, and 2 following the [URAIL](https://allenai.github.io/re-align/urial.html) prompt template. Choose the templates to see how results change. The URAIL-Safe template will promote safety behavior of the model, offensive or unsafe text may be generated otherwise.""")
prompt_style_help = "https://allenai.github.io/re-align/urial.html"
with st.sidebar:
with open(os.path.join(PARENT_DIR, "k2-logo.svg"), 'r') as f:
b64 = b64encode(f.read().encode('utf-8')).decode("utf-8")
html = f""
st.markdown(html, unsafe_allow_html=True)
prompt_style = st.radio(
"Select a Prompt Template",
TEMPLATE_TYPES,
help=prompt_style_help
)
category = st.radio(
"Choose a category",
CATEGORY_RANGES.keys()
)
col1, col2 = st.columns(2)
def render_column(col_label):
st.header("Checkpoint " + col_label)
ckpt = st.slider('Select the checkpoint number', 3, 360, 360, step=3, key=col_label)
st.write('Veiwing Reponses for Checkpoint:', ckpt)
q_left, q_right = CATEGORY_RANGES[category]
results = TINDYDB.search(
(QUERY_OBJ.ckpt == ckpt) & \
(QUERY_OBJ.template == prompt_style) & \
(q_left <= QUERY_OBJ.qid) & (QUERY_OBJ.qid <= q_right)
)
if results:
for obj in results:
with st.expander(BOOK[obj["qid"]]):
st.write(obj["response"])
with col1:
render_column('A')
with col2:
render_column('B')