Spaces:
Running
Running
""" | |
This module handles the management of the database. | |
""" | |
from dataclasses import dataclass | |
import enum | |
import os | |
from typing import List | |
import firebase_admin | |
from firebase_admin import credentials | |
from firebase_admin import firestore | |
from google.cloud.firestore_v1 import base_query | |
import gradio as gr | |
from credentials import get_credentials_json | |
def get_required_env(name: str) -> str: | |
value = os.getenv(name) | |
if value is None: | |
raise ValueError(f"Environment variable {name} is not set") | |
return value | |
RATINGS_COLLECTION = get_required_env("RATINGS_COLLECTION") | |
SUMMARIZATIONS_COLLECTION = get_required_env("SUMMARIZATIONS_COLLECTION") | |
TRANSLATIONS_COLLECTION = get_required_env("TRANSLATIONS_COLLECTION") | |
if gr.NO_RELOAD: | |
firebase_admin.initialize_app(credentials.Certificate(get_credentials_json())) | |
db = firestore.client() | |
class Category(enum.Enum): | |
SUMMARIZATION = "summarization" | |
TRANSLATION = "translation" | |
class Rating: | |
model: str | |
rating: int | |
def get_ratings(category: Category, source_lang: str | None, | |
target_lang: str | None) -> List[Rating] | None: | |
doc_id = "#".join([category.value] + | |
[lang for lang in (source_lang, target_lang) if lang]) | |
# TODO(#37): Make it more clear what fields are in the document. | |
doc_dict = db.collection(RATINGS_COLLECTION).document(doc_id).get().to_dict() | |
if doc_dict is None: | |
return None | |
# TODO(#37): Return the timestamp as well. | |
doc_dict.pop("timestamp") | |
return [Rating(model, rating) for model, rating in doc_dict.items()] | |
def set_ratings(category: Category, ratings: List[Rating], source_lang: str, | |
target_lang: str | None): | |
source_lang_lowercase = source_lang.lower() | |
target_lang_lowercase = target_lang.lower() if target_lang else None | |
doc_id = "#".join([category.value, source_lang_lowercase] + | |
([target_lang_lowercase] if target_lang_lowercase else [])) | |
doc_ref = db.collection(RATINGS_COLLECTION).document(doc_id) | |
new_ratings = {rating.model: rating.rating for rating in ratings} | |
new_ratings["timestamp"] = firestore.SERVER_TIMESTAMP | |
doc_ref.set(new_ratings, merge=True) | |
class Battle: | |
model_a: str | |
model_b: str | |
winner: str | |
def get_battles(category: Category, source_lang: str | None, | |
target_lang: str | None) -> List[Battle]: | |
source_lang_lowercase = source_lang.lower() if source_lang else None | |
target_lang_lowercase = target_lang.lower() if target_lang else None | |
if category == Category.SUMMARIZATION: | |
collection = db.collection(SUMMARIZATIONS_COLLECTION).order_by("timestamp") | |
if source_lang_lowercase: | |
collection = collection.where(filter=base_query.FieldFilter( | |
"model_a_response_language", "==", source_lang_lowercase)).where( | |
filter=base_query.FieldFilter("model_b_response_language", "==", | |
source_lang_lowercase)) | |
elif category == Category.TRANSLATION: | |
collection = db.collection(TRANSLATIONS_COLLECTION).order_by("timestamp") | |
if source_lang_lowercase: | |
collection = collection.where(filter=base_query.FieldFilter( | |
"source_language", "==", source_lang_lowercase)) | |
if target_lang_lowercase: | |
collection = collection.where(filter=base_query.FieldFilter( | |
"target_language", "==", target_lang_lowercase)) | |
else: | |
raise ValueError(f"Invalid category: {category}") | |
docs = collection.stream() | |
battles = [] | |
for doc in docs: | |
data = doc.to_dict() | |
battles.append(Battle(data["model_a"], data["model_b"], data["winner"])) | |
return battles | |