SaulLu commited on
Commit
cf4f63b
1 Parent(s): 3046477

start to add bubble animation

Browse files
Files changed (5) hide show
  1. Makefile +15 -0
  2. app.py +17 -0
  3. dashboard_utils/bubbles.py +122 -0
  4. requirements-dev.txt +3 -0
  5. requirements.txt +1 -0
Makefile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .PHONY: quality style test test-examples
3
+
4
+ # Check that source code meets quality standards
5
+
6
+ quality:
7
+ python -m black --check --line-length 119 --target-version py38 .
8
+ python -m isort --check-only .
9
+ python -m flake8 --max-line-length 119
10
+
11
+ # Format source code automatically
12
+
13
+ style:
14
+ python -m black --line-length 119 --target-version py38 .
15
+ python -m isort .
app.py CHANGED
@@ -1,4 +1,21 @@
 
 
1
  import streamlit as st
 
 
 
2
 
3
  st.title("Training transformers together dashboard")
4
  st.write("test")
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
  import streamlit as st
4
+ from streamlit_observable import observable
5
+
6
+ from dashboard_utils.bubbles import get_new_bubble_data
7
 
8
  st.title("Training transformers together dashboard")
9
  st.write("test")
10
+
11
+
12
+ serialized_data, profiles = get_new_bubble_data()
13
+
14
+
15
+ observers = observable(
16
+ "Participants",
17
+ notebook="d/9ae236a507f54046", # "@huggingface/participants-bubbles-chart",
18
+ targets=["c_noaws"],
19
+ # observe=["selectedCounties"]
20
+ redefine={"serializedData": serialized_data, "profileSimple": profiles},
21
+ )
dashboard_utils/bubbles.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from urllib import parse
3
+
4
+ import requests
5
+ import wandb
6
+
7
+ URL_QUICKSEARCH = "https://huggingface.co/api/quicksearch?"
8
+ WANDB_REPO = "learning-at-home/Worker_logs"
9
+
10
+
11
+ def get_new_bubble_data():
12
+ serialized_data_points, latest_timestamp = get_serialized_data_points()
13
+ serialized_data = get_serialized_data(serialized_data_points, latest_timestamp)
14
+ profiles = get_profiles(serialized_data_points)
15
+
16
+ return serialized_data, profiles
17
+
18
+
19
+ def get_profiles(serialized_data_points):
20
+ profiles = []
21
+ for username in serialized_data_points.keys():
22
+ params = {"type": "user", "q": username}
23
+ new_url = URL_QUICKSEARCH + parse.urlencode(params)
24
+ r = requests.get(new_url)
25
+ response = r.json()
26
+ try:
27
+ avatarUrl = response["users"][0]["avatarUrl"]
28
+ except:
29
+ avatarUrl = "/avatars/57584cb934354663ac65baa04e6829bf.svg"
30
+ if avatarUrl.startswith("/avatars/"):
31
+ avatarUrl = f"https://huggingface.co{avatarUrl}"
32
+ profiles.append(
33
+ {"id": username, "name": username, "src": avatarUrl, "url": f"https://huggingface.co/{username}"}
34
+ )
35
+ return profiles
36
+
37
+
38
+ def get_serialized_data_points():
39
+ api = wandb.Api()
40
+ runs = api.runs(WANDB_REPO)
41
+
42
+ serialized_data_points = {}
43
+ latest_timestamp = None
44
+ print("**start api call")
45
+ for run in runs:
46
+ run_summary = run.summary._json_dict
47
+ run_name = run.name
48
+
49
+ if run_name in serialized_data_points:
50
+ try:
51
+ timestamp = run_summary["_timestamp"]
52
+ serialized_data_points[run_name]["Runs"].append(
53
+ {
54
+ "batches": run_summary["_step"],
55
+ "runtime": run_summary["_runtime"],
56
+ "loss": run_summary["train/loss"],
57
+ "velocity": run_summary["_step"] / run_summary["_runtime"],
58
+ "date": datetime.datetime.utcfromtimestamp(timestamp),
59
+ }
60
+ )
61
+ if not latest_timestamp or timestamp > latest_timestamp:
62
+ latest_timestamp = timestamp
63
+ except Exception as e:
64
+ pass
65
+ # print(e)
66
+ # print([key for key in list(run_summary.keys()) if "gradients" not in key])
67
+ else:
68
+ try:
69
+ timestamp = run_summary["_timestamp"]
70
+ serialized_data_points[run_name] = {
71
+ "profileId": run_name,
72
+ "Runs": [
73
+ {
74
+ "batches": run_summary["_step"],
75
+ "runtime": run_summary["_runtime"],
76
+ "loss": run_summary["train/loss"],
77
+ "velocity": run_summary["_step"] / run_summary["_runtime"],
78
+ "date": datetime.datetime.utcfromtimestamp(timestamp),
79
+ }
80
+ ],
81
+ }
82
+ if not latest_timestamp or timestamp > latest_timestamp:
83
+ latest_timestamp = timestamp
84
+ except Exception as e:
85
+ pass
86
+ # print(e)
87
+ # print([key for key in list(run_summary.keys()) if "gradients" not in key])
88
+ latest_timestamp = datetime.datetime.utcfromtimestamp(latest_timestamp)
89
+ print("**finish api call")
90
+ return serialized_data_points, latest_timestamp
91
+
92
+
93
+ def get_serialized_data(serialized_data_points, latest_timestamp):
94
+ serialized_data_points_v2 = []
95
+ max_velocity = 1
96
+ for run_name, serialized_data_point in serialized_data_points.items():
97
+ activeRuns = []
98
+ loss = 0
99
+ runtime = 0
100
+ batches = 0
101
+ velocity = 0
102
+ for run in serialized_data_point["Runs"]:
103
+ if run["date"] == latest_timestamp:
104
+ run["date"] = run["date"].isoformat()
105
+ activeRuns.append(run)
106
+ loss += run["loss"]
107
+ velocity += run["velocity"]
108
+ loss = loss / len(activeRuns) if activeRuns else 0
109
+ runtime += run["runtime"]
110
+ batches += run["batches"]
111
+ new_item = {
112
+ "date": latest_timestamp.isoformat(),
113
+ "profileId": run_name,
114
+ "batches": batches,
115
+ "runtime": runtime,
116
+ "loss": loss,
117
+ "velocity": velocity,
118
+ "activeRuns": activeRuns,
119
+ }
120
+ serialized_data_points_v2.append(new_item)
121
+ serialized_data = {"points": [serialized_data_points_v2], "maxVelocity": max_velocity}
122
+ return serialized_data
requirements-dev.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ black
2
+ isort
3
+ flake8
requirements.txt CHANGED
@@ -1 +1,2 @@
1
  streamlit
 
 
1
  streamlit
2
+ streamlit-observable