Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
import pandas as pd | |
import streamlit as st | |
# from streamlit_tensorboard import st_tensorboard | |
from huggingface_hub import list_models | |
from huggingface_hub import HfApi | |
# ============================================================== | |
st.set_page_config(layout="wide") | |
# ============================================================== | |
logdir="/tmp/tensorboard_logs" | |
os.makedirs(logdir, exist_ok=True) | |
def clean_logdir(logdir): | |
try: | |
shutil.rmtree(logdir) | |
except Exception as e: | |
print(e) | |
def get_models(): | |
_author = "hahunavth" | |
_filter = "emofs2" | |
return list( | |
list_models(author=_author, filter=_filter, sort="last_modified", direction=-1) | |
) | |
TB_FILE_PREFIX = "events.out.tfevents" | |
def download_repoo_tb(repo_id, api, log_dir, df): | |
repo_name = repo_id.split("/")[-1] | |
if api.repo_exists(repo_id): | |
files = api.list_repo_files(repo_id) | |
else: | |
raise ValueError(f"Repo {repo_id} does not exist") | |
tb_files = [f for f in files if f.split('/')[-1].startswith(TB_FILE_PREFIX)] | |
tb_files_info = list(api.list_files_info(repo_id, tb_files)) | |
tb_files_info = [f for f in tb_files_info if f.size > 0] | |
for repo_file in tb_files_info: | |
path = repo_file.path | |
size = repo_file.size | |
stage = path.split('/')[-2] | |
fname = path.split('/')[-1] | |
sub_folder = path.replace(f"/{fname}", '') | |
if ((df["repo"]==repo_name) & (df["path"]==path) & (df["size"]==size)).any() and os.path.exists(os.path.join(log_dir, repo_name, path)): | |
print(f"Skipping {repo_name}/{path}") | |
continue | |
else: | |
print(f"Downloading {repo_name}/{path}") | |
api.hf_hub_download(repo_id=repo_id, filename=fname, subfolder=sub_folder, local_dir=os.path.join(log_dir, repo_name)) | |
new_df = pd.DataFrame([{ | |
"repo": repo_name, | |
"path": path, | |
"size": size, | |
}]) | |
df = pd.concat([df, new_df], ignore_index=True) | |
return df | |
def create_cache_dataframe(): | |
return pd.DataFrame(columns=["repo", "path", "size"]) | |
# ============================================================== | |
api = HfApi() | |
df = create_cache_dataframe() | |
models = get_models() | |
model_ids = [model.id for model in models] | |
# select many | |
with st.expander("Download tf", expanded=True): | |
with st.form("my_form"): | |
selected_models = st.multiselect("Select models", model_ids, default=None) | |
submit = st.form_submit_button("Download logs") | |
if submit: | |
# download tensorboard logs | |
with st.spinner("Downloading logs..."): | |
for model_id in selected_models: | |
st.write(f"Downloading logs for {model_id}") | |
df = download_repoo_tb(model_id, api, logdir, df) | |
st.write("Done") | |
clean_btn = st.button("Clean all") | |
if clean_btn: | |
clean_logdir(logdir) | |
create_cache_dataframe.clear() | |
get_models.clear() | |
# with st.expander("...", expanded=True): | |
# st_tensorboard(logdir=logdir, port=6006, width=1760, scrolling=False) | |
# st.text(st) |