clefourrier HF staff alozowski HF staff commited on
Commit
6b87e28
1 Parent(s): 8ff5577

dataframe-improvement (#671)

Browse files

- Updated init_space() mostly (e34e357137b1ac54e7f2db292b77c14d4c7cf0ed)
- Updated collections.py (2293858bb036fc9f69040d0210b6db1678b7bdf9)
- Updated populate.py (6b9cbbe716f8f1b4c4d5c3925fbc1d1c27381b5f)
- Updated gitignore (122c7afd045b064431b1ae27c3c543c9dbd1a482)
- bugfix and populate refactoring (2e74c81428ac062c254bc55b88eadf06d877f532)
- updated utils.py (f073c67652ed110738fb31ccb2abf2dd2c2b5156)
- removed comments from populate.py (79ad1ade160afd2ba0f95bd1dec9e8534121f132)
- fixing envs CACHE_PATH check (63dac32758e6a31233c5c57913eda3b53e53c266)
- debugging CACHE_PATH in envs.py (6a5081fbccfd95fb301ba4d8cb446e2c101b337c)
- debugging CACHE_PATH in envs.py (e243a5f654ee69c2a62cb3dd438a7dedc3631c22)
- debugging CACHE_PATH in envs.py (5a8f7dc96273e99cd0895dc13e4a4e476c1eb629)
- small fixed (d8bf61b20d803025270c3395b0b0bf1d68af5576)


Co-authored-by: Alina Lozovskaya <[email protected]>

Files changed (7) hide show
  1. .gitignore +5 -0
  2. .python-version +0 -1
  3. app.py +39 -45
  4. src/display/utils.py +11 -1
  5. src/envs.py +15 -4
  6. src/populate.py +38 -51
  7. src/tools/collections.py +48 -53
.gitignore CHANGED
@@ -1,10 +1,15 @@
1
  venv/
 
2
  __pycache__/
3
  .env
4
  .ipynb_checkpoints
5
  *ipynb
6
  .vscode/
7
  .DS_Store
 
 
 
 
8
 
9
  eval-queue/
10
  eval-results/
 
1
  venv/
2
+ .venv/
3
  __pycache__/
4
  .env
5
  .ipynb_checkpoints
6
  *ipynb
7
  .vscode/
8
  .DS_Store
9
+ .ruff_cache/
10
+ .python-version
11
+ .profile_app.python
12
+ *pstats
13
 
14
  eval-queue/
15
  eval-results/
.python-version DELETED
@@ -1 +0,0 @@
1
- 3.10.0
 
 
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import gradio as gr
3
  import pandas as pd
4
  from apscheduler.schedulers.background import BackgroundScheduler
@@ -47,6 +48,7 @@ from src.submission.submit import add_new_eval
47
  from src.tools.collections import update_collections
48
  from src.tools.plots import create_metric_plot_obj, create_plot_df, create_scores_df
49
 
 
50
  # Start ephemeral Spaces on PRs (see config in README.md)
51
  enable_space_ci()
52
 
@@ -55,44 +57,34 @@ def restart_space():
55
  API.restart_space(repo_id=REPO_ID, token=H4_TOKEN)
56
 
57
 
58
- def init_space(full_init: bool = True):
59
- if full_init:
60
- try:
61
- print(EVAL_REQUESTS_PATH)
62
- snapshot_download(
63
- repo_id=QUEUE_REPO,
64
- local_dir=EVAL_REQUESTS_PATH,
65
- repo_type="dataset",
66
- tqdm_class=None,
67
- etag_timeout=30,
68
- max_workers=8,
69
- )
70
- except Exception:
71
- restart_space()
72
  try:
73
- print(DYNAMIC_INFO_PATH)
74
  snapshot_download(
75
- repo_id=DYNAMIC_INFO_REPO,
76
- local_dir=DYNAMIC_INFO_PATH,
77
- repo_type="dataset",
78
  tqdm_class=None,
79
  etag_timeout=30,
80
  max_workers=8,
81
  )
82
- except Exception:
83
- restart_space()
84
- try:
85
- print(EVAL_RESULTS_PATH)
86
- snapshot_download(
87
- repo_id=RESULTS_REPO,
88
- local_dir=EVAL_RESULTS_PATH,
89
- repo_type="dataset",
90
- tqdm_class=None,
91
- etag_timeout=30,
92
- max_workers=8,
93
- )
94
- except Exception:
95
- restart_space()
96
 
97
  raw_data, original_df = get_leaderboard_df(
98
  results_path=EVAL_RESULTS_PATH,
@@ -101,18 +93,12 @@ def init_space(full_init: bool = True):
101
  cols=COLS,
102
  benchmark_cols=BENCHMARK_COLS,
103
  )
104
- update_collections(original_df.copy())
105
  leaderboard_df = original_df.copy()
 
 
106
 
107
- plot_df = create_plot_df(create_scores_df(raw_data))
108
-
109
- (
110
- finished_eval_queue_df,
111
- running_eval_queue_df,
112
- pending_eval_queue_df,
113
- ) = get_evaluation_queue_df(EVAL_REQUESTS_PATH, EVAL_COLS)
114
-
115
- return leaderboard_df, original_df, plot_df, finished_eval_queue_df, running_eval_queue_df, pending_eval_queue_df
116
 
117
 
118
  # Convert the environment variable "LEADERBOARD_FULL_INIT" to a boolean value, defaulting to True if the variable is not set.
@@ -121,9 +107,14 @@ do_full_init = os.getenv("LEADERBOARD_FULL_INIT", "True") == "True"
121
 
122
  # Calls the init_space function with the `full_init` parameter determined by the `do_full_init` variable.
123
  # This initializes various DataFrames used throughout the application, with the level of initialization detail controlled by the `do_full_init` flag.
124
- leaderboard_df, original_df, plot_df, finished_eval_queue_df, running_eval_queue_df, pending_eval_queue_df = (
125
- init_space(full_init=do_full_init)
126
- )
 
 
 
 
 
127
 
128
 
129
  # Searching and filtering
@@ -406,6 +397,7 @@ with demo:
406
  with gr.TabItem("📈 Metrics through time", elem_id="llm-benchmark-tab-table", id=2):
407
  with gr.Row():
408
  with gr.Column():
 
409
  chart = create_metric_plot_obj(
410
  plot_df,
411
  [AutoEvalColumn.average.name],
@@ -413,12 +405,14 @@ with demo:
413
  )
414
  gr.Plot(value=chart, min_width=500)
415
  with gr.Column():
 
416
  chart = create_metric_plot_obj(
417
  plot_df,
418
  BENCHMARK_COLS,
419
  title="Top Scores and Human Baseline Over Time (from last update)",
420
  )
421
  gr.Plot(value=chart, min_width=500)
 
422
  with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=3):
423
  gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")
424
 
 
1
  import os
2
+ import logging
3
  import gradio as gr
4
  import pandas as pd
5
  from apscheduler.schedulers.background import BackgroundScheduler
 
48
  from src.tools.collections import update_collections
49
  from src.tools.plots import create_metric_plot_obj, create_plot_df, create_scores_df
50
 
51
+
52
  # Start ephemeral Spaces on PRs (see config in README.md)
53
  enable_space_ci()
54
 
 
57
  API.restart_space(repo_id=REPO_ID, token=H4_TOKEN)
58
 
59
 
60
+ def download_dataset(repo_id, local_dir, repo_type="dataset", max_attempts=3):
61
+ """Attempt to download dataset with retries."""
62
+ attempt = 0
63
+ while attempt < max_attempts:
 
 
 
 
 
 
 
 
 
 
64
  try:
65
+ print(f"Downloading {repo_id} to {local_dir}")
66
  snapshot_download(
67
+ repo_id=repo_id,
68
+ local_dir=local_dir,
69
+ repo_type=repo_type,
70
  tqdm_class=None,
71
  etag_timeout=30,
72
  max_workers=8,
73
  )
74
+ return
75
+ except Exception as e:
76
+ logging.error(f"Error downloading {repo_id}: {e}")
77
+ attempt += 1
78
+ if attempt == max_attempts:
79
+ restart_space()
80
+
81
+
82
+ def init_space(full_init: bool = True):
83
+ """Initializes the application space, loading only necessary data."""
84
+ if full_init:
85
+ download_dataset(QUEUE_REPO, EVAL_REQUESTS_PATH)
86
+ download_dataset(DYNAMIC_INFO_REPO, DYNAMIC_INFO_PATH)
87
+ download_dataset(RESULTS_REPO, EVAL_RESULTS_PATH)
88
 
89
  raw_data, original_df = get_leaderboard_df(
90
  results_path=EVAL_RESULTS_PATH,
 
93
  cols=COLS,
94
  benchmark_cols=BENCHMARK_COLS,
95
  )
96
+ update_collections(original_df)
97
  leaderboard_df = original_df.copy()
98
+
99
+ eval_queue_dfs = get_evaluation_queue_df(EVAL_REQUESTS_PATH, EVAL_COLS)
100
 
101
+ return leaderboard_df, raw_data, original_df, eval_queue_dfs
 
 
 
 
 
 
 
 
102
 
103
 
104
  # Convert the environment variable "LEADERBOARD_FULL_INIT" to a boolean value, defaulting to True if the variable is not set.
 
107
 
108
  # Calls the init_space function with the `full_init` parameter determined by the `do_full_init` variable.
109
  # This initializes various DataFrames used throughout the application, with the level of initialization detail controlled by the `do_full_init` flag.
110
+ leaderboard_df, raw_data, original_df, eval_queue_dfs = init_space(full_init=do_full_init)
111
+ finished_eval_queue_df, running_eval_queue_df, pending_eval_queue_df = eval_queue_dfs
112
+
113
+
114
+ # Data processing for plots now only on demand in the respective Gradio tab
115
+ def load_and_create_plots():
116
+ plot_df = create_plot_df(create_scores_df(raw_data))
117
+ return plot_df
118
 
119
 
120
  # Searching and filtering
 
397
  with gr.TabItem("📈 Metrics through time", elem_id="llm-benchmark-tab-table", id=2):
398
  with gr.Row():
399
  with gr.Column():
400
+ plot_df = load_and_create_plots()
401
  chart = create_metric_plot_obj(
402
  plot_df,
403
  [AutoEvalColumn.average.name],
 
405
  )
406
  gr.Plot(value=chart, min_width=500)
407
  with gr.Column():
408
+ plot_df = load_and_create_plots()
409
  chart = create_metric_plot_obj(
410
  plot_df,
411
  BENCHMARK_COLS,
412
  title="Top Scores and Human Baseline Over Time (from last update)",
413
  )
414
  gr.Plot(value=chart, min_width=500)
415
+
416
  with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=3):
417
  gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")
418
 
src/display/utils.py CHANGED
@@ -1,9 +1,19 @@
1
  from dataclasses import dataclass, make_dataclass
2
  from enum import Enum
3
-
4
  import pandas as pd
5
 
6
 
 
 
 
 
 
 
 
 
 
 
7
  def fields(raw_class):
8
  return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"]
9
 
 
1
  from dataclasses import dataclass, make_dataclass
2
  from enum import Enum
3
+ import json
4
  import pandas as pd
5
 
6
 
7
+ def load_json_data(file_path):
8
+ """Safely load JSON data from a file."""
9
+ try:
10
+ with open(file_path, "r") as file:
11
+ return json.load(file)
12
+ except json.JSONDecodeError:
13
+ print(f"Error reading JSON from {file_path}")
14
+ return None # Or raise an exception
15
+
16
+
17
  def fields(raw_class):
18
  return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"]
19
 
src/envs.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
 
3
  from huggingface_hub import HfApi
4
 
@@ -15,11 +16,21 @@ PRIVATE_RESULTS_REPO = "open-llm-leaderboard/private-results"
15
 
16
  IS_PUBLIC = bool(os.environ.get("IS_PUBLIC", True))
17
 
18
- CACHE_PATH = os.getenv("HF_HOME", ".")
19
 
20
- EVAL_REQUESTS_PATH = os.path.join(CACHE_PATH, "eval-queue")
21
- EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval-results")
22
- DYNAMIC_INFO_PATH = os.path.join(CACHE_PATH, "dynamic-info")
 
 
 
 
 
 
 
 
 
 
23
  DYNAMIC_INFO_FILE_PATH = os.path.join(DYNAMIC_INFO_PATH, "model_infos.json")
24
 
25
  EVAL_REQUESTS_PATH_PRIVATE = "eval-queue-private"
 
1
  import os
2
+ import logging
3
 
4
  from huggingface_hub import HfApi
5
 
 
16
 
17
  IS_PUBLIC = bool(os.environ.get("IS_PUBLIC", True))
18
 
19
+ HF_HOME = os.getenv("HF_HOME", ".")
20
 
21
+ # Check HF_HOME write access
22
+ print(f"Initial HF_HOME set to: {HF_HOME}")
23
+
24
+ if not os.access(HF_HOME, os.W_OK):
25
+ print(f"No write access to HF_HOME: {HF_HOME}. Resetting to current directory.")
26
+ HF_HOME = "."
27
+ os.environ["HF_HOME"] = HF_HOME
28
+ else:
29
+ print(f"Write access confirmed for HF_HOME")
30
+
31
+ EVAL_REQUESTS_PATH = os.path.join(HF_HOME, "eval-queue")
32
+ EVAL_RESULTS_PATH = os.path.join(HF_HOME, "eval-results")
33
+ DYNAMIC_INFO_PATH = os.path.join(HF_HOME, "dynamic-info")
34
  DYNAMIC_INFO_FILE_PATH = os.path.join(DYNAMIC_INFO_PATH, "model_infos.json")
35
 
36
  EVAL_REQUESTS_PATH_PRIVATE = "eval-queue-private"
src/populate.py CHANGED
@@ -1,68 +1,55 @@
1
  import json
2
  import os
3
-
4
  import pandas as pd
5
-
6
  from src.display.formatting import has_no_nan_values, make_clickable_model
7
  from src.display.utils import AutoEvalColumn, EvalQueueColumn, baseline_row
8
  from src.leaderboard.filter_models import filter_models_flags
9
  from src.leaderboard.read_evals import get_raw_eval_results
 
 
 
 
 
 
 
 
10
 
11
 
12
- def get_leaderboard_df(
13
- results_path: str, requests_path: str, dynamic_path: str, cols: list, benchmark_cols: list
14
- ) -> pd.DataFrame:
15
- raw_data = get_raw_eval_results(results_path=results_path, requests_path=requests_path, dynamic_path=dynamic_path)
16
- all_data_json = [v.to_dict() for v in raw_data]
17
- all_data_json.append(baseline_row)
18
- print([data for data in all_data_json if data["model_name_for_query"] == "databricks/dbrx-base"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  filter_models_flags(all_data_json)
20
 
21
  df = pd.DataFrame.from_records(all_data_json)
22
- print(df.columns)
23
- print(df[df["model_name_for_query"] == "databricks/dbrx-base"])
24
  df = df.sort_values(by=[AutoEvalColumn.average.name], ascending=False)
25
  df = df[cols].round(decimals=2)
26
-
27
- # filter out if any of the benchmarks have not been produced
28
  df = df[has_no_nan_values(df, benchmark_cols)]
29
  return raw_data, df
30
 
31
-
32
- def get_evaluation_queue_df(save_path: str, cols: list) -> list[pd.DataFrame]:
33
- entries = [entry for entry in os.listdir(save_path) if not entry.startswith(".")]
34
- all_evals = []
35
-
36
- for entry in entries:
37
- if ".json" in entry:
38
- file_path = os.path.join(save_path, entry)
39
- with open(file_path) as fp:
40
- data = json.load(fp)
41
-
42
- data[EvalQueueColumn.model.name] = make_clickable_model(data["model"])
43
- data[EvalQueueColumn.revision.name] = data.get("revision", "main")
44
-
45
- all_evals.append(data)
46
- elif ".md" not in entry:
47
- # this is a folder
48
- sub_entries = [e for e in os.listdir(f"{save_path}/{entry}") if not e.startswith(".")]
49
- for sub_entry in sub_entries:
50
- file_path = os.path.join(save_path, entry, sub_entry)
51
- with open(file_path) as fp:
52
- try:
53
- data = json.load(fp)
54
- except json.JSONDecodeError:
55
- print(f"Error reading {file_path}")
56
- continue
57
-
58
- data[EvalQueueColumn.model.name] = make_clickable_model(data["model"])
59
- data[EvalQueueColumn.revision.name] = data.get("revision", "main")
60
- all_evals.append(data)
61
-
62
- pending_list = [e for e in all_evals if e["status"] in ["PENDING", "RERUN"]]
63
- running_list = [e for e in all_evals if e["status"] == "RUNNING"]
64
- finished_list = [e for e in all_evals if e["status"].startswith("FINISHED") or e["status"] == "PENDING_NEW_EVAL"]
65
- df_pending = pd.DataFrame.from_records(pending_list, columns=cols)
66
- df_running = pd.DataFrame.from_records(running_list, columns=cols)
67
- df_finished = pd.DataFrame.from_records(finished_list, columns=cols)
68
- return df_finished[cols], df_running[cols], df_pending[cols]
 
1
  import json
2
  import os
3
+ import pathlib
4
  import pandas as pd
 
5
  from src.display.formatting import has_no_nan_values, make_clickable_model
6
  from src.display.utils import AutoEvalColumn, EvalQueueColumn, baseline_row
7
  from src.leaderboard.filter_models import filter_models_flags
8
  from src.leaderboard.read_evals import get_raw_eval_results
9
+ from src.display.utils import load_json_data
10
+
11
+
12
+ def _process_model_data(entry, model_name_key="model", revision_key="revision"):
13
+ """Enrich model data with clickable links and revisions."""
14
+ entry[EvalQueueColumn.model.name] = make_clickable_model(entry.get(model_name_key, ""))
15
+ entry[EvalQueueColumn.revision.name] = entry.get(revision_key, "main")
16
+ return entry
17
 
18
 
19
+ def get_evaluation_queue_df(save_path, cols):
20
+ """Generate dataframes for pending, running, and finished evaluation entries."""
21
+ save_path = pathlib.Path(save_path)
22
+ all_evals = []
23
+
24
+ for path in save_path.rglob('*.json'):
25
+ data = load_json_data(path)
26
+ if data:
27
+ all_evals.append(_process_model_data(data))
28
+
29
+ # Organizing data by status
30
+ status_map = {
31
+ "PENDING": ["PENDING", "RERUN"],
32
+ "RUNNING": ["RUNNING"],
33
+ "FINISHED": ["FINISHED", "PENDING_NEW_EVAL"],
34
+ }
35
+ status_dfs = {status: [] for status in status_map}
36
+ for eval_data in all_evals:
37
+ for status, extra_statuses in status_map.items():
38
+ if eval_data["status"] in extra_statuses:
39
+ status_dfs[status].append(eval_data)
40
+
41
+ return tuple(pd.DataFrame(status_dfs[status], columns=cols) for status in ["FINISHED", "RUNNING", "PENDING"])
42
+
43
+
44
+ def get_leaderboard_df(results_path, requests_path, dynamic_path, cols, benchmark_cols):
45
+ """Retrieve and process leaderboard data."""
46
+ raw_data = get_raw_eval_results(results_path, requests_path, dynamic_path)
47
+ all_data_json = [model.to_dict() for model in raw_data] + [baseline_row]
48
  filter_models_flags(all_data_json)
49
 
50
  df = pd.DataFrame.from_records(all_data_json)
 
 
51
  df = df.sort_values(by=[AutoEvalColumn.average.name], ascending=False)
52
  df = df[cols].round(decimals=2)
 
 
53
  df = df[has_no_nan_values(df, benchmark_cols)]
54
  return raw_data, df
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/tools/collections.py CHANGED
@@ -17,65 +17,60 @@ intervals = {
17
  }
18
 
19
 
20
- def update_collections(df: DataFrame):
21
- """This function updates the Open LLM Leaderboard model collection with the latest best models for
22
- each size category and type.
23
- """
24
- collection = get_collection(collection_slug=PATH_TO_COLLECTION, token=H4_TOKEN)
25
  params_column = pd.to_numeric(df[AutoEvalColumn.params.name], errors="coerce")
 
 
26
 
27
- cur_best_models = []
28
 
29
- ix = 0
30
- for type in ModelType:
31
- if type.value.name == "":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  continue
33
- for size in intervals:
34
- # We filter the df to gather the relevant models
35
- type_emoji = [t[0] for t in type.value.symbol]
36
- filtered_df = df[df[AutoEvalColumn.model_type_symbol.name].isin(type_emoji)]
37
 
38
- numeric_interval = pd.IntervalIndex([intervals[size]])
39
- mask = params_column.apply(lambda x: any(numeric_interval.contains(x)))
40
- filtered_df = filtered_df.loc[mask]
41
 
 
 
 
 
 
 
 
 
 
 
42
  best_models = list(
43
- filtered_df.sort_values(AutoEvalColumn.average.name, ascending=False)[AutoEvalColumn.dummy.name]
44
  )
45
- print(type.value.symbol, size, best_models[:10])
 
 
46
 
47
- # We add them one by one to the leaderboard
48
- for model in best_models:
49
- ix += 1
50
- cur_len_collection = len(collection.items)
51
- try:
52
- collection = add_collection_item(
53
- PATH_TO_COLLECTION,
54
- item_id=model,
55
- item_type="model",
56
- exists_ok=True,
57
- note=f"Best {type.to_str(' ')} model of around {size} on the leaderboard today!",
58
- token=H4_TOKEN,
59
- )
60
- if (
61
- len(collection.items) > cur_len_collection
62
- ): # we added an item - we make sure its position is correct
63
- item_object_id = collection.items[-1].item_object_id
64
- update_collection_item(
65
- collection_slug=PATH_TO_COLLECTION, item_object_id=item_object_id, position=ix
66
- )
67
- cur_len_collection = len(collection.items)
68
- cur_best_models.append(model)
69
- break
70
- except HfHubHTTPError:
71
- continue
72
-
73
- collection = get_collection(PATH_TO_COLLECTION, token=H4_TOKEN)
74
- for item in collection.items:
75
- if item.item_id not in cur_best_models:
76
- try:
77
- delete_collection_item(
78
- collection_slug=PATH_TO_COLLECTION, item_object_id=item.item_object_id, token=H4_TOKEN
79
- )
80
- except HfHubHTTPError:
81
- continue
 
17
  }
18
 
19
 
20
+ def _filter_by_type_and_size(df, model_type, size_interval):
21
+ """Filter DataFrame by model type and parameter size interval."""
22
+ type_emoji = model_type.value.symbol[0]
23
+ filtered_df = df[df[AutoEvalColumn.model_type_symbol.name] == type_emoji]
 
24
  params_column = pd.to_numeric(df[AutoEvalColumn.params.name], errors="coerce")
25
+ mask = params_column.apply(lambda x: x in size_interval)
26
+ return filtered_df.loc[mask]
27
 
 
28
 
29
+ def _add_models_to_collection(collection, models, model_type, size):
30
+ """Add best models to the collection and update positions."""
31
+ cur_len_collection = len(collection.items)
32
+ for ix, model in enumerate(models, start=1):
33
+ try:
34
+ collection = add_collection_item(
35
+ PATH_TO_COLLECTION,
36
+ item_id=model,
37
+ item_type="model",
38
+ exists_ok=True,
39
+ note=f"Best {model_type.to_str(' ')} model of around {size} on the leaderboard today!",
40
+ token=H4_TOKEN,
41
+ )
42
+ # Ensure position is correct if item was added
43
+ if len(collection.items) > cur_len_collection:
44
+ item_object_id = collection.items[-1].item_object_id
45
+ update_collection_item(collection_slug=PATH_TO_COLLECTION, item_object_id=item_object_id, position=ix)
46
+ cur_len_collection = len(collection.items)
47
+ break # assuming we only add the top model
48
+ except HfHubHTTPError:
49
  continue
 
 
 
 
50
 
 
 
 
51
 
52
+ def update_collections(df: DataFrame):
53
+ """Update collections by filtering and adding the best models."""
54
+ collection = get_collection(collection_slug=PATH_TO_COLLECTION, token=H4_TOKEN)
55
+ cur_best_models = []
56
+
57
+ for model_type in ModelType:
58
+ if not model_type.value.name:
59
+ continue
60
+ for size, interval in intervals.items():
61
+ filtered_df = _filter_by_type_and_size(df, model_type, interval)
62
  best_models = list(
63
+ filtered_df.sort_values(AutoEvalColumn.average.name, ascending=False)[AutoEvalColumn.dummy.name][:10]
64
  )
65
+ print(model_type.value.symbol, size, best_models)
66
+ _add_models_to_collection(collection, best_models, model_type, size)
67
+ cur_best_models.extend(best_models)
68
 
69
+ # Cleanup
70
+ existing_models = {item.item_id for item in collection.items}
71
+ to_remove = existing_models - set(cur_best_models)
72
+ for item_id in to_remove:
73
+ try:
74
+ delete_collection_item(collection_slug=PATH_TO_COLLECTION, item_object_id=item_id, token=H4_TOKEN)
75
+ except HfHubHTTPError:
76
+ continue