arena / db.py
Kang Suhyun
[#37] Store ELO ratings in DB after calculation (#112)
5352a13 unverified
"""
This module handles the management of the database.
"""
from dataclasses import dataclass
import enum
import os
from typing import List
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
from google.cloud.firestore_v1 import base_query
import gradio as gr
from credentials import get_credentials_json
def get_required_env(name: str) -> str:
value = os.getenv(name)
if value is None:
raise ValueError(f"Environment variable {name} is not set")
return value
RATINGS_COLLECTION = get_required_env("RATINGS_COLLECTION")
SUMMARIZATIONS_COLLECTION = get_required_env("SUMMARIZATIONS_COLLECTION")
TRANSLATIONS_COLLECTION = get_required_env("TRANSLATIONS_COLLECTION")
if gr.NO_RELOAD:
firebase_admin.initialize_app(credentials.Certificate(get_credentials_json()))
db = firestore.client()
class Category(enum.Enum):
SUMMARIZATION = "summarization"
TRANSLATION = "translation"
@dataclass
class Rating:
model: str
rating: int
def get_ratings(category: Category, source_lang: str | None,
target_lang: str | None) -> List[Rating] | None:
doc_id = "#".join([category.value] +
[lang for lang in (source_lang, target_lang) if lang])
# TODO(#37): Make it more clear what fields are in the document.
doc_dict = db.collection(RATINGS_COLLECTION).document(doc_id).get().to_dict()
if doc_dict is None:
return None
# TODO(#37): Return the timestamp as well.
doc_dict.pop("timestamp")
return [Rating(model, rating) for model, rating in doc_dict.items()]
def set_ratings(category: Category, ratings: List[Rating], source_lang: str,
target_lang: str | None):
source_lang_lowercase = source_lang.lower()
target_lang_lowercase = target_lang.lower() if target_lang else None
doc_id = "#".join([category.value, source_lang_lowercase] +
([target_lang_lowercase] if target_lang_lowercase else []))
doc_ref = db.collection(RATINGS_COLLECTION).document(doc_id)
new_ratings = {rating.model: rating.rating for rating in ratings}
new_ratings["timestamp"] = firestore.SERVER_TIMESTAMP
doc_ref.set(new_ratings, merge=True)
@dataclass
class Battle:
model_a: str
model_b: str
winner: str
def get_battles(category: Category, source_lang: str | None,
target_lang: str | None) -> List[Battle]:
source_lang_lowercase = source_lang.lower() if source_lang else None
target_lang_lowercase = target_lang.lower() if target_lang else None
if category == Category.SUMMARIZATION:
collection = db.collection(SUMMARIZATIONS_COLLECTION).order_by("timestamp")
if source_lang_lowercase:
collection = collection.where(filter=base_query.FieldFilter(
"model_a_response_language", "==", source_lang_lowercase)).where(
filter=base_query.FieldFilter("model_b_response_language", "==",
source_lang_lowercase))
elif category == Category.TRANSLATION:
collection = db.collection(TRANSLATIONS_COLLECTION).order_by("timestamp")
if source_lang_lowercase:
collection = collection.where(filter=base_query.FieldFilter(
"source_language", "==", source_lang_lowercase))
if target_lang_lowercase:
collection = collection.where(filter=base_query.FieldFilter(
"target_language", "==", target_lang_lowercase))
else:
raise ValueError(f"Invalid category: {category}")
docs = collection.stream()
battles = []
for doc in docs:
data = doc.to_dict()
battles.append(Battle(data["model_a"], data["model_b"], data["winner"]))
return battles