chrisjay commited on
Commit
6f4a3dd
β€’
1 Parent(s): d7529ad

initial commit to RL stats

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. app.css +32 -0
  3. app.py +272 -0
  4. utils.py +84 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
app.css ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .infoPoint h1 {
3
+ font-size: 30px;
4
+ text-decoration: bold;
5
+
6
+ }
7
+
8
+ a {
9
+ text-decoration: underline;
10
+ color: #1f3b54 ;
11
+ }
12
+
13
+ table {
14
+
15
+ margin: 25px 0;
16
+ font-size: 0.9em;
17
+ font-family: sans-serif;
18
+ min-width: 400px;
19
+ box-shadow: 0 0 20px rgba(0, 0, 0, 0.15);
20
+ }
21
+
22
+ table th,
23
+ table td {
24
+ padding: 12px 15px;
25
+ }
26
+
27
+ tr {
28
+ text-align: left;
29
+ }
30
+ thead tr {
31
+ text-align: left;
32
+ }
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import pandas as pd
3
+ from tqdm.auto import tqdm
4
+ from utils import *
5
+ import gradio as gr
6
+ from huggingface_hub import HfApi, hf_hub_download
7
+ from huggingface_hub.repocard import metadata_load
8
+
9
+
10
+ class DeepRL_Leaderboard:
11
+ def __init__(self) -> None:
12
+ self.leaderboard= {}
13
+
14
+ def add_leaderboard(self,id=None, title=None):
15
+ if id is not None and title is not None:
16
+ id = id.strip()
17
+ title = title.strip()
18
+ self.leaderboard.update({id:{'title':title,'data':get_data_per_env(id)}})
19
+ def get_data(self):
20
+ return self.leaderboard
21
+
22
+ def get_ids(self):
23
+ return list(self.leaderboard.keys())
24
+
25
+
26
+
27
+ # CSS file for the
28
+ with open('app.css','r') as f:
29
+ BLOCK_CSS = f.read()
30
+
31
+
32
+
33
+ LOADED_MODEL_IDS = {}
34
+
35
+ def get_data(rl_env):
36
+ global LOADED_MODEL_IDS
37
+ data = []
38
+ model_ids = get_model_ids(rl_env)
39
+ LOADED_MODEL_IDS[rl_env]=model_ids
40
+
41
+ for model_id in tqdm(model_ids):
42
+ meta = get_metadata(model_id)
43
+ if meta is None:
44
+ continue
45
+ row={}
46
+ row["metadata"] = meta
47
+
48
+ data.append(row)
49
+ return pd.DataFrame.from_records(data)
50
+
51
+ def get_data_per_env(rl_env):
52
+ dataframe = get_data(rl_env)
53
+ return dataframe,dataframe.empty
54
+
55
+
56
+
57
+ rl_leaderboard = DeepRL_Leaderboard()
58
+ rl_leaderboard.add_leaderboard('CarRacing-v0'," The Car Racing 🏎️ Leaderboard πŸš€")
59
+ rl_leaderboard.add_leaderboard('MountainCar-v0',"The Mountain Car ⛰️ πŸš— Leaderboard πŸš€")
60
+ rl_leaderboard.add_leaderboard('LunarLander-v2',"The Lunar Lander πŸŒ• Leaderboard πŸš€")
61
+ rl_leaderboard.add_leaderboard('BipedalWalker-v3',"The BipedalWalker Leaderboard πŸš€")
62
+ rl_leaderboard.add_leaderboard('Taxi-v3','The Taxi-v3πŸš– Leaderboard πŸš€')
63
+ rl_leaderboard.add_leaderboard('FrozenLake-v1-4x4-no_slippery','The FrozenLake-v1-4x4-no_slippery Leaderboard πŸš€')
64
+ rl_leaderboard.add_leaderboard('FrozenLake-v1-8x8-no_slippery','The FrozenLake-v1-8x8-no_slippery Leaderboard πŸš€')
65
+ rl_leaderboard.add_leaderboard('FrozenLake-v1-4x4','The FrozenLake-v1-4x4 Leaderboard πŸš€')
66
+ rl_leaderboard.add_leaderboard('FrozenLake-v1-8x8','The FrozenLake-v1-8x8 Leaderboard πŸš€')
67
+ rl_leaderboard.add_leaderboard('SpaceInvadersNoFrameskip-v4','The SpaceInvadersNoFrameskip-v4 Leaderboard πŸš€')
68
+
69
+ RL_ENVS = rl_leaderboard.get_ids()
70
+ RL_DETAILS = rl_leaderboard.get_data()
71
+
72
+
73
+ def update_data(rl_env):
74
+ global LOADED_MODEL_IDS
75
+ data = []
76
+ model_ids = [x for x in get_model_ids(rl_env) if x not in LOADED_MODEL_IDS[rl_env]]
77
+ LOADED_MODEL_IDS[rl_env]+=model_ids
78
+
79
+ for model_id in tqdm(model_ids):
80
+ meta = get_metadata(model_id)
81
+ if meta is None:
82
+ continue
83
+ row = {}
84
+ row["metadata"] = meta
85
+ data.append(row)
86
+ return pd.DataFrame.from_records(data)
87
+
88
+
89
+
90
+ def update_data_per_env(rl_env):
91
+ global RL_DETAILS
92
+
93
+ old_dataframe,_ = RL_DETAILS[rl_env]['data']
94
+ new_dataframe = update_data(rl_env)
95
+
96
+ new_dataframe = new_dataframe.fillna("")
97
+
98
+ dataframe = pd.concat([old_dataframe,new_dataframe])
99
+
100
+ return dataframe,dataframe.empty
101
+
102
+
103
+
104
+
105
+
106
+
107
+ def get_info_display(dataframe,env_name,name_leaderboard,is_empty):
108
+ if not is_empty:
109
+ markdown = """
110
+ <div class='infoPoint'>
111
+ <h1> {name_leaderboard} </h1>
112
+ <br>
113
+ <p> This is a leaderboard of <b>{len_dataframe}</b> agents, from <b>{num_unique_users}</b> unique users, playing {env_name} πŸ‘©β€πŸš€. </p>
114
+ <br>
115
+ <p> We use lower bound result to sort the models: mean_reward - std_reward. </p>
116
+ <br>
117
+ <p> You can click on the model's name to be redirected to its model card which includes documentation. </p>
118
+ <br>
119
+ <p> You want to try your model? Read this <a href="https://github.com/huggingface/deep-rl-class/blob/Unit1/unit1/README.md" target="_blank">Unit 1</a> of Deep Reinforcement Learning Class.
120
+ </p>
121
+ </div>
122
+ """.format(len_dataframe = len(dataframe),env_name = env_name,name_leaderboard = name_leaderboard,num_unique_users = len(set(dataframe['User'].values)))
123
+
124
+ else:
125
+ markdown = """
126
+ <div class='infoPoint'>
127
+ <h1> {name_leaderboard} </h1>
128
+ <br>
129
+ </div>
130
+ """.format(name_leaderboard = name_leaderboard)
131
+ return markdown
132
+
133
+ def reload_all_data():
134
+
135
+ global RL_DETAILS,RL_ENVS
136
+
137
+ for rl_env in RL_ENVS:
138
+ RL_DETAILS[rl_env]['data'] = update_data_per_env(rl_env)
139
+
140
+ html = """<div style="color: green">
141
+ <p> βœ… Leaderboard updated! Click `Show Statistics` to see the current statistics.</p>
142
+ </div>
143
+ """
144
+ return html
145
+
146
+
147
+ def reload_leaderboard(rl_env):
148
+ global RL_DETAILS
149
+
150
+ data_dataframe,is_empty = RL_DETAILS[rl_env]['data']
151
+
152
+ markdown = get_info_display(data_dataframe,rl_env,RL_DETAILS[rl_env]['title'],is_empty)
153
+
154
+ return markdown
155
+
156
+ def get_units_stat():
157
+ # gets the number of models per unit
158
+ units={'Unit 1':[],'Unit 2':[],'Unit 3':[]}
159
+ for rl_env in RL_ENVS:
160
+ rl_env_metadata,is_empty = RL_DETAILS[rl_env]['data']
161
+ if is_empty is False:
162
+ # All good! Carry on
163
+ metadata_list = rl_env_metadata['metadata'].values
164
+ units['Unit 1'].extend([m for m in metadata_list if 'stable-baselines3' in m['tags']])
165
+ units['Unit 2'].extend([m for m in metadata_list if 'custom-implementation' in m['tags']])
166
+ units['Unit 3'].extend([m for m in metadata_list if 'stable-baselines3' in m['tags'] and 'SpaceInvadersNoFrameskip-v4'.lower() in [tag.lower for tag in m['tags']]])
167
+
168
+ # get count
169
+ for k in units.keys():
170
+ units[k] = len(units[k])
171
+
172
+ return plot_bar(value = list(units.values),name = list(units.keys()),x_name = "Units",y_name = "Number of model submissions",title="Number of model submissions per unit")
173
+
174
+
175
+
176
+ def get_models_stat():
177
+ # gets the number of models per unit
178
+ units={}
179
+ for rl_env in RL_ENVS:
180
+ rl_env_metadata,is_empty = RL_DETAILS[rl_env]['data']
181
+ if is_empty is False:
182
+ # All good! Carry on
183
+ metadata_list = rl_env_metadata['metadata'].values
184
+ units[rl_env] = [m for m in metadata_list]
185
+
186
+ # get count
187
+ for k in units.keys():
188
+ units[k] = len(units[k])
189
+
190
+ return plot_bar(value = list(units.values),name = list(units.keys()),x_name = "RL Environment",y_name = "Number of model submissions",title="Number of model submissions per RL environment")
191
+
192
+ def get_user_stat():
193
+ # gets the number of models per unit
194
+ users={}
195
+ for rl_env in RL_ENVS:
196
+ rl_env_metadata,is_empty = RL_DETAILS[rl_env]['data']
197
+ if is_empty is False:
198
+ # All good! Carry on
199
+ metadata_list = rl_env_metadata['metadata'].values
200
+ users[rl_env] = [m['model_id'].split('/')[0] for m in metadata_list]
201
+
202
+ # get count
203
+ for k in users.keys():
204
+ users[k] = len(set(users[k]))
205
+
206
+ return plot_bar(value = list(users.values),name = list(users.keys()),x_name = "RL Environment",y_name = "Number of user submissions",title="Number of user submissions per RL environment")
207
+
208
+ def get_stat():
209
+ # gets the number of models per unit
210
+ units={'Unit 1':[],'Unit 2':[],'Unit 3':[]}
211
+ users={}
212
+ models={}
213
+ for rl_env in RL_ENVS:
214
+ rl_env_metadata,is_empty = RL_DETAILS[rl_env]['data']
215
+ if is_empty is False:
216
+ # All good! Carry on
217
+ metadata_list = rl_env_metadata['metadata'].values
218
+ units['Unit 1'].extend([m for m in metadata_list if 'stable-baselines3' in m['tags']])
219
+ units['Unit 2'].extend([m for m in metadata_list if 'custom-implementation' in m['tags']])
220
+ units['Unit 3'].extend([m for m in metadata_list if 'stable-baselines3' in m['tags'] and 'spaceinvadersNoFrameskip-v4'.lower() in [tag.lower() for tag in m['tags']]])
221
+
222
+ users[rl_env] = [m['model_id'].split('/')[0] for m in metadata_list]
223
+ models[rl_env] = [m for m in metadata_list]
224
+
225
+ # get count
226
+ for k in units.keys():
227
+ units[k] = len(units[k])
228
+ for k in users.keys():
229
+ users[k] = len(set(users[k]))
230
+ for k in models.keys():
231
+ models[k] = len(models[k])
232
+
233
+ units_plot = plot_bar(value = list(units.values()),name = list(units.keys()),x_name = "Units",y_name = "Number of model submissions",title="Number of model submissions per unit")
234
+ user_plot = plot_barh(value = list(users.values()),name = list(users.keys()),x_name = "RL Environment",y_name = "Number of unique user submissions",title="Number of unique user submissions per RL environment")
235
+ model_plot = plot_barh(value = list(models.values()),name = list(models.keys()),x_name = "RL Environment",y_name = "Number of model submissions",title="Number of model submissions per RL environment")
236
+ return units_plot,user_plot,model_plot
237
+
238
+
239
+
240
+
241
+
242
+ block = gr.Blocks(css=BLOCK_CSS)
243
+ with block:
244
+ notification = gr.HTML("""<div style="color: green">
245
+ <p> βŒ› Updating leaderboard... </p>
246
+ </div>
247
+ """)
248
+ block.load(reload_all_data,[],[notification])
249
+
250
+ with gr.Tabs():
251
+ with gr.TabItem("Dashboard") as rl_tab:
252
+ # Stats of user submission per units
253
+ # 2. # model submissions per environment
254
+ # 3. # unique users per environment
255
+ # get_units_stat()
256
+ #data_html,data_dataframe,is_empty = RL_DETAILS[rl_env]['data']
257
+ #markdown = get_info_display(data_dataframe,rl_env,RL_DETAILS[rl_env]['title'],is_empty)
258
+ #env_state =gr.Variable(default_value=rl_env)
259
+ #output_markdown = gr.HTML(markdown)
260
+ reload = gr.Button('Show Statistics')
261
+
262
+ units_plot = gr.Plot(type="matplotlib")
263
+ model_plot = gr.Plot(type="matplotlib")
264
+ user_plot = gr.Plot(type="matplotlib")
265
+ #plot_gender = gr.Plot(type="matplotlib")
266
+
267
+ #output_html = gr.HTML(data_html)
268
+
269
+ reload.click(get_stat,[],[units_plot,user_plot,model_plot])
270
+ #rl_tab.select(reload_leaderboard,inputs=[env_state],outputs=[output_markdown,output_html])
271
+
272
+ block.launch()
utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import requests
3
+ from tqdm.auto import tqdm
4
+ from huggingface_hub import HfApi, hf_hub_download
5
+ from huggingface_hub.repocard import metadata_load
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+ def plot_bar(value,name,x_name,y_name,title):
10
+ fig, ax = plt.subplots(figsize=(10,4),tight_layout=True)
11
+
12
+ ax.set(xlabel=x_name, ylabel=y_name,title=title)
13
+
14
+ ax.bar(name, value)
15
+
16
+
17
+ return ax.figure
18
+ def plot_barh(value,name,x_name,y_name,title):
19
+ fig, ax = plt.subplots(figsize=(10,4),tight_layout=True)
20
+
21
+ ax.set(xlabel=x_name, ylabel=y_name,title=title)
22
+
23
+ ax.barh(name, value)
24
+
25
+
26
+ return ax.figure
27
+ # Based on Omar Sanseviero work
28
+ # Make model clickable link
29
+ def make_clickable_model(model_name):
30
+ # remove user from model name
31
+ model_name_show = ' '.join(model_name.split('/')[1:])
32
+
33
+ link = "https://huggingface.co/" + model_name
34
+ return f'<a target="_blank" href="{link}">{model_name_show}</a>'
35
+
36
+ # Make user clickable link
37
+ def make_clickable_user(user_id):
38
+ link = "https://huggingface.co/" + user_id
39
+ return f'<a target="_blank" href="{link}">{user_id}</a>'
40
+
41
+
42
+
43
+ def get_model_ids(rl_env):
44
+ api = HfApi()
45
+ models = api.list_models(filter=rl_env)
46
+ model_ids = [x.modelId for x in models]
47
+ return model_ids
48
+
49
+ def get_metadata(model_id):
50
+ try:
51
+ readme_path = hf_hub_download(model_id, filename="README.md")
52
+ metadata = metadata_load(readme_path)
53
+ metadata['model_id'] = model_id
54
+ return metadata
55
+ except requests.exceptions.HTTPError:
56
+ # 404 README.md not found
57
+ return None
58
+
59
+ def parse_metrics_accuracy(meta):
60
+ if "model-index" not in meta:
61
+ return None
62
+ result = meta["model-index"][0]["results"]
63
+ metrics = result[0]["metrics"]
64
+ accuracy = metrics[0]["value"]
65
+ return accuracy
66
+
67
+ # We keep the worst case episode
68
+ def parse_rewards(accuracy):
69
+ default_std = -1000
70
+ default_reward=-1000
71
+ if accuracy != None:
72
+ parsed = accuracy.split(' +/- ')
73
+ if len(parsed)>1:
74
+ mean_reward = float(parsed[0])
75
+ std_reward = float(parsed[1])
76
+ else:
77
+ mean_reward = float(default_std)
78
+ std_reward = float(default_reward)
79
+
80
+ else:
81
+ mean_reward = float(default_std)
82
+ std_reward = float(default_reward)
83
+ return mean_reward, std_reward
84
+