demo_test / monitor /clean_battle_data.py
yuantao-infini-ai's picture
Upload folder using huggingface_hub
cf1798b verified
"""
Clean chatbot arena battle log.
Usage:
python3 clean_battle_data.py --mode conv_release
"""
import argparse
import datetime
import json
import os
from pytz import timezone
import time
from tqdm import tqdm
from fastchat.serve.monitor.basic_stats import get_log_files, NUM_SERVERS
from fastchat.utils import detect_language
VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"]
IDENTITY_WORDS = [
"vicuna",
"lmsys",
"koala",
"uc berkeley",
"open assistant",
"laion",
"chatglm",
"chatgpt",
"openai",
"anthropic",
"claude",
"bard",
"palm",
"lamda",
"google",
"llama",
"NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.",
"$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES.",
]
for i in range(len(IDENTITY_WORDS)):
IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower()
def get_log_files(max_num_files=None):
dates = []
for month in range(4, 12):
for day in range(1, 33):
dates.append(f"2023-{month:02d}-{day:02d}")
filenames = []
for d in dates:
for i in range(NUM_SERVERS):
name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
if os.path.exists(name):
filenames.append(name)
max_num_files = max_num_files or len(filenames)
filenames = filenames[-max_num_files:]
return filenames
def remove_html(raw):
if raw.startswith("<h3>"):
return raw[raw.find(": ") + 2 : -len("</h3>\n")]
return raw
def to_openai_format(messages):
roles = ["user", "assistant"]
ret = []
for i, x in enumerate(messages):
ret.append({"role": roles[i % 2], "content": x[1]})
return ret
def replace_model_name(old_name):
return (
old_name.replace("bard", "palm-2")
.replace("claude-v1", "claude-1")
.replace("claude-instant-v1", "claude-instant-1")
.replace("oasst-sft-1-pythia-12b", "oasst-pythia-12b")
)
def clean_battle_data(log_files, exclude_model_names):
data = []
for filename in tqdm(log_files, desc="read files"):
for retry in range(5):
try:
lines = open(filename).readlines()
break
except FileNotFoundError:
time.sleep(2)
for l in lines:
row = json.loads(l)
if row["type"] in VOTES:
data.append(row)
convert_type = {
"leftvote": "model_a",
"rightvote": "model_b",
"tievote": "tie",
"bothbad_vote": "tie (bothbad)",
}
all_models = set()
all_ips = dict()
ct_anony = 0
ct_invalid = 0
ct_leaked_identity = 0
battles = []
for row in data:
if row["models"][0] is None or row["models"][1] is None:
continue
# Resolve model names
models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])]
if "model_name" in row["states"][0]:
models_hidden = [
row["states"][0]["model_name"],
row["states"][1]["model_name"],
]
if models_hidden[0] is None:
models_hidden = models_public
else:
models_hidden = models_public
if (models_public[0] == "" and models_public[1] != "") or (
models_public[1] == "" and models_public[0] != ""
):
ct_invalid += 1
continue
if models_public[0] == "" or models_public[0] == "Model A":
anony = True
models = models_hidden
ct_anony += 1
else:
anony = False
models = models_public
if not models_public == models_hidden:
ct_invalid += 1
continue
# Detect langauge
state = row["states"][0]
if state["offset"] >= len(state["messages"]):
ct_invalid += 1
continue
lang_code = detect_language(state["messages"][state["offset"]][1])
# Drop conversations if the model names are leaked
leaked_identity = False
messages = ""
for i in range(2):
state = row["states"][i]
for role, msg in state["messages"][state["offset"] :]:
if msg:
messages += msg.lower()
for word in IDENTITY_WORDS:
if word in messages:
leaked_identity = True
break
if leaked_identity:
ct_leaked_identity += 1
continue
# Replace bard with palm
models = [replace_model_name(m) for m in models]
# Exclude certain models
if any(x in exclude_model_names for x in models):
ct_invalid += 1
continue
question_id = row["states"][0]["conv_id"]
conversation_a = to_openai_format(
row["states"][0]["messages"][row["states"][0]["offset"] :]
)
conversation_b = to_openai_format(
row["states"][1]["messages"][row["states"][1]["offset"] :]
)
ip = row["ip"]
if ip not in all_ips:
all_ips[ip] = len(all_ips)
user_id = all_ips[ip]
# Save the results
battles.append(
dict(
question_id=question_id,
model_a=models[0],
model_b=models[1],
winner=convert_type[row["type"]],
judge=f"arena_user_{user_id}",
conversation_a=conversation_a,
conversation_b=conversation_b,
turn=len(conversation_a) // 2,
anony=anony,
language=lang_code,
tstamp=row["tstamp"],
)
)
all_models.update(models_hidden)
battles.sort(key=lambda x: x["tstamp"])
last_updated_tstamp = battles[-1]["tstamp"]
last_updated_datetime = datetime.datetime.fromtimestamp(
last_updated_tstamp, tz=timezone("US/Pacific")
).strftime("%Y-%m-%d %H:%M:%S %Z")
print(
f"#votes: {len(data)}, #invalid votes: {ct_invalid}, "
f"#leaked_identity: {ct_leaked_identity}"
)
print(f"#battles: {len(battles)}, #anony: {ct_anony}")
print(f"#models: {len(all_models)}, {all_models}")
print(f"last-updated: {last_updated_datetime}")
return battles
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--max-num-files", type=int)
parser.add_argument(
"--mode", type=str, choices=["simple", "conv_release"], default="simple"
)
parser.add_argument("--exclude-model-names", type=str, nargs="+")
args = parser.parse_args()
log_files = get_log_files(args.max_num_files)
battles = clean_battle_data(log_files, args.exclude_model_names or [])
last_updated_tstamp = battles[-1]["tstamp"]
cutoff_date = datetime.datetime.fromtimestamp(
last_updated_tstamp, tz=timezone("US/Pacific")
).strftime("%Y%m%d")
if args.mode == "simple":
for x in battles:
for key in [
"conversation_a",
"conversation_b",
"question_id",
]:
del x[key]
print("Samples:")
for i in range(4):
print(battles[i])
output = f"clean_battle_{cutoff_date}.json"
elif args.mode == "conv_release":
new_battles = []
for x in battles:
if not x["anony"]:
continue
for key in []:
del x[key]
new_battles.append(x)
battles = new_battles
output = f"clean_battle_conv_{cutoff_date}.json"
with open(output, "w") as fout:
json.dump(battles, fout, indent=2, ensure_ascii=False)
print(f"Write cleaned data to {output}")