Ricercar commited on
Commit
4d2856f
β€’
1 Parent(s): 20581b8

new summary page

Browse files
Home.py CHANGED
@@ -18,6 +18,8 @@ def login():
18
  )
19
  st.write('You can leave it blank to be anonymous.')
20
 
 
 
21
  # Every form must have a submit button.
22
  submitted = st.form_submit_button("Start")
23
  if submitted:
@@ -40,6 +42,8 @@ def logout():
40
  st.session_state.pop('gallery_state', None)
41
  st.session_state.pop('progress', None)
42
  st.session_state.pop('gallery_focus', None)
 
 
43
 
44
 
45
  def info():
 
18
  )
19
  st.write('You can leave it blank to be anonymous.')
20
 
21
+ st.session_state.show_NSFW = st.checkbox(':orange[show potentially mature content]', help='Inevitably, a few images might be NSFW, even if we tried to elimiate NFSW content in our prompts. We calculate a NSFW score to filter them out. Please check only if you are 18+ and want to take a look at the whole GEMRec-18k dataset', value=False, key='mature_content')
22
+
23
  # Every form must have a submit button.
24
  submitted = st.form_submit_button("Start")
25
  if submitted:
 
42
  st.session_state.pop('gallery_state', None)
43
  st.session_state.pop('progress', None)
44
  st.session_state.pop('gallery_focus', None)
45
+ st.session_state.pop('assigned_rank_mode', None)
46
+ st.session_state.pop('show_NSFW', None)
47
 
48
 
49
  def info():
{pages β†’ css}/style.css RENAMED
File without changes
data/unsafe_prompts.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "world knowledge": [83],
3
+ "abstract": [1, 3]
4
+ }
pages/Gallery.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import requests
3
 
@@ -326,13 +327,15 @@ class GalleryApp:
326
 
327
  # add safety check for some prompts
328
  safety_check = True
329
- unsafe_prompts = {}
330
- # initialize unsafe prompts
 
331
  for prompt_tag in prompt_tags:
332
- unsafe_prompts[prompt_tag] = []
333
- # manually add unsafe prompts
334
- unsafe_prompts['world knowledge'] = [83]
335
- unsafe_prompts['abstract'] = [1, 3]
 
336
 
337
  if int(prompt_id.item()) in unsafe_prompts[tag]:
338
  st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
@@ -643,7 +646,7 @@ def altair_histogram(hist_data, sort_by, mini, maxi):
643
 
644
 
645
  @st.cache_data
646
- def load_hf_dataset():
647
  # login to huggingface
648
  login(token=os.environ.get("HF_TOKEN"))
649
 
@@ -669,7 +672,9 @@ def load_hf_dataset():
669
  promptBook.loc[:, 'row_idx'] = promptBook.index
670
 
671
  # apply a nsfw filter
672
- promptBook = promptBook[promptBook['nsfw_score'] <= 0.84].reset_index(drop=True)
 
 
673
 
674
  # add a column that adds up 'norm_clip', 'norm_mcos', and 'norm_pop'
675
  score_weights = [1.0, 0.8, 0.2]
@@ -693,9 +698,6 @@ def load_tsne_coordinates(items):
693
  if __name__ == "__main__":
694
  st.set_page_config(page_title="Model Coffer Gallery", page_icon="πŸ–ΌοΈ", layout="wide")
695
 
696
- with open('./pages/style.css') as f:
697
- st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
698
-
699
  if 'user_id' not in st.session_state:
700
  st.warning('Please log in first.')
701
  home_btn = st.button('Go to Home Page')
@@ -703,7 +705,7 @@ if __name__ == "__main__":
703
  switch_page("home")
704
  else:
705
  # st.write('You have already logged in as ' + st.session_state.user_id[0])
706
- roster, promptBook, images_ds = load_hf_dataset()
707
  # print(promptBook.columns)
708
 
709
  # # initialize selected_dict
@@ -713,15 +715,6 @@ if __name__ == "__main__":
713
  app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
714
  app.app()
715
 
716
- # components.html(
717
- # """
718
- # <script>
719
- # var iframe = window.parent.document.querySelector('[title="streamlit_agraph.agraph"]');
720
- # console.log(iframe);
721
- # var targetElement = iframe.contentDocument.querySelector('div.vis-network div.vis-navigation div.vis-button.vis-zoomExtends');
722
- # console.log(targetElement);
723
- # targetElement.style.background-image = "url(https://www.flaticon.com/free-icon-font/menu-burger_3917215?related_id=3917215#)";
724
- # </script>
725
- # """,
726
- # # unsafe_allow_html=True,
727
- # )
 
1
+ import json
2
  import os
3
  import requests
4
 
 
327
 
328
  # add safety check for some prompts
329
  safety_check = True
330
+
331
+ # load unsafe prompts
332
+ unsafe_prompts = json.load(open('./data/unsafe_prompts.json', 'r'))
333
  for prompt_tag in prompt_tags:
334
+ if prompt_tag not in unsafe_prompts:
335
+ unsafe_prompts[prompt_tag] = []
336
+ # # manually add unsafe prompts
337
+ # unsafe_prompts['world knowledge'] = [83]
338
+ # unsafe_prompts['abstract'] = [1, 3]
339
 
340
  if int(prompt_id.item()) in unsafe_prompts[tag]:
341
  st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
 
646
 
647
 
648
  @st.cache_data
649
+ def load_hf_dataset(show_NSFW=False):
650
  # login to huggingface
651
  login(token=os.environ.get("HF_TOKEN"))
652
 
 
672
  promptBook.loc[:, 'row_idx'] = promptBook.index
673
 
674
  # apply a nsfw filter
675
+ if not show_NSFW:
676
+ promptBook = promptBook[promptBook['norm_nsfw'] <= 0.8].reset_index(drop=True)
677
+ print('nsfw filter applied', len(promptBook))
678
 
679
  # add a column that adds up 'norm_clip', 'norm_mcos', and 'norm_pop'
680
  score_weights = [1.0, 0.8, 0.2]
 
698
  if __name__ == "__main__":
699
  st.set_page_config(page_title="Model Coffer Gallery", page_icon="πŸ–ΌοΈ", layout="wide")
700
 
 
 
 
701
  if 'user_id' not in st.session_state:
702
  st.warning('Please log in first.')
703
  home_btn = st.button('Go to Home Page')
 
705
  switch_page("home")
706
  else:
707
  # st.write('You have already logged in as ' + st.session_state.user_id[0])
708
+ roster, promptBook, images_ds = load_hf_dataset(st.session_state.show_NSFW)
709
  # print(promptBook.columns)
710
 
711
  # # initialize selected_dict
 
715
  app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
716
  app.app()
717
 
718
+ with open('./css/style.css') as f:
719
+ st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
720
+
 
 
 
 
 
 
 
 
 
pages/Ranking.py CHANGED
@@ -366,9 +366,6 @@ def connect_to_db():
366
  if __name__ == "__main__":
367
  st.set_page_config(page_title="Personal Image Ranking", page_icon="πŸŽ–οΈοΈ", layout="wide")
368
 
369
- with open('./pages/style.css') as f:
370
- st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
371
-
372
  if 'user_id' not in st.session_state:
373
  st.warning('Please log in first.')
374
  home_btn = st.button('Go to Home Page')
@@ -391,7 +388,7 @@ if __name__ == "__main__":
391
  switch_page('gallery')
392
  else:
393
  # st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.')
394
- roster, promptBook, images_ds = load_hf_dataset()
395
  print(st.session_state.selected_dict)
396
 
397
  # st.write("# Full function is coming soon.")
@@ -422,3 +419,5 @@ if __name__ == "__main__":
422
  app = RankingApp(promptBook_selected, images_endpoint, batch_size=4)
423
  app.app()
424
 
 
 
 
366
  if __name__ == "__main__":
367
  st.set_page_config(page_title="Personal Image Ranking", page_icon="πŸŽ–οΈοΈ", layout="wide")
368
 
 
 
 
369
  if 'user_id' not in st.session_state:
370
  st.warning('Please log in first.')
371
  home_btn = st.button('Go to Home Page')
 
388
  switch_page('gallery')
389
  else:
390
  # st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.')
391
+ roster, promptBook, images_ds = load_hf_dataset(st.session_state.show_NSFW)
392
  print(st.session_state.selected_dict)
393
 
394
  # st.write("# Full function is coming soon.")
 
419
  app = RankingApp(promptBook_selected, images_endpoint, batch_size=4)
420
  app.app()
421
 
422
+ with open('./css/style.css') as f:
423
+ st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
pages/Summary.py CHANGED
@@ -24,15 +24,16 @@ class DashboardApp:
24
  def sidebar(self, tags, mode):
25
  with st.sidebar:
26
  tag = st.selectbox('Select a tag', tags, key='tag')
27
- st.write('---')
28
- st.write('## Want a more comprehensive summary?')
29
- st.write('Jump back to gallery and select more images to rank!')
30
- back_to_gallery = st.button('πŸ–ΌοΈ Go to Gallery', key='summary_sidebar_gallery')
31
- if back_to_gallery:
32
- switch_page('gallery')
33
- back_to_ranking = st.button('πŸŽ–οΈ Go to Ranking', key='summary_sidebar_ranking')
34
- if back_to_ranking:
35
- switch_page('ranking')
 
36
 
37
  return tag
38
 
@@ -108,15 +109,8 @@ class DashboardApp:
108
  image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{images[i]}.png"
109
  st.image(image, use_column_width=True)
110
 
111
-
112
- # # st.write('---')
113
- # expander = st.expander(f'# {icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{modelVersion_id})')
114
- # with expander:
115
- # images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values
116
- # st.write(images)
117
-
118
  def podium_expander(self, modelVersion_standings, n=3):
119
- st.write('## Top picks')
120
  # metric_cols = st.columns(n)
121
  for i in range(n):
122
  # with metric_cols[i]:
@@ -126,15 +120,24 @@ class DashboardApp:
126
  model_id, model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_id', 'model_name', 'modelVersion_name', 'modelVersion_url']].values[0]
127
 
128
  icon = 'πŸ₯‡'if i == 0 else 'πŸ₯ˆ' if i == 1 else 'πŸ₯‰'
129
- with st.expander(f'**{icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})**, Ranking Score: {winning_times}'):
130
- images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values
131
- st.write(f'### Images generated with {icon} {model_name}, {modelVersion_name}')
132
- col_num = 4
133
- image_cols = st.columns(col_num)
134
- for i in range(len(images)):
135
- with image_cols[i % col_num]:
136
- image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{images[i]}.png"
137
- st.image(image, use_column_width=True)
 
 
 
 
 
 
 
 
 
138
 
139
  def score_calculator(self, results, db_table):
140
  modelVersion_standings = {}
@@ -160,9 +163,8 @@ class DashboardApp:
160
 
161
  return modelVersion_standings
162
 
163
-
164
  def app(self):
165
- st.title('Your Preferred Models', help="Scores are calculated based on your ranking results.")
166
 
167
  # mode = st.sidebar.radio('Ranking mode', ['Drag and Sort', 'Battle'], horizontal=True, index=1)
168
  mode = st.session_state.assigned_rank_mode
@@ -219,7 +221,7 @@ if __name__ == "__main__":
219
  switch_page('gallery')
220
 
221
  else:
222
- roster, promptBook, images_ds = load_hf_dataset()
223
  RANKING_CONN = connect_to_db()
224
  app = DashboardApp(roster, promptBook, session_finished)
225
  app.app()
 
24
  def sidebar(self, tags, mode):
25
  with st.sidebar:
26
  tag = st.selectbox('Select a tag', tags, key='tag')
27
+ # st.write('---')
28
+ with st.form('summary_sidebar_form'):
29
+ st.write('## Want a more comprehensive summary?')
30
+ st.write('Jump back to gallery and select more images to rank!')
31
+ back_to_gallery = st.form_submit_button('πŸ–ΌοΈ Go to Gallery')
32
+ if back_to_gallery:
33
+ switch_page('gallery')
34
+ back_to_ranking = st.form_submit_button('πŸŽ–οΈ Go to Ranking')
35
+ if back_to_ranking:
36
+ switch_page('ranking')
37
 
38
  return tag
39
 
 
109
  image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{images[i]}.png"
110
  st.image(image, use_column_width=True)
111
 
 
 
 
 
 
 
 
112
  def podium_expander(self, modelVersion_standings, n=3):
113
+ # st.write('## Top picks')
114
  # metric_cols = st.columns(n)
115
  for i in range(n):
116
  # with metric_cols[i]:
 
120
  model_id, model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_id', 'model_name', 'modelVersion_name', 'modelVersion_url']].values[0]
121
 
122
  icon = 'πŸ₯‡'if i == 0 else 'πŸ₯ˆ' if i == 1 else 'πŸ₯‰'
123
+ podium_display = st.columns([1, 14])
124
+ with podium_display[0]:
125
+ st.title(f'{icon}')
126
+ with podium_display[1]:
127
+ st.write(f'##### {model_name}, {modelVersion_name}')
128
+ st.write(f'[Civitai Page](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id}), [Model Download Link]({url}), Ranking Score: {winning_times}')
129
+ # with st.expander(f'**{icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})**, Ranking Score: {winning_times}'):
130
+ with st.expander(f'Show Images'):
131
+ images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values
132
+ # st.write(f'### Images generated with {icon} {model_name}, {modelVersion_name}')
133
+ col_num = 4
134
+ image_cols = st.columns(col_num)
135
+ for j in range(len(images)):
136
+ with image_cols[j % col_num]:
137
+ image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{images[j]}.png"
138
+ st.image(image, use_column_width=True)
139
+ if i != n - 1:
140
+ st.write('---')
141
 
142
  def score_calculator(self, results, db_table):
143
  modelVersion_standings = {}
 
163
 
164
  return modelVersion_standings
165
 
 
166
  def app(self):
167
+ st.write('### Your Preferred Models')
168
 
169
  # mode = st.sidebar.radio('Ranking mode', ['Drag and Sort', 'Battle'], horizontal=True, index=1)
170
  mode = st.session_state.assigned_rank_mode
 
221
  switch_page('gallery')
222
 
223
  else:
224
+ roster, promptBook, images_ds = load_hf_dataset(st.session_state.show_NSFW)
225
  RANKING_CONN = connect_to_db()
226
  app = DashboardApp(roster, promptBook, session_finished)
227
  app.app()