ZhangYuhan commited on
Commit
7c1eee1
1 Parent(s): 8862cab
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ from serve.gradio_web_t2s import *
5
+ from serve.gradio_web_i2s import *
6
+ from serve.leaderboard import build_leaderboard_tab
7
+ from model.model_manager import ModelManager
8
+ from pathlib import Path
9
+ from serve.constants import SERVER_PORT, ROOT_PATH, ELO_RESULTS_DIR
10
+
11
+
12
+ def build_combine_demo(models, elo_results_file, leaderboard_table_file):
13
+ with gr.Blocks(
14
+ title="Play with Open 3D Generative Models",
15
+ theme=gr.themes.Default(),
16
+ css=block_css,
17
+ ) as demo:
18
+ with gr.Tabs() as tabs_combine:
19
+ with gr.Tab("Text-to-3D Generation", id=0):
20
+ with gr.Tabs() as tabs_ig:
21
+ with gr.Tab("Text-to-3D Arena (battle)", id=0):
22
+ build_t2s_ui_side_by_side_anony(models)
23
+ with gr.Tab("Text-to-3D Arena (side-by-side)", id=1):
24
+ build_t2s_ui_side_by_side_named(models)
25
+ with gr.Tab("Text-to-3D Direct Chat", id=2):
26
+ build_t2s_ui_single_model(models)
27
+ if elo_results_file:
28
+ with gr.Tab("Text-to-3D Leaderboard", id=3):
29
+ build_leaderboard_tab(elo_results_file['t2s_generation'], leaderboard_table_file['t2s_generation'])
30
+ with gr.Tab("About Us", id=4):
31
+ build_about()
32
+
33
+ with gr.Tab("Image-to-3D Generation", id=5):
34
+ with gr.Tabs() as tabs_ie:
35
+ with gr.Tab("Image-to-3D Arena (battle)", id=5):
36
+ build_i2s_ui_side_by_side_anony(models)
37
+ with gr.Tab("Image-to-3D Arena (side-by-side)", id=6):
38
+ build_i2s_ui_side_by_side_named(models)
39
+ with gr.Tab("Image-to-3D Direct Chat", id=7):
40
+ build_i2s_ui_single_model(models)
41
+ if elo_results_file:
42
+ with gr.Tab("Image-to-3D Leaderboard", id=8):
43
+ build_leaderboard_tab(elo_results_file['i2s_generation'], leaderboard_table_file['i2s_generation'])
44
+ with gr.Tab("About Us", id=9):
45
+ build_about()
46
+
47
+ return demo
48
+
49
+
50
+ def load_elo_results(elo_results_dir):
51
+ from collections import defaultdict
52
+ elo_results_file = defaultdict(lambda: None)
53
+ leaderboard_table_file = defaultdict(lambda: None)
54
+ if elo_results_dir is not None:
55
+ elo_results_dir = Path(elo_results_dir)
56
+ elo_results_file = {}
57
+ leaderboard_table_file = {}
58
+ for file in elo_results_dir.glob('elo_results_*.pkl'):
59
+ if 't2s_generation' in file.name:
60
+ elo_results_file['t2s_generation'] = file
61
+ elif 'i2s_generation' in file.name:
62
+ elo_results_file['i2s_generation'] = file
63
+ else:
64
+ raise ValueError(f"Unknown file name: {file.name}")
65
+ for file in elo_results_dir.glob('*_leaderboard.csv'):
66
+ if 't2s_generation' in file.name:
67
+ leaderboard_table_file['t2s_generation'] = file
68
+ elif 'i2s_generation' in file.name:
69
+ leaderboard_table_file['i2s_generation'] = file
70
+ else:
71
+ raise ValueError(f"Unknown file name: {file.name}")
72
+
73
+ return elo_results_file, leaderboard_table_file
74
+
75
+ if __name__ == "__main__":
76
+ server_port = int(SERVER_PORT)
77
+ root_path = ROOT_PATH
78
+ elo_results_dir = ELO_RESULTS_DIR
79
+ models = ModelManager()
80
+
81
+ # elo_results_file, leaderboard_table_file = load_elo_results(elo_results_dir)
82
+ elo_results_file, leaderboard_table_file = None, None
83
+ demo = build_combine_demo(models, elo_results_file, leaderboard_table_file)
84
+ demo.queue(max_size=20).launch(server_port=server_port, root_path=ROOT_PATH)
model/__init__.py ADDED
File without changes
model/model_config.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ from typing import List
3
+
4
+ ModelConfig = namedtuple("ModelConfig", ["model_name", "i2s_model", "online_model", "model_path"])
5
+ model_config = {}
6
+
7
+ def register_model_config(
8
+ model_name: str, i2s_model: bool, online_model: bool, model_path: str = None
9
+ ):
10
+ config = ModelConfig(model_name, i2s_model, online_model, model_path)
11
+ model_config[model_name] = config
12
+
13
+ def get_model_config(model_name: str) -> ModelConfig:
14
+ assert model_name in model_config
15
+ return model_config[model_name]
16
+
17
+ register_model_config(
18
+ model_name="dreamfusion",
19
+ i2s_model=False,
20
+ online_model=False
21
+ )
22
+
23
+ register_model_config(
24
+ model_name="instant3d",
25
+ i2s_model=False,
26
+ online_model=False
27
+ )
28
+
29
+ register_model_config(
30
+ model_name="latent-nerf",
31
+ i2s_model=False,
32
+ online_model=False
33
+ )
34
+
35
+ register_model_config(
36
+ model_name="magic3d",
37
+ i2s_model=False,
38
+ online_model=False
39
+ )
40
+
41
+ register_model_config(
42
+ model_name="mvdream",
43
+ i2s_model=False,
44
+ online_model=False
45
+ )
46
+
47
+ register_model_config(
48
+ model_name="prolificdreamer",
49
+ i2s_model=False,
50
+ online_model=False
51
+ )
model/model_manager.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import random
3
+ import gradio as gr
4
+ import requests
5
+ import io, base64, json
6
+ # import spaces
7
+ from PIL import Image
8
+
9
+ from .model_config import model_config
10
+ from .model_worker import BaseModelWorker
11
+
12
+ class ModelManager:
13
+ def __init__(self):
14
+ self.models_config = model_config
15
+ self.models_worker: list[BaseModelWorker] = {}
16
+
17
+ self.build_model_workers()
18
+
19
+ def build_model_workers(self):
20
+ for cfg in self.models_config.values():
21
+ worker = BaseModelWorker(cfg.model_name, cfg.i2s_model, cfg.online_model, cfg.model_path)
22
+ self.models_worker[cfg.model_name] = worker
23
+
24
+ def get_all_models(self):
25
+ models = []
26
+ for model_name in self.models_config.keys():
27
+ models.append(model_name)
28
+ return models
29
+
30
+ def get_t2s_models(self):
31
+ models = []
32
+ for cfg in self.models_config.values():
33
+ if not cfg.i2s_model:
34
+ models.append(cfg.model_name)
35
+ return models
36
+
37
+ def get_i2s_models(self):
38
+ models = []
39
+ for cfg in self.models_config.values():
40
+ if cfg.i2s_model:
41
+ models.append(cfg.model_name)
42
+ return models
43
+
44
+ def get_online_models(self):
45
+ models = []
46
+ for cfg in self.models_config.values():
47
+ if cfg.online_model:
48
+ models.append(cfg.model_name)
49
+ return models
50
+
51
+ def get_models(self, i2s_model:bool, online_model:bool):
52
+ models = []
53
+ for cfg in self.models_config.values():
54
+ if cfg.i2s_model==i2s_model and cfg.online_model==online_model:
55
+ models.append(cfg.model_name)
56
+ return models
57
+
58
+ def check_online(self, name):
59
+ worker = self.models_worker[name]
60
+ if not worker.online_model:
61
+ return
62
+
63
+ # @spaces.GPU(duration=120)
64
+ def inference(self, prompt, model_name):
65
+ worker = self.models_worker[model_name]
66
+ result = worker.inference(prompt=prompt)
67
+ return result
68
+
69
+ def render(self, prompt, model_name):
70
+ worker = self.models_worker[model_name]
71
+ result = worker.render(prompt=prompt)
72
+ return result
73
+
74
+ def inference_parallel(self, prompt, model_A, model_B):
75
+ results = []
76
+ model_names = [model_A, model_B]
77
+ with concurrent.futures.ThreadPoolExecutor() as executor:
78
+ future_to_result = {executor.submit(self.inference, prompt, model): model
79
+ for model in model_names}
80
+ for future in concurrent.futures.as_completed(future_to_result):
81
+ result = future.result()
82
+ results.append(result)
83
+ return results[0], results[1]
84
+
85
+ def inference_parallel_anony(self, prompt, model_A, model_B, i2s_model):
86
+ if model_A == model_B == "":
87
+ model_A, model_B = random.sample(self.get_models(i2s_model=i2s_model, online_model=True), 2)
88
+ model_names = [model_A, model_B]
89
+ results = []
90
+ with concurrent.futures.ThreadPoolExecutor() as executor:
91
+ future_to_result = {executor.submit(self.inference, prompt, model): model
92
+ for model in model_names}
93
+ for future in concurrent.futures.as_completed(future_to_result):
94
+ result = future.result()
95
+ results.append(result)
96
+ return results[0], results[1]
97
+
98
+
99
+ def render_parallel(self, prompt, model_A, model_B):
100
+ results = []
101
+ model_names = [model_A, model_B]
102
+ with concurrent.futures.ThreadPoolExecutor() as executor:
103
+ future_to_result = {executor.submit(self.render, prompt, model): model
104
+ for model in model_names}
105
+ for future in concurrent.futures.as_completed(future_to_result):
106
+ result = future.result()
107
+ results.append(result)
108
+ return results[0], results[1]
109
+
110
+ # def i2s_inference_parallel(self, image, model_A, model_B):
111
+ # results = []
112
+ # model_names = [model_A, model_B]
113
+ # with concurrent.futures.ThreadPoolExecutor() as executor:
114
+ # future_to_result = {executor.submit(self.inference, image, model): model
115
+ # for model in model_names}
116
+ # for future in concurrent.futures.as_completed(future_to_result):
117
+ # result = future.result()
118
+ # results.append(result)
119
+ # return results[0], results[1]
120
+
model/model_registry.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ from typing import List
3
+
4
+ ModelInfo = namedtuple("ModelInfo", ["simple_name", "link", "description"])
5
+ model_info = {}
6
+
7
+ def register_model_info(
8
+ full_names: List[str], simple_name: str, link: str, description: str
9
+ ):
10
+ info = ModelInfo(simple_name, link, description)
11
+
12
+ for full_name in full_names:
13
+ model_info[full_name] = info
14
+
15
+ def get_model_info(name: str) -> ModelInfo:
16
+ if name in model_info:
17
+ return model_info[name]
18
+ else:
19
+ # To fix this, please use `register_model_info` to register your model
20
+ return ModelInfo(
21
+ name, "", "Register the description at fastchat/model/model_registry.py"
22
+ )
23
+
24
+ def get_model_description_md(model_list):
25
+ model_description_md = """
26
+ | | | |
27
+ | ---- | ---- | ---- |
28
+ """
29
+ ct = 0
30
+ visited = set()
31
+ for i, name in enumerate(model_list):
32
+ minfo = get_model_info(name)
33
+ if minfo.simple_name in visited:
34
+ continue
35
+ visited.add(minfo.simple_name)
36
+ one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
37
+
38
+ if ct % 3 == 0:
39
+ model_description_md += "|"
40
+ model_description_md += f" {one_model_md} |"
41
+ if ct % 3 == 2:
42
+ model_description_md += "\n"
43
+ ct += 1
44
+ return model_description_md
45
+
46
+ # regist text-to-shape generation models
47
+
48
+ register_model_info(
49
+ ["dreamfusion"],
50
+ "DreamFusion",
51
+ "https://dreamfusion3d.github.io/",
52
+ "Text-to-3D using 2D Diffusion and SDS Loss",
53
+ )
54
+
55
+ register_model_info(
56
+ ["dreamgaussian"],
57
+ "DreamGaussian",
58
+ "https://github.com/dreamgaussian/dreamgaussian",
59
+ "Generative Gaussian Splatting for Efficient 3D Content Creation",
60
+ )
61
+
62
+ register_model_info(
63
+ ["fantasia3d"],
64
+ "Fantasia3D",
65
+ "https://fantasia3d.github.io/",
66
+ "Disentangling Geometry and Appearance for High-quality Text-to-3D Content Creation",
67
+ )
68
+
69
+ register_model_info(
70
+ ["latent-nerf"],
71
+ "Latent-NeRF",
72
+ "https://github.com/eladrich/latent-nerf",
73
+ "Latent-NeRF for Shape-Guided Generation of 3D Shapes and Textures",
74
+ )
75
+
76
+ register_model_info(
77
+ ["magic3d"],
78
+ "Magic3D",
79
+ "https://research.nvidia.com/labs/dir/magic3d/",
80
+ "High-Resolution Text-to-3D Content Creation",
81
+ )
82
+
83
+ register_model_info(
84
+ ["geodream"],
85
+ "GeoDream",
86
+ "https://mabaorui.github.io/GeoDream_page/",
87
+ "Disentangling 2D and Geometric Priors for High-Fidelity and Consistent 3D Generation",
88
+ )
89
+
90
+ register_model_info(
91
+ ["mvdream"],
92
+ "MVDream",
93
+ "https://github.com/bytedance/MVDream",
94
+ "Multi-view Diffusion for 3D Generation",
95
+ )
96
+
97
+ register_model_info(
98
+ ["prolificdreamer"],
99
+ "ProlificDreamer",
100
+ "https://ml.cs.tsinghua.edu.cn/prolificdreamer/",
101
+ "High-Fidelity and Diverse Text-to-3D Generation with Variational Score Distillation",
102
+ )
103
+
104
+ register_model_info(
105
+ ["syncdreamer"],
106
+ "SyncDreamer",
107
+ "https://github.com/liuyuan-pal/SyncDreamer",
108
+ "Generating Multiview-consistent Images from a Single-view Image",
109
+ )
110
+
111
+ register_model_info(
112
+ ["wonder3d"],
113
+ "Wonder3D",
114
+ "https://github.com/xxlong0/Wonder3D",
115
+ "Single Image to 3D using Cross-Domain Diffusion",
116
+ )
117
+
118
+ register_model_info(
119
+ ["dreamcraft3d"],
120
+ "Dreamcraft3d",
121
+ "https://github.com/deepseek-ai/DreamCraft3D",
122
+ "Hierarchical 3d generation with bootstrapped diffusion prior",
123
+ )
124
+
125
+
126
+ # register_model_info(
127
+ # [],
128
+ # "",
129
+ # "",
130
+ # "",
131
+ # )
132
+
133
+
134
+ # regist image edition models
135
+
136
+ register_model_info(
137
+ ["zero123"],
138
+ "Zero-1-to-3",
139
+ "https://github.com/cvlab-columbia/zero123",
140
+ "Zero-shot One Image to 3D Object",
141
+ )
142
+
143
+ register_model_info(
144
+ ["stable-zero123", "zero123-xl"],
145
+ "Stable Zero123",
146
+ "https://stability.ai/news/stable-zero123-3d-generation",
147
+ "Quality 3D Object Generation from Single Images",
148
+ )
149
+
150
+ register_model_info(
151
+ ["magic123"],
152
+ "Magic123",
153
+ "https://guochengqian.github.io/project/magic123/",
154
+ "One Image to High-Quality 3D Object Generation Using Both 2D and 3D Diffusion Priors",
155
+ )
156
+
157
+ register_model_info(
158
+ ["imagedream"],
159
+ "ImageDream",
160
+ "https://github.com/bytedance/ImageDream",
161
+ "Image-Prompt Multi-view Diffusion for 3D Generation",
162
+ )
163
+
164
+ register_model_info(
165
+ ["lucid-dreamer"],
166
+ "LucidDreamer",
167
+ "https://github.com/EnVision-Research/LucidDreamer",
168
+ "Towards High-Fidelity Text-to-3D Generation via Interval Score Matching",
169
+ )
170
+
171
+ register_model_info(
172
+ ["make-it-3d"],
173
+ "Make-It-3D",
174
+ "https://github.com/junshutang/Make-It-3D",
175
+ "High-Fidelity 3D Creation from A Single Image with Diffusion Prior",
176
+ )
177
+
178
+ register_model_info(
179
+ ["triplane-gaussian"],
180
+ "TriplaneGaussian",
181
+ "https://github.com/VAST-AI-Research/TriplaneGaussian",
182
+ "Triplane Meets Gaussian Splatting: Fast and Generalizable Single-View 3D Reconstruction with Transformers",
183
+ )
184
+
185
+ register_model_info(
186
+ ["free3d"],
187
+ "Free3D",
188
+ "https://github.com/lyndonzheng/Free3D",
189
+ "Consistent Novel View Synthesis without 3D Representation",
190
+ )
191
+
192
+ register_model_info(
193
+ ["escher-net"],
194
+ "EscherNet",
195
+ "https://github.com/kxhit/EscherNet",
196
+ "A Generative Model for Scalable View Synthesis",
197
+ )
198
+
199
+ register_model_info(
200
+ ["v3d"],
201
+ "V3D",
202
+ "https://github.com/heheyas/V3D",
203
+ "Video Diffusion Models are Effective 3D Generators",
204
+ )
205
+
206
+ register_model_info(
207
+ ["lgm"],
208
+ "LGM",
209
+ "https://github.com/3DTopia/LGM",
210
+ "Large Multi-View Gaussian Model for High-Resolution 3D Content Creation",
211
+ )
212
+
213
+ # register_model_info(
214
+ # [],
215
+ # "",
216
+ # "",
217
+ # "",
218
+ # )
219
+
model/model_worker.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from typing import List
4
+ import replicate
5
+
6
+ os.environ("REPLICATE_API_TOKEN", "r8_0BaoQW0G8nWFXY8YWBCCUDurANxCtY72rarv9")
7
+
8
+ class BaseModelWorker:
9
+ def __init__(self,
10
+ model_name: str,
11
+ i2s_model: bool,
12
+ online_model: bool,
13
+ model_path: str = None,
14
+ ):
15
+ self.model_name = model_name
16
+ self.i2s_model = i2s_model
17
+ self.online_model = online_model
18
+ self.model_path = model_path
19
+ self.model = None
20
+
21
+ if self.online_model:
22
+ assert not self.model_path, f"Please give model_path of {model_name}"
23
+ self.model = self.load_model()
24
+
25
+ def check_online(self) -> bool:
26
+ if self.online_model and not self.model:
27
+ return True
28
+ else:
29
+ return False
30
+
31
+ def load_model(self):
32
+ pass
33
+
34
+ def inference(self, prompt):
35
+ pass
36
+
37
+ def render(self, shape):
38
+ pass
39
+
40
+ class HuggingfaceApiWorker(BaseModelWorker):
41
+ def __init__(
42
+ self,
43
+ model_name: str,
44
+ i2s_model: bool,
45
+ online_model: bool,
46
+ model_api: str,
47
+ model_path: str = None,
48
+ ):
49
+ super().__init__(
50
+ model_name,
51
+ i2s_model,
52
+ online_model,
53
+ model_path,
54
+ )
55
+ self.model_api = model_api
56
+
57
+ class PointE_Worker(BaseModelWorker):
58
+ def __init__(self,
59
+ model_name: str,
60
+ i2s_model: bool,
61
+ online_model: bool,
62
+ model_api: str,
63
+ model_path: str = None):
64
+ super().__init__(model_name, i2s_model, online_model, model_path)
65
+ self.model_api = model_api
66
+
67
+
68
+ class LGM_Worker(BaseModelWorker):
69
+ def __init__(self,
70
+ model_name: str,
71
+ i2s_model: bool,
72
+ online_model: bool,
73
+ model_path: str = "camenduru/lgm-ply-to-glb:eb217314ab0d025370df16b8c9127f9ac1a0e4b3ffbff6b323d598d3c814d258"):
74
+ super().__init__(model_name, i2s_model, online_model, model_path)
75
+
76
+ def inference(self, image):
77
+ output = replicate.run(
78
+ self.model_path,
79
+ input={"ply_file_url": image}
80
+ )
81
+ #=> .glb file url: "https://replicate.delivery/pbxt/r4iOSfk7cv2wACJL539ACB4E...
82
+ return output
83
+
84
+
85
+ if __name__=="__main__":
86
+ input = {
87
+ "ply_file_url": "https://replicate.delivery/pbxt/UvKKgNj9mT7pIVHzwerhcjkp5cMH4FS5emPVghk2qyzMRwUSA/gradio_output.ply"
88
+ }
89
+ print("Start...")
90
+ output = replicate.run(
91
+ "camenduru/lgm-ply-to-glb:eb217314ab0d025370df16b8c9127f9ac1a0e4b3ffbff6b323d598d3c814d258",
92
+ input=input
93
+ )
94
+ print("output: ", output)
95
+ #=> "https://replicate.delivery/pbxt/r4iOSfk7cv2wACJL539ACB4E...
model/models/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .imagenhub_models import load_imagenhub_model
2
+ from .playground_api import load_playground_model
3
+
4
+ IMAGE_GENERATION_MODELS = ['imagenhub_LCM_generation','imagenhub_SDXLTurbo_generation','imagenhub_SDXL_generation', 'imagenhub_PixArtAlpha_generation',
5
+ 'imagenhub_OpenJourney_generation','imagenhub_SDXLLightning_generation', 'imagenhub_StableCascade_generation',
6
+ 'playground_PlayGroundV2_generation', 'playground_PlayGroundV2.5_generation']
7
+ IMAGE_EDITION_MODELS = ['imagenhub_CycleDiffusion_edition', 'imagenhub_Pix2PixZero_edition', 'imagenhub_Prompt2prompt_edition',
8
+ 'imagenhub_SDEdit_edition', 'imagenhub_InstructPix2Pix_edition', 'imagenhub_MagicBrush_edition', 'imagenhub_PNP_edition']
9
+
10
+
11
+ def load_pipeline(model_name):
12
+ """
13
+ Load a model pipeline based on the model name
14
+ Args:
15
+ model_name (str): The name of the model to load, should be of the form {source}_{name}_{type}
16
+ the source can be either imagenhub or playground
17
+ the name is the name of the model used to load the model
18
+ the type is the type of the model, either generation or edition
19
+ """
20
+ model_source, model_name, model_type = model_name.split("_")
21
+ if model_source == "imagenhub":
22
+ pipe = load_imagenhub_model(model_name, model_type)
23
+ elif model_source == "playground":
24
+ pipe = load_playground_model(model_name)
25
+ else:
26
+ raise ValueError(f"Model source {model_source} not supported")
27
+ return pipe
model/models/fal_api_models.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fal
2
+
3
+ class FalModel():
4
+ def __init__(self, model_name, model_type):
5
+ self.model_name = model_name
6
+ self.modle_type = model_type
7
+
8
+ def __call__(self, *args, **kwargs):
9
+
10
+ if self.model_type == "text2image":
11
+ assert "prompt" in kwargs, "prompt is required for text2image model"
12
+ handler = fal.apps.submit(
13
+ f"fal-ai/{self.model_name}",
14
+ arguments={
15
+ "prompt": kwargs["prompt"]
16
+ },
17
+ )
18
+
19
+ for event in handler.iter_events():
20
+ if isinstance(event, fal.apps.InProgress):
21
+ print('Request in progress')
22
+ print(event.logs)
23
+
24
+ result = handler.get()
25
+ return result
26
+ elif self.model_type == "image2image":
27
+ assert "image" in kwargs or "image_url" in kwargs, "image or image_url is required for image2image model"
28
+ if "image" in kwargs:
29
+ image_url = None
30
+ pass
31
+ handler = fal.apps.submit(
32
+ f"fal-ai/{self.model_name}",
33
+ arguments={
34
+ "image_url": image_url
35
+ },
36
+ )
37
+
38
+ for event in handler.iter_events():
39
+ if isinstance(event, fal.apps.InProgress):
40
+ print('Request in progress')
41
+ print(event.logs)
42
+
43
+ result = handler.get()
44
+ return result
45
+ else:
46
+ raise ValueError("model_type must be text2image or image2image")
47
+
48
+ def load_fal_model(model_name, model_type):
49
+ return FalModel(model_name, model_type)
model/models/imagenhub_models.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imagen_hub
2
+
3
+ class ImagenHubModel():
4
+ def __init__(self, model_name):
5
+ self.model = imagen_hub.load(model_name)
6
+
7
+ def __call__(self, *args, **kwargs):
8
+ return self.model.infer_one_image(*args, **kwargs)
9
+
10
+ class PNP(ImagenHubModel):
11
+ def __init__(self):
12
+ super().__init__('PNP')
13
+
14
+ def __call__(self, *args, **kwargs):
15
+ if "num_inversion_steps" not in kwargs:
16
+ kwargs["num_inversion_steps"] = 200
17
+ return super().__call__(*args, **kwargs)
18
+
19
+ class Prompt2prompt(ImagenHubModel):
20
+ def __init__(self):
21
+ super().__init__('Prompt2prompt')
22
+
23
+ def __call__(self, *args, **kwargs):
24
+ if "num_inner_steps" not in kwargs:
25
+ kwargs["num_inner_steps"] = 3
26
+ return super().__call__(*args, **kwargs)
27
+
28
+ def load_imagenhub_model(model_name, model_type=None):
29
+ if model_name == 'PNP':
30
+ return PNP()
31
+ if model_name == 'Prompt2prompt':
32
+ return Prompt2prompt()
33
+ return ImagenHubModel(model_name)
model/models/playground_api.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import requests
4
+ from PIL import Image
5
+ import io
6
+ import base64
7
+ class PlayGround():
8
+ def __init__(self, model_name, model_type=None):
9
+ self.model_name = model_name
10
+ self.model_type = model_type
11
+ self.api_key = os.environ['PlaygroundAPI']
12
+ if model_name == "PlayGroundV2":
13
+ self._model_name = "Playground_v2"
14
+ elif model_name == "PlayGroundV2.5":
15
+ self._model_name = "Playground_v2.5"
16
+
17
+
18
+ def __call__(self, prompt):
19
+ headers = {
20
+ 'Content-Type': 'application/json',
21
+ 'Authorization': "Bearer " + self.api_key,
22
+ }
23
+
24
+ data = json.dumps({"prompt": prompt, "filter_model": self._model_name, "scheduler": "DPMPP_2M_K", "guidance_scale": 3})
25
+
26
+ response = requests.post('https://playground.com/api/models/external/v1', headers=headers, data=data)
27
+ response.raise_for_status()
28
+ json_obj = response.json()
29
+ image_base64 = json_obj['images'][0]
30
+ img = Image.open(io.BytesIO(base64.decodebytes(bytes(image_base64, "utf-8"))))
31
+
32
+ return img
33
+
34
+ def load_playground_model(model_name, model_type="generation"):
35
+ return PlayGround(model_name, model_type)
model/other_models.py ADDED
File without changes
offline/utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import requests
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+
8
+ # MAX_LEN = 40
9
+ # STEP = 2
10
+ # x = np.arange(0, MAX_LEN, STEP)
11
+ # token_counts = [0] * (MAX_LEN//STEP)
12
+
13
+ # with open("prompts.json", 'r') as f:
14
+ # prompts = json.load(f)
15
+
16
+ # for prompt in prompts:
17
+ # tokens = len(prompt.strip().split(' '))
18
+ # token_counts[min(tokens//STEP, MAX_LEN//STEP-1)] += 1
19
+
20
+ # plt.xticks(x, x+1)
21
+ # plt.xlabel("token counts")
22
+ # plt.bar(x, token_counts, width=1.3)
23
+ # # plt.show()
24
+ # plt.savefig("token_counts.png")
25
+
26
+ ## Generate image prompts
27
+ with open("prompts.json") as f:
28
+ text_prompts = json.load(f)
29
+
30
+ engine_id = "stable-diffusion-v1-6"
31
+ api_host = os.getenv('API_HOST', 'https://api.stability.ai')
32
+ api_key = os.getenv("STABILITY_API_KEY", "sk-ZvoFiXEbln6yh0hvSlm1K60WYcWFY5rmyW8a9FgoVBrKKP9N")
33
+
34
+ if api_key is None:
35
+ raise Exception("Missing Stability API key.")
36
+
37
+ for idx, text in enumerate(text_prompts):
38
+ if idx<=20: continue
39
+ print(f"Start generate prompt[{idx}]: {text}")
40
+ response = requests.post(
41
+ f"{api_host}/v1/generation/{engine_id}/text-to-image",
42
+ headers={
43
+ "Content-Type": "application/json",
44
+ "Accept": "application/json",
45
+ "Authorization": f"Bearer {api_key}"
46
+ },
47
+ json={
48
+ "text_prompts": [
49
+ {
50
+ "text": text.strip()
51
+ }
52
+ ],
53
+ "cfg_scale": 7,
54
+ "height": 1024,
55
+ "width": 1024,
56
+ "samples": 3,
57
+ "steps": 30,
58
+ },
59
+ )
60
+
61
+ if response.status_code != 200:
62
+ # raise Exception("Non-200 response: " + str(response.text))
63
+ print(f"{idx} Failed!!! {str(response.text)}")
64
+ continue
65
+
66
+ print("Finished!")
67
+ data = response.json()
68
+
69
+ for i, image in enumerate(data["artifacts"]):
70
+ img_path = f"./images/{idx}/v1_txt2img_{i}.png"
71
+ os.makedirs(os.path.dirname(img_path), exist_ok=True)
72
+ with open(img_path, "wb") as f:
73
+ f.write(base64.b64decode(image["base64"]))
serve/constants.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ LOGDIR = os.getenv("LOGDIR", "./3DGen-Arena-logs/vote_log")
4
+ IMAGE_DIR = os.getenv("IMAGE_DIR", f"{LOGDIR}/images")
5
+ OFFLINE_DIR = "./offline"
6
+
7
+ SERVER_PORT = os.getenv("SERVER_PORT", 7860)
8
+ ROOT_PATH = os.getenv("ROOT_PATH", None)
9
+ ELO_RESULTS_DIR = os.getenv("ELO_RESULTS_DIR", "./arena_elo/results/latest")
10
+
11
+ LOG_SERVER = os.getenv("LOG_SERVER", "https://tigerai.ca")
12
+ LOG_SERVER_SUBDOAMIN = os.getenv("LOG_SERVER_SUBDIR", "GenAI-Arena-hf-logs")
13
+ LOG_SERVER_ADDR = os.getenv("LOG_SERVER_ADDR", f"{LOG_SERVER}/{LOG_SERVER_SUBDOAMIN}")
14
+ # LOG SERVER API ENDPOINTS
15
+ APPEND_JSON = "append_json"
16
+ SAVE_IMAGE = "save_image"
17
+ SAVE_LOG = "save_log"
18
+
19
+ NUM_SIDES = 2
20
+ TEXT_PROMPT_PATH = "offline/prompts.json"
serve/gradio_web_i2s.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from functools import partial
3
+
4
+ from .utils import *
5
+ from .vote_utils import (
6
+ upvote_last_response_i2s as upvote_last_response,
7
+ downvote_last_response_i2s as downvote_last_response,
8
+ flag_last_response_i2s as flag_last_response,
9
+ leftvote_last_response_i2s_anony as leftvote_last_response_anony,
10
+ rightvote_last_response_i2s_anony as rightvote_last_response_anony,
11
+ tievote_last_response_i2s_anony as tievote_last_response_anony,
12
+ bothbad_vote_last_response_i2s_anony as bothbad_vote_last_response_anony,
13
+ leftvote_last_response_i2s_named as leftvote_last_response_named,
14
+ rightvote_last_response_i2s_named as rightvote_last_response_named,
15
+ tievote_last_response_i2s_named as tievote_last_response_named,
16
+ bothbad_vote_last_response_i2s_named as bothbad_vote_last_response_named,
17
+ share_click_i2s_multi as share_click,
18
+ share_js
19
+ )
20
+ from .inference import(
21
+ sample_i2s_model as sample_model,
22
+ sample_image,
23
+ sample_image_side_by_side,
24
+ generate_i2s,
25
+ generate_i2s_multi,
26
+ generate_i2s_multi_annoy
27
+ )
28
+
29
+
30
+ def build_i2s_ui_side_by_side_anony(models):
31
+ notice_markdown = """
32
+ # ⚔️ GenAI-Arena ⚔️ : Benchmarking Image-to-3D generative models
33
+ ## 📜 Rules
34
+ - Upload image to two anonymous models in same area and vote for the better one!
35
+ - When the results are ready, click the button below to vote.
36
+ - Vote won't be counted if model identity is revealed during conversation.
37
+ - Click "Clear" to start a new round.
38
+
39
+ ## 🏆 Arena Elo
40
+ Find out who is the 🥇conditional image generation models! More models are going to be supported.
41
+
42
+ ## 👇 Generating now!
43
+
44
+ """
45
+ model_list = models.get_i2s_models()
46
+ gen_func = partial(generate_i2s_multi_annoy, models.inference_parallel, models.render_parallel)
47
+
48
+ state_0 = gr.State()
49
+ state_1 = gr.State()
50
+
51
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
52
+ model_str = gr.Markdown(str(model_list), visible=False, elem_id="moedl list string")
53
+
54
+ with gr.Group(elem_id="share-region-anony"):
55
+ with gr.Accordion("🔍 Expand to see all Arena players", open=False):
56
+ model_description_md = get_model_description_md(model_list)
57
+ gr.Markdown(model_description_md, elem_id="model_description_markdown")
58
+
59
+ with gr.Row():
60
+ with gr.Column():
61
+ normal_left = gr.Image(width=512, label = "Model A",
62
+ interactive=False, show_download_button=True)
63
+ rgb_left = gr.Image(width=512, label = "Model A",
64
+ interactive=False, show_download_button=True)
65
+ with gr.Column():
66
+ normal_right = gr.Image(width=512, label = "Model B",
67
+ interactive=False, show_download_button=True,)
68
+ rgb_right = gr.Image(width=512, label = "Model B",
69
+ interactive=False, show_download_button=True,)
70
+
71
+ with gr.Row():
72
+ with gr.Column():
73
+ model_selector_left =gr.Markdown("", visible=False)
74
+ with gr.Column():
75
+ model_selector_right = gr.Markdown("", visible=False)
76
+ with gr.Row():
77
+ slow_warning = gr.Markdown("", elem_id="notice_markdown")
78
+
79
+ with gr.Row(elem_id="Geometry Quality"):
80
+ geo_md = gr.Markdown("Geometry Quality: ", visible=False, elem_id="evaldim_markdown")
81
+ geo_leftvote_btn = gr.Button(
82
+ value="👈 A is better", visible=False, interactive=False
83
+ )
84
+ geo_rightvote_btn = gr.Button(
85
+ value="👉 B is better", visible=False, interactive=False
86
+ )
87
+ geo_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
88
+ geo_bothbad_btn = gr.Button(
89
+ value="👎 Both are bad", visible=False, interactive=False
90
+ )
91
+
92
+ with gr.Row(elem_id="Texture Quality"):
93
+ text_md = gr.Markdown("Texture Quality: ", visible=False, elem_id="evaldim_markdown")
94
+ text_leftvote_btn = gr.Button(
95
+ value="👈 A is better", visible=False, interactive=False
96
+ )
97
+ text_rightvote_btn = gr.Button(
98
+ value="👉 B is better", visible=False, interactive=False
99
+ )
100
+ text_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
101
+ text_bothbad_btn = gr.Button(
102
+ value="👎 Both are bad", visible=False, interactive=False
103
+ )
104
+
105
+ with gr.Row(elem_id="Alignment Quality"):
106
+ align_md = gr.Markdown("Alignment Quality: ", visible=False, elem_id="evaldim_markdown")
107
+ align_leftvote_btn = gr.Button(
108
+ value="👈 A is better", visible=False, interactive=False
109
+ )
110
+ align_rightvote_btn = gr.Button(
111
+ value="👉 B is better", visible=False, interactive=False
112
+ )
113
+ align_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
114
+ align_bothbad_btn = gr.Button(
115
+ value="👎 Both are bad", visible=False, interactive=False
116
+ )
117
+
118
+ with gr.Row():
119
+ imagebox = gr.Image(
120
+ width=512,
121
+ show_label=False,
122
+ visible=True,
123
+ interactive=True,
124
+ elem_id="input_box",
125
+ )
126
+ with gr.Column():
127
+ sample_btn = gr.Button(value="🎲 Sample", variant="primary")
128
+ send_btn = gr.Button(value="📤 Send", variant="primary")
129
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
130
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
131
+ share_btn = gr.Button(value="📷 Share")
132
+
133
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
134
+
135
+
136
+ geo_btn_list = [geo_leftvote_btn, geo_rightvote_btn, geo_tie_btn, geo_bothbad_btn]
137
+ text_btn_list = [text_leftvote_btn, text_rightvote_btn, text_tie_btn, text_bothbad_btn]
138
+ align_btn_list = [align_leftvote_btn, align_rightvote_btn, align_tie_btn, align_bothbad_btn]
139
+ states = [state_0, state_1]
140
+ model_selectors = [model_selector_left, model_selector_right]
141
+ results = [normal_left, rgb_left, normal_right, rgb_right]
142
+
143
+ for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
144
+ leftvote_btn, rightvote_btn, tie_btn, bothbad_btn = btn_list
145
+
146
+ leftvote_btn.click(
147
+ leftvote_last_response_anony,
148
+ states + model_selectors,
149
+ [imagebox] + btn_list + model_selectors
150
+ )
151
+ rightvote_btn.click(
152
+ rightvote_last_response_anony,
153
+ states + model_selectors,
154
+ [imagebox] + btn_list + model_selectors
155
+ )
156
+ tie_btn.click(
157
+ tievote_last_response_anony,
158
+ states + model_selectors,
159
+ [imagebox] + btn_list + model_selectors
160
+ )
161
+ bothbad_btn.click(
162
+ bothbad_vote_last_response_anony,
163
+ states + model_selectors,
164
+ [imagebox] + btn_list + model_selectors
165
+ )
166
+
167
+ sample_btn.click(
168
+ sample_image_side_by_side,
169
+ states + model_selectors,
170
+ states + [imagebox],
171
+ api_name="sample_btn_anony"
172
+ )
173
+
174
+ imagebox.upload(
175
+ gen_func,
176
+ states + [imagebox] + model_selectors,
177
+ states + results + model_selectors,
178
+ api_name="submit_btn_named"
179
+ ).then(
180
+ enable_mds,
181
+ None,
182
+ [geo_md, text_md, align_md]
183
+ ).then(
184
+ enable_buttons_side_by_side,
185
+ None,
186
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
187
+ )
188
+
189
+ send_btn.click(
190
+ sample_model,
191
+ states + [model_str],
192
+ states + model_selectors
193
+ ).then(
194
+ gen_func,
195
+ states + [imagebox] + model_selectors,
196
+ states + results + model_selectors,
197
+ api_name="send_btn_anony"
198
+ ).then(
199
+ enable_mds,
200
+ None,
201
+ [geo_md, text_md, align_md]
202
+ ).then(
203
+ enable_buttons_side_by_side,
204
+ None,
205
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
206
+ )
207
+
208
+ clear_btn.click(
209
+ clear_i2s_history_side_by_side_anony,
210
+ None,
211
+ states + [imagebox] + results + model_selectors,
212
+ api_name="clear_btn_anony"
213
+ ).then(
214
+ disable_mds,
215
+ None,
216
+ [geo_md, text_md, align_md]
217
+ ).then(
218
+ disable_buttons_side_by_side,
219
+ None,
220
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
221
+ )
222
+
223
+ regenerate_btn.click(
224
+ sample_model,
225
+ states + [model_str],
226
+ states + model_selectors
227
+ ).then(
228
+ gen_func,
229
+ states + [imagebox] + model_selectors,
230
+ states + results + model_selectors,
231
+ api_name="regenerate_btn_anony"
232
+ ).then(
233
+ enable_mds,
234
+ None,
235
+ [geo_md, text_md, align_md]
236
+ ).then(
237
+ enable_buttons_side_by_side,
238
+ None,
239
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
240
+ )
241
+
242
+ share_btn.click(
243
+ share_click,
244
+ states + model_selectors,
245
+ [],
246
+ js=share_js
247
+ )
248
+
249
+
250
+ def build_i2s_ui_side_by_side_named(models):
251
+ notice_markdown = """
252
+ # ⚔️ GenAI-Arena ⚔️ : Benchmarking Image-to-3D generative models
253
+ ## 📜 Rules
254
+ - Generate with any two selected models side-by-side and vote!
255
+ - Sample or Input prompt you want to generate.
256
+ - Click "Send" to submit the prompt.
257
+ - Click "Clear" to start a new round.
258
+
259
+ ## 🏆 Arena Elo
260
+ Find out who is the 🥇conditional image generation models! More models are going to be supported.
261
+
262
+ ## 👇 Generating now!
263
+
264
+ """
265
+ model_list = models.get_i2s_models()
266
+ gen_func = partial(generate_i2s_multi, models.inference_parallel, models.render_parallel)
267
+
268
+ state_0 = gr.State()
269
+ state_1 = gr.State()
270
+
271
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
272
+
273
+ with gr.Group(elem_id="share-region-named"):
274
+ with gr.Accordion("🔍 Expand to see all Arena players", open=False):
275
+ model_description_md = get_model_description_md(model_list)
276
+ gr.Markdown(model_description_md, elem_id="model_description_markdown")
277
+
278
+ with gr.Row():
279
+ with gr.Column():
280
+ model_selector_left = gr.Dropdown(
281
+ choices=model_list,
282
+ value=model_list[0] if len(model_list) > 0 else "",
283
+ interactive=True,
284
+ show_label=False,
285
+ container=False,
286
+ )
287
+ with gr.Column():
288
+ model_selector_right = gr.Dropdown(
289
+ choices=model_list,
290
+ value=model_list[1] if len(model_list) > 1 else "",
291
+ interactive=True,
292
+ show_label=False,
293
+ container=False,
294
+ )
295
+
296
+ with gr.Row():
297
+
298
+ with gr.Column():
299
+ normal_left = gr.Image(width=512, label = "Model A", show_download_button=True)
300
+ rgb_left = gr.Image(width=512, label = "Model A", show_download_button=True)
301
+ with gr.Column():
302
+ normal_right = gr.Image(width=512, label = "Model B", show_download_button=True,)
303
+ rgb_right = gr.Image(width=512, label = "Model B", show_download_button=True,)
304
+
305
+ with gr.Row(elem_id="Geometry Quality"):
306
+ geo_md = gr.Markdown("Geometry Quality: ", visible=False, elem_id="evaldim_markdown")
307
+ geo_leftvote_btn = gr.Button(
308
+ value="👈 A is better", visible=False, interactive=False
309
+ )
310
+ geo_rightvote_btn = gr.Button(
311
+ value="👉 B is better", visible=False, interactive=False
312
+ )
313
+ geo_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
314
+ geo_bothbad_btn = gr.Button(
315
+ value="👎 Both are bad", visible=False, interactive=False
316
+ )
317
+
318
+ with gr.Row(elem_id="Texture Quality"):
319
+ text_md = gr.Markdown("Texture Quality: ", visible=False, elem_id="evaldim_markdown")
320
+ text_leftvote_btn = gr.Button(
321
+ value="👈 A is better", visible=False, interactive=False
322
+ )
323
+ text_rightvote_btn = gr.Button(
324
+ value="👉 B is better", visible=False, interactive=False
325
+ )
326
+ text_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
327
+ text_bothbad_btn = gr.Button(
328
+ value="👎 Both are bad", visible=False, interactive=False
329
+ )
330
+
331
+ with gr.Row(elem_id="Alignment Quality"):
332
+ align_md = gr.Markdown("Alignment Quality: ", visible=False, elem_id="evaldim_markdown")
333
+ align_leftvote_btn = gr.Button(
334
+ value="👈 A is better", visible=False, interactive=False
335
+ )
336
+ align_rightvote_btn = gr.Button(
337
+ value="👉 B is better", visible=False, interactive=False
338
+ )
339
+ align_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
340
+ align_bothbad_btn = gr.Button(
341
+ value="👎 Both are bad", visible=False, interactive=False
342
+ )
343
+
344
+ with gr.Row():
345
+ imagebox = gr.Image(
346
+ width=512,
347
+ show_label=False,
348
+ visible=True,
349
+ interactive=True,
350
+ elem_id="input_box",
351
+ )
352
+
353
+ with gr.Column():
354
+ sample_btn = gr.Button(value="🎲 Sample", variant="primary")
355
+ send_btn = gr.Button(value="📤 Send", variant="primary")
356
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
357
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
358
+ share_btn = gr.Button(value="📷 Share")
359
+
360
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
361
+
362
+ geo_btn_list = [geo_leftvote_btn, geo_rightvote_btn, geo_tie_btn, geo_bothbad_btn]
363
+ text_btn_list = [text_leftvote_btn, text_rightvote_btn, text_tie_btn, text_bothbad_btn]
364
+ align_btn_list = [align_leftvote_btn, align_rightvote_btn, align_tie_btn, align_bothbad_btn]
365
+ states = [state_0, state_1]
366
+ model_selectors = [model_selector_left, model_selector_right]
367
+ results = [normal_left, rgb_left, normal_right, rgb_right]
368
+
369
+ model_selector_left.change(
370
+ clear_i2s_history_side_by_side,
371
+ None,
372
+ states + [imagebox] + results,
373
+ api_name="model_selector_left"
374
+ )
375
+ model_selector_right.change(
376
+ clear_i2s_history_side_by_side,
377
+ None,
378
+ states + [imagebox] + results,
379
+ api_name="model_selector_right"
380
+ )
381
+
382
+ for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
383
+ leftvote_btn, rightvote_btn, tie_btn, bothbad_btn = btn_list
384
+
385
+ leftvote_btn.click(
386
+ leftvote_last_response_named,
387
+ states + model_selectors,
388
+ [imagebox] + btn_list
389
+ )
390
+ rightvote_btn.click(
391
+ rightvote_last_response_named,
392
+ states + model_selectors,
393
+ [imagebox] + btn_list
394
+ )
395
+ tie_btn.click(
396
+ tievote_last_response_named,
397
+ states + model_selectors,
398
+ [imagebox] + btn_list
399
+ )
400
+ bothbad_btn.click(
401
+ bothbad_vote_last_response_named,
402
+ states + model_selectors,
403
+ [imagebox] + btn_list
404
+ )
405
+
406
+ sample_btn.click(
407
+ sample_image_side_by_side,
408
+ states + model_selectors,
409
+ states + [imagebox],
410
+ api_name="sample_btn_named"
411
+ )
412
+
413
+ imagebox.upload(
414
+ gen_func,
415
+ states + [imagebox] + model_selectors,
416
+ states + results + model_selectors,
417
+ api_name="submit_btn_named"
418
+ ).then(
419
+ enable_mds,
420
+ None,
421
+ [geo_md, text_md, align_md]
422
+ ).then(
423
+ enable_buttons_side_by_side,
424
+ None,
425
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
426
+ )
427
+
428
+ send_btn.click(
429
+ gen_func,
430
+ states + [imagebox] + model_selectors,
431
+ states + results + model_selectors,
432
+ api_name="send_btn_named"
433
+ ).then(
434
+ enable_mds,
435
+ None,
436
+ [geo_md, text_md, align_md]
437
+ ).then(
438
+ enable_buttons_side_by_side,
439
+ None,
440
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
441
+ )
442
+
443
+ clear_btn.click(
444
+ clear_i2s_history_side_by_side,
445
+ None,
446
+ states + [imagebox] + results,
447
+ api_name="clear_btn_named"
448
+ ).then(
449
+ disable_mds,
450
+ None,
451
+ [geo_md, text_md, align_md]
452
+ ).then(
453
+ disable_buttons_side_by_side,
454
+ None,
455
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
456
+ )
457
+
458
+ regenerate_btn.click(
459
+ gen_func,
460
+ states + [imagebox] + model_selectors,
461
+ states + results + model_selectors,
462
+ api_name="regenerate_btn_named"
463
+ ).then(
464
+ enable_mds,
465
+ None,
466
+ [geo_md, text_md, align_md]
467
+ ).then(
468
+ enable_buttons_side_by_side,
469
+ None,
470
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
471
+ )
472
+
473
+ share_btn.click(
474
+ share_click,
475
+ states + model_selectors,
476
+ [],
477
+ js=share_js
478
+ )
479
+
480
+
481
+ def build_i2s_ui_single_model(models):
482
+ notice_markdown = """
483
+ # 🏔️ Play with Image Generation Models
484
+ {promotion}
485
+
486
+ ## 🤖 Choose any model to generate
487
+
488
+ """
489
+ model_list = models.get_i2s_models()
490
+ gen_func = partial(generate_i2s, models.inference_parallel, models.render_parallel)
491
+
492
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
493
+
494
+ with gr.Row():
495
+ with gr.Accordion("🔍 Expand to see all Arena players", open=False):
496
+ model_description_md = get_model_description_md(model_list)
497
+ gr.Markdown(model_description_md, elem_id="model_description_markdown")
498
+
499
+ with gr.Row(elem_id="model_selector_row"):
500
+ model_selector = gr.Dropdown(
501
+ choices=model_list,
502
+ value=model_list[0] if len(model_list) > 0 else "",
503
+ interactive=True,
504
+ show_label=False
505
+ )
506
+
507
+ with gr.Row():
508
+ normal = gr.Image(width=512, label = "Normal", show_download_button=True)
509
+ rgb = gr.Image(width=512, label = "RGB", show_download_button=True,)
510
+
511
+ with gr.Row():
512
+ imagebox = gr.Image(
513
+ width=512,
514
+ show_label=False,
515
+ visible=True,
516
+ interactive=True,
517
+ elem_id="input_box",
518
+ )
519
+ with gr.Column():
520
+ # with gr.Row():
521
+ sample_btn = gr.Button(value="🎲 Sample", variant="primary")
522
+ send_btn = gr.Button(value="📤 Send", variant="primary")
523
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
524
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
525
+
526
+
527
+ with gr.Row(elem_id="Geometry Quality"):
528
+ gr.Markdown("Geometry Quality: ")
529
+ geo_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
530
+ geo_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
531
+ geo_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
532
+
533
+ with gr.Row(elem_id="Texture Quality"):
534
+ gr.Markdown("Texture Quality: ")
535
+ text_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
536
+ text_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
537
+ text_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
538
+
539
+ with gr.Row(elem_id="Alignment Quality"):
540
+ gr.Markdown("Alignment Quality: ")
541
+ align_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
542
+ align_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
543
+ align_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
544
+
545
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
546
+
547
+ state = gr.State()
548
+ geo_btn_list = [geo_upvote_btn, geo_downvote_btn, geo_flag_btn]
549
+ text_btn_list = [text_upvote_btn, text_downvote_btn, text_flag_btn]
550
+ align_btn_list = [align_upvote_btn, align_downvote_btn, align_flag_btn]
551
+
552
+ for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
553
+ upvote_btn, downvote_btn, flag_btn = btn_list
554
+
555
+ upvote_btn.click(
556
+ upvote_last_response,
557
+ [state, model_selector],
558
+ [imagebox] + btn_list
559
+ )
560
+
561
+ downvote_btn.click(
562
+ downvote_last_response,
563
+ [state, model_selector],
564
+ [imagebox] + btn_list
565
+ )
566
+ flag_btn.click(
567
+ flag_last_response,
568
+ [state, model_selector],
569
+ [imagebox] + btn_list
570
+ )
571
+
572
+ sample_btn.click(
573
+ sample_image,
574
+ [state, model_selector],
575
+ [state, imagebox],
576
+ api_name="sample_btn_single"
577
+ )
578
+
579
+ imagebox.upload(
580
+ gen_func,
581
+ [state, imagebox, model_selector],
582
+ [state, normal, rgb],
583
+ api_name="submit_btn_single",
584
+ show_progress = "full"
585
+ ).then(
586
+ enable_buttons,
587
+ None,
588
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
589
+ )
590
+
591
+ send_btn.click(
592
+ gen_func,
593
+ [state, imagebox, model_selector],
594
+ [state, normal, rgb],
595
+ api_name="send_btn_single",
596
+ show_progress = "full"
597
+ ).then(
598
+ enable_buttons,
599
+ None,
600
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
601
+ )
602
+
603
+ clear_btn.click(
604
+ clear_i2s_history,
605
+ None,
606
+ [state, imagebox, normal, rgb],
607
+ api_name="clear_history_single",
608
+ show_progress="full"
609
+ ).then(
610
+ disable_buttons,
611
+ None,
612
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
613
+ )
614
+
615
+ regenerate_btn.click(
616
+ gen_func,
617
+ [state, imagebox, model_selector],
618
+ [state, normal, rgb],
619
+ api_name="regenerate_btn_single",
620
+ show_progress = "full"
621
+ ).then(
622
+ enable_buttons,
623
+ None,
624
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
625
+ )
626
+
serve/gradio_web_t2i_anony.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from functools import partial
3
+
4
+ from .utils import *
5
+ from .vote_utils import (
6
+ upvote_last_response_t2s as upvote_last_response,
7
+ downvote_last_response_t2s as downvote_last_response,
8
+ flag_last_response_t2s as flag_last_response,
9
+ leftvote_last_response_t2s_multi as leftvote_last_response,
10
+ rightvote_last_response_t2s_multi as rightvote_last_response,
11
+ tievote_last_response_t2s_multi as tievote_last_response,
12
+ bothbad_vote_last_response_t2s_multi as bothbad_vote_last_response,
13
+ share_click_t2s_multi as share_click,
14
+ share_js
15
+ )
16
+ from .inference import(
17
+ sample_model,
18
+ sample_prompt,
19
+ generate_t2s,
20
+ generate_t2s_multi,
21
+ generate_t2s_multi_annoy
22
+ )
23
+ from .constants import TEXT_PROMPT_PATH
24
+
25
+ with open(TEXT_PROMPT_PATH, 'r') as f:
26
+ prompt_list = json.load(f)
27
+
28
+ def build_side_by_side_ui_anony(models):
29
+ notice_markdown = """
30
+ # ⚔️ GenAI-Arena ⚔️ : Benchmarking Text-to-3D generative models
31
+ ## 📜 Rules
32
+ - Input prompt to two anonymous models in same area and vote for the better one!
33
+ - When the results are ready, click the button below to vote.
34
+ - Vote won't be counted if model identity is revealed during conversation.
35
+ - Click "Clear" to start a new round.
36
+
37
+ ## 🏆 Arena Elo
38
+ Find out who is the 🥇conditional image generation models! More models are going to be supported.
39
+
40
+ ## 👇 Generating now!
41
+
42
+ """
43
+ model_list = models.get_t2s_models()
44
+ gen_func = partial(generate_t2s_multi_annoy, models.inference_parallel, models.render_parallel)
45
+
46
+ state_0 = gr.State()
47
+ state_1 = gr.State()
48
+
49
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
50
+
51
+ with gr.Group(elem_id="share-region-anony"):
52
+ with gr.Accordion("🔍 Expand to see all Arena players", open=False):
53
+ model_description_md = get_model_description_md(model_list)
54
+ gr.Markdown(model_description_md, elem_id="model_description_markdown")
55
+ with gr.Row():
56
+ with gr.Column():
57
+ normal_left = gr.Image(width=512, label = "Model A", show_copy_button=True)
58
+ rgb_left = gr.Image(width=512, label = "Model A", show_copy_button=True)
59
+ with gr.Column():
60
+ normal_right = gr.Image(width=512, label = "Model B", show_copy_button=True,)
61
+ rgb_right = gr.Image(width=512, label = "Model B", show_copy_button=True,)
62
+
63
+ with gr.Row():
64
+ with gr.Column():
65
+ model_selector_left =gr.Markdown("", visible=False)
66
+ with gr.Column():
67
+ model_selector_right = gr.Markdown("", visible=False)
68
+ with gr.Row():
69
+ slow_warning = gr.Markdown("", elem_id="notice_markdown")
70
+
71
+ with gr.Row(elem_id="Geometry Quality"):
72
+ geo_leftvote_btn = gr.Button(
73
+ value="👈 A is better", visible=False, interactive=False
74
+ )
75
+ geo_rightvote_btn = gr.Button(
76
+ value="👉 B is better", visible=False, interactive=False
77
+ )
78
+ geo_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
79
+ geo_bothbad_btn = gr.Button(
80
+ value="👎 Both are bad", visible=False, interactive=False
81
+ )
82
+
83
+ with gr.Row(elem_id="Texture Quality"):
84
+ text_leftvote_btn = gr.Button(
85
+ value="👈 A is better", visible=False, interactive=False
86
+ )
87
+ text_rightvote_btn = gr.Button(
88
+ value="👉 B is better", visible=False, interactive=False
89
+ )
90
+ text_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
91
+ text_bothbad_btn = gr.Button(
92
+ value="👎 Both are bad", visible=False, interactive=False
93
+ )
94
+
95
+ with gr.Row(elem_id="Alignment Quality"):
96
+ align_leftvote_btn = gr.Button(
97
+ value="👈 A is better", visible=False, interactive=False
98
+ )
99
+ align_rightvote_btn = gr.Button(
100
+ value="👉 B is better", visible=False, interactive=False
101
+ )
102
+ align_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
103
+ align_bothbad_btn = gr.Button(
104
+ value="👎 Both are bad", visible=False, interactive=False
105
+ )
106
+
107
+ with gr.Row():
108
+ textbox = gr.Textbox(
109
+ show_label=False,
110
+ placeholder="👉 Enter your prompt or Sample a random prompt, and press ENTER",
111
+ container=True,
112
+ elem_id="input_box",
113
+ )
114
+ sample_btn = gr.Button(value="🎲 Sample", variant="primary", scale=0)
115
+ send_btn = gr.Button(value="📤 Send", variant="primary", scale=0)
116
+
117
+ with gr.Row():
118
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
119
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
120
+ share_btn = gr.Button(value="📷 Share")
121
+
122
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
123
+
124
+
125
+ geo_btn_list = [geo_leftvote_btn, geo_rightvote_btn, geo_tie_btn, geo_bothbad_btn]
126
+ text_btn_list = [text_leftvote_btn, text_rightvote_btn, text_tie_btn, text_bothbad_btn]
127
+ align_btn_list = [align_leftvote_btn, align_rightvote_btn, align_tie_btn, align_bothbad_btn]
128
+ states = [state_0, state_1]
129
+ model_selectors = [model_selector_left, model_selector_right]
130
+ results = [normal_left, rgb_left, normal_right, rgb_right]
131
+
132
+ for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
133
+ leftvote_btn, rightvote_btn, tie_btn, bothbad_btn = btn_list
134
+
135
+ leftvote_btn.click(
136
+ leftvote_last_response,
137
+ states + model_selectors,
138
+ [textbox] + btn_list + model_selectors
139
+ )
140
+ rightvote_btn.click(
141
+ rightvote_last_response,
142
+ states + model_selectors,
143
+ [textbox] + btn_list + model_selectors
144
+ )
145
+ tie_btn.click(
146
+ tievote_last_response,
147
+ states + model_selectors,
148
+ [textbox] + btn_list + model_selectors
149
+ )
150
+ bothbad_btn.click(
151
+ bothbad_vote_last_response,
152
+ states + model_selectors,
153
+ [textbox] + btn_list + model_selectors
154
+ )
155
+
156
+ sample_btn.click(
157
+ sample_prompt,
158
+ states + model_selectors + [prompt_list],
159
+ states + [textbox],
160
+ api_name="sample_btn_anony"
161
+ )
162
+
163
+ textbox.submit(
164
+ sample_model,
165
+ states + [model_list, False],
166
+ states + model_selectors
167
+ ).then(
168
+ gen_func,
169
+ states + [textbox] + model_selectors + [prompt_list],
170
+ states + results + model_selectors,
171
+ api_name="submit_btn_anony"
172
+ ).then(
173
+ enable_buttons_side_by_side,
174
+ None,
175
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
176
+ )
177
+
178
+ send_btn.click(
179
+ sample_model,
180
+ states + [model_list, False],
181
+ states + model_selectors
182
+ ).then(
183
+ gen_func,
184
+ states + [textbox] + model_selectors + [prompt_list],
185
+ states + results + model_selectors,
186
+ api_name="send_btn_anony"
187
+ ).then(
188
+ enable_buttons_side_by_side,
189
+ None,
190
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
191
+ )
192
+
193
+ clear_btn.click(
194
+ clear_history_side_by_side_anony,
195
+ None,
196
+ states + [textbox] + results + model_selectors,
197
+ api_name="clear_btn_anony"
198
+ ).then(
199
+ disable_buttons_side_by_side,
200
+ None,
201
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
202
+ )
203
+
204
+ regenerate_btn.click(
205
+ sample_model,
206
+ states + [model_list, False],
207
+ states + model_selectors
208
+ ).then(
209
+ gen_func,
210
+ states + [textbox] + model_selectors + [prompt_list],
211
+ states + results + model_selectors,
212
+ api_name="regenerate_btn_anony"
213
+ ).then(
214
+ enable_buttons_side_by_side,
215
+ None,
216
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
217
+ )
218
+
219
+ share_btn.click(
220
+ share_click,
221
+ states + model_selectors,
222
+ [],
223
+ js=share_js
224
+ )
225
+
serve/gradio_web_t2i_named.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from functools import partial
3
+
4
+ from .utils import *
5
+ from .vote_utils import (
6
+ upvote_last_response_t2s as upvote_last_response,
7
+ downvote_last_response_t2s as downvote_last_response,
8
+ flag_last_response_t2s as flag_last_response,
9
+ leftvote_last_response_t2s_multi as leftvote_last_response,
10
+ rightvote_last_response_t2s_multi as rightvote_last_response,
11
+ tievote_last_response_t2s_multi as tievote_last_response,
12
+ bothbad_vote_last_response_t2s_multi as bothbad_vote_last_response,
13
+ share_click_t2s_multi as share_click,
14
+ share_js
15
+ )
16
+ from .inference import(
17
+ sample_prompt,
18
+ generate_t2s_multi
19
+ )
20
+ from .constants import TEXT_PROMPT_PATH
21
+
22
+ with open(TEXT_PROMPT_PATH, 'r') as f:
23
+ prompt_list = json.load(f)
24
+
25
+ def build_side_by_side_ui_named(models):
26
+ notice_markdown = """
27
+ # ⚔️ GenAI-Arena ⚔️ : Benchmarking Text-to-3D generative models
28
+ ## 📜 Rules
29
+ - Generate with any two selected models side-by-side and vote!
30
+ - Sample or Input prompt you want to generate.
31
+ - Click "Send" to submit the prompt.
32
+ - Click "Clear" to start a new round.
33
+
34
+ ## 🏆 Arena Elo
35
+ Find out who is the 🥇conditional image generation models! More models are going to be supported.
36
+
37
+ ## 👇 Generating now!
38
+
39
+ """
40
+ model_list = models.get_t2s_models()
41
+ gen_func = partial(generate_t2s_multi, models.inference_parallel, models.render_parallel)
42
+
43
+ state_0 = gr.State()
44
+ state_1 = gr.State()
45
+
46
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
47
+
48
+ with gr.Group(elem_id="share-region-named"):
49
+ with gr.Row():
50
+ with gr.Column():
51
+ model_selector_left = gr.Dropdown(
52
+ choices=model_list,
53
+ value=model_list[0] if len(model_list) > 0 else "",
54
+ interactive=True,
55
+ show_label=False,
56
+ container=False,
57
+ )
58
+ with gr.Column():
59
+ model_selector_right = gr.Dropdown(
60
+ choices=model_list,
61
+ value=model_list[1] if len(model_list) > 1 else "",
62
+ interactive=True,
63
+ show_label=False,
64
+ container=False,
65
+ )
66
+
67
+ with gr.Accordion("🔍 Expand to see all Arena players", open=False):
68
+ model_description_md = get_model_description_md(model_list)
69
+ gr.Markdown(model_description_md, elem_id="model_description_markdown")
70
+
71
+ with gr.Row():
72
+ with gr.Column():
73
+ normal_left = gr.Image(width=512, label = "Model A", show_copy_button=True)
74
+ rgb_left = gr.Image(width=512, label = "Model A", show_copy_button=True)
75
+ with gr.Column():
76
+ normal_right = gr.Image(width=512, label = "Model B", show_copy_button=True,)
77
+ rgb_right = gr.Image(width=512, label = "Model B", show_copy_button=True,)
78
+
79
+ with gr.Row():
80
+ slow_warning = gr.Markdown("", elem_id="notice_markdown")
81
+
82
+ with gr.Row(elem_id="Geometry Quality"):
83
+ geo_leftvote_btn = gr.Button(
84
+ value="👈 A is better", visible=False, interactive=False
85
+ )
86
+ geo_rightvote_btn = gr.Button(
87
+ value="👉 B is better", visible=False, interactive=False
88
+ )
89
+ geo_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
90
+ geo_bothbad_btn = gr.Button(
91
+ value="👎 Both are bad", visible=False, interactive=False
92
+ )
93
+
94
+ with gr.Row(elem_id="Texture Quality"):
95
+ text_leftvote_btn = gr.Button(
96
+ value="👈 A is better", visible=False, interactive=False
97
+ )
98
+ text_rightvote_btn = gr.Button(
99
+ value="👉 B is better", visible=False, interactive=False
100
+ )
101
+ text_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
102
+ text_bothbad_btn = gr.Button(
103
+ value="👎 Both are bad", visible=False, interactive=False
104
+ )
105
+
106
+ with gr.Row(elem_id="Alignment Quality"):
107
+ align_leftvote_btn = gr.Button(
108
+ value="👈 A is better", visible=False, interactive=False
109
+ )
110
+ align_rightvote_btn = gr.Button(
111
+ value="👉 B is better", visible=False, interactive=False
112
+ )
113
+ align_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
114
+ align_bothbad_btn = gr.Button(
115
+ value="👎 Both are bad", visible=False, interactive=False
116
+ )
117
+
118
+ with gr.Row():
119
+ textbox = gr.Textbox(
120
+ show_label=False,
121
+ placeholder="👉 Enter your prompt or Sample a random prompt, and press ENTER",
122
+ container=True,
123
+ elem_id="input_box",
124
+ )
125
+ sample_btn = gr.Button(value="🎲 Sample", variant="primary", scale=0)
126
+ send_btn = gr.Button(value="📤 Send", variant="primary", scale=0)
127
+
128
+ with gr.Row():
129
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
130
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
131
+ share_btn = gr.Button(value="📷 Share")
132
+
133
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
134
+
135
+ geo_btn_list = [geo_leftvote_btn, geo_rightvote_btn, geo_tie_btn, geo_bothbad_btn]
136
+ text_btn_list = [text_leftvote_btn, text_rightvote_btn, text_tie_btn, text_bothbad_btn]
137
+ align_btn_list = [align_leftvote_btn, align_rightvote_btn, align_tie_btn, align_bothbad_btn]
138
+ states = [state_0, state_1]
139
+ model_selectors = [model_selector_left, model_selector_right]
140
+ results = [normal_left, rgb_left, normal_right, rgb_right]
141
+
142
+ model_selector_left.change(
143
+ clear_history_side_by_side,
144
+ None,
145
+ states + [textbox] + results,
146
+ api_name="model_selector_left"
147
+ )
148
+ model_selector_right.change(
149
+ clear_history_side_by_side,
150
+ None,
151
+ states + [textbox] + results,
152
+ api_name="model_selector_right"
153
+ )
154
+
155
+ for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
156
+ leftvote_btn, rightvote_btn, tie_btn, bothbad_btn = btn_list
157
+
158
+ leftvote_btn.click(
159
+ leftvote_last_response,
160
+ states + model_selectors,
161
+ [textbox] + btn_list + model_selectors
162
+ )
163
+ rightvote_btn.click(
164
+ rightvote_last_response,
165
+ states + model_selectors,
166
+ [textbox] + btn_list + model_selectors
167
+ )
168
+ tie_btn.click(
169
+ tievote_last_response,
170
+ states + model_selectors,
171
+ [textbox] + btn_list + model_selectors
172
+ )
173
+ bothbad_btn.click(
174
+ bothbad_vote_last_response,
175
+ states + model_selectors,
176
+ [textbox] + btn_list + model_selectors
177
+ )
178
+
179
+ sample_btn.click(
180
+ sample_prompt,
181
+ states + model_selectors + [prompt_list],
182
+ states + [textbox],
183
+ api_name="sample_btn_named"
184
+ )
185
+
186
+ textbox.then(
187
+ gen_func,
188
+ states + [textbox] + model_selectors + [prompt_list],
189
+ states + results + model_selectors,
190
+ api_name="submit_btn_named"
191
+ ).then(
192
+ enable_buttons_side_by_side,
193
+ None,
194
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
195
+ )
196
+
197
+ send_btn.click(
198
+ gen_func,
199
+ states + [textbox] + model_selectors + [prompt_list],
200
+ states + results + model_selectors,
201
+ api_name="send_btn_named"
202
+ ).then(
203
+ enable_buttons_side_by_side,
204
+ None,
205
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
206
+ )
207
+
208
+ clear_btn.click(
209
+ clear_history_side_by_side_anony,
210
+ None,
211
+ states + [textbox] + results + model_selectors,
212
+ api_name="clear_btn_named"
213
+ ).then(
214
+ disable_buttons_side_by_side,
215
+ None,
216
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
217
+ )
218
+
219
+ regenerate_btn.click(
220
+ gen_func,
221
+ states + [textbox] + model_selectors + [prompt_list],
222
+ states + results + model_selectors,
223
+ api_name="regenerate_btn_named"
224
+ ).then(
225
+ enable_buttons_side_by_side,
226
+ None,
227
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
228
+ )
229
+
230
+ share_btn.click(
231
+ share_click,
232
+ states + model_selectors,
233
+ [],
234
+ js=share_js
235
+ )
236
+
serve/gradio_web_t2i_single.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from functools import partial
3
+
4
+ from .utils import *
5
+ from .vote_utils import (
6
+ upvote_last_response_t2s as upvote_last_response,
7
+ downvote_last_response_t2s as downvote_last_response,
8
+ flag_last_response_t2s as flag_last_response,
9
+ )
10
+ from .inference import(
11
+ sample_prompt,
12
+ generate_t2s
13
+ )
14
+ from .constants import TEXT_PROMPT_PATH
15
+
16
+ with open(TEXT_PROMPT_PATH, 'r') as f:
17
+ prompt_list = json.load(f)
18
+
19
+ def build_single_model_ui(models):
20
+ notice_markdown = """
21
+ # 🏔️ Play with Image Generation Models
22
+ {promotion}
23
+
24
+ ## 🤖 Choose any model to generate
25
+
26
+ """
27
+ model_list = models.get_t2s_models()
28
+ gen_func = partial(generate_t2s, models.inference_parallel, models.render_parallel)
29
+
30
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
31
+
32
+ with gr.Row(elem_id="model_selector_row"):
33
+ model_selector = gr.Dropdown(
34
+ choices=model_list,
35
+ value=model_list[0] if len(model_list) > 0 else "",
36
+ interactive=True,
37
+ show_label=False
38
+ )
39
+
40
+ with gr.Row():
41
+ with gr.Accordion("🔍 Expand to see all Arena players", open=False):
42
+ model_description_md = get_model_description_md(model_list)
43
+ gr.Markdown(model_description_md, elem_id="model_description_markdown")
44
+
45
+ with gr.Row():
46
+ textbox = gr.Textbox(
47
+ show_label=False,
48
+ placeholder="👉 Enter your prompt or Sample a random prompt, and press ENTER",
49
+ container=True,
50
+ elem_id="input_box",
51
+ )
52
+ sample_btn = gr.Button(value="🎲 Sample", variant="primary", scale=0)
53
+ send_btn = gr.Button(value="📤 Send", variant="primary", scale=0)
54
+
55
+ with gr.Row():
56
+ normal = gr.Image(width=512, label = "Normal", show_copy_button=True)
57
+ rgb = gr.Image(width=512, label = "RGB", show_copy_button=True,)
58
+
59
+ with gr.Row():
60
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
61
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
62
+
63
+ with gr.Row(elem_id="Geometry Quality"):
64
+ geo_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
65
+ geo_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
66
+ geo_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
67
+
68
+ with gr.Row(elem_id="Texture Quality"):
69
+ text_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
70
+ text_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
71
+ text_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
72
+
73
+ with gr.Row(elem_id="Alignment Quality"):
74
+ align_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
75
+ align_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
76
+ align_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
77
+
78
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
79
+
80
+ state = gr.State()
81
+ geo_btn_list = [geo_upvote_btn, geo_downvote_btn, geo_flag_btn]
82
+ text_btn_list = [text_upvote_btn, text_downvote_btn, text_flag_btn]
83
+ align_btn_list = [align_upvote_btn, align_downvote_btn, align_flag_btn]
84
+
85
+ for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
86
+ upvote_btn, downvote_btn, flag_btn = btn_list
87
+
88
+ upvote_btn.click(
89
+ upvote_last_response,
90
+ [state, model_selector],
91
+ [textbox] + btn_list
92
+ )
93
+
94
+ downvote_btn.click(
95
+ downvote_last_response,
96
+ [state, model_selector],
97
+ [textbox] + btn_list
98
+ )
99
+ flag_btn.click(
100
+ flag_last_response,
101
+ [state, model_selector],
102
+ [textbox] + btn_list
103
+ )
104
+
105
+ sample_btn.click(
106
+ sample_prompt,
107
+ [state, model_selector, prompt_list],
108
+ state + [textbox],
109
+ api_name="sample_btn_single"
110
+ )
111
+
112
+ textbox.submit(
113
+ gen_func,
114
+ [state, textbox, model_selector, prompt_list],
115
+ [state, normal, rgb],
116
+ api_name="submit_btn_single",
117
+ show_progress = "full"
118
+ ).then(
119
+ enable_buttons,
120
+ None,
121
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
122
+ )
123
+
124
+ send_btn.click(
125
+ gen_func,
126
+ [state, textbox, model_selector, prompt_list],
127
+ [state, normal, rgb],
128
+ api_name="send_btn_single",
129
+ show_progress = "full"
130
+ ).then(
131
+ enable_buttons,
132
+ None,
133
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
134
+ )
135
+
136
+ clear_btn.click(
137
+ clear_history,
138
+ None,
139
+ [state, textbox, normal, rgb],
140
+ api_name="clear_history_single",
141
+ show_progress="full"
142
+ ).then(
143
+ disable_buttons,
144
+ None,
145
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
146
+ )
147
+
148
+ regenerate_btn.click(
149
+ gen_func,
150
+ [state, textbox, model_selector, prompt_list],
151
+ [state, normal, rgb],
152
+ api_name="regenerate_btn_single",
153
+ show_progress = "full"
154
+ ).then(
155
+ enable_buttons,
156
+ None,
157
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
158
+ )
159
+
serve/gradio_web_t2s.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from functools import partial
3
+
4
+ from .utils import *
5
+ from .vote_utils import (
6
+ upvote_last_response_t2s as upvote_last_response,
7
+ downvote_last_response_t2s as downvote_last_response,
8
+ flag_last_response_t2s as flag_last_response,
9
+ leftvote_last_response_t2s_anony as leftvote_last_response_anony,
10
+ rightvote_last_response_t2s_anony as rightvote_last_response_anony,
11
+ tievote_last_response_t2s_anony as tievote_last_response_anony,
12
+ bothbad_vote_last_response_t2s_anony as bothbad_vote_last_response_anony,
13
+ leftvote_last_response_t2s_named as leftvote_last_response_named,
14
+ rightvote_last_response_t2s_named as rightvote_last_response_named,
15
+ tievote_last_response_t2s_named as tievote_last_response_named,
16
+ bothbad_vote_last_response_t2s_named as bothbad_vote_last_response_named,
17
+ share_click_t2s_multi as share_click,
18
+ share_js
19
+ )
20
+ from .inference import(
21
+ sample_t2s_model as sample_model,
22
+ sample_prompt,
23
+ sample_prompt_side_by_side,
24
+ generate_t2s,
25
+ generate_t2s_multi,
26
+ generate_t2s_multi_annoy
27
+ )
28
+
29
+
30
+ def build_t2s_ui_side_by_side_anony(models):
31
+ notice_markdown = """
32
+ # ⚔️ GenAI-Arena ⚔️ : Benchmarking Text-to-3D generative models
33
+ ## 📜 Rules
34
+ - Input prompt to two anonymous models in same area and vote for the better one!
35
+ - When the results are ready, click the button below to vote.
36
+ - Vote won't be counted if model identity is revealed during conversation.
37
+ - Click "Clear" to start a new round.
38
+
39
+ ## 🏆 Arena Elo
40
+ Find out who is the 🥇conditional image generation models! More models are going to be supported.
41
+
42
+ ## 👇 Generating now!
43
+
44
+ """
45
+ model_list = models.get_t2s_models()
46
+ gen_func = partial(generate_t2s_multi_annoy, models.inference_parallel, models.render_parallel)
47
+
48
+
49
+ state_0 = gr.State()
50
+ state_1 = gr.State()
51
+
52
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
53
+ model_str = gr.Markdown(str(model_list), visible=False, elem_id="model list string")
54
+
55
+ with gr.Group(elem_id="share-region-anony"):
56
+ with gr.Accordion("🔍 Expand to see all Arena players", open=False):
57
+ model_description_md = get_model_description_md(model_list)
58
+ gr.Markdown(model_description_md, elem_id="model_description_markdown")
59
+ with gr.Row():
60
+ with gr.Column():
61
+ normal_left = gr.Image(width=512, label = "Model A", show_download_button=True)
62
+ rgb_left = gr.Image(width=512, label = "Model A", show_download_button=True)
63
+ with gr.Column():
64
+ normal_right = gr.Image(width=512, label = "Model B", show_download_button=True,)
65
+ rgb_right = gr.Image(width=512, label = "Model B", show_download_button=True,)
66
+
67
+ with gr.Row():
68
+ with gr.Column():
69
+ model_selector_left =gr.Markdown("", visible=False)
70
+ with gr.Column():
71
+ model_selector_right = gr.Markdown("", visible=False)
72
+ with gr.Row():
73
+ slow_warning = gr.Markdown("", elem_id="notice_markdown")
74
+
75
+ with gr.Row(elem_id="Geometry Quality"):
76
+ geo_md = gr.Markdown("Geometry Quality: ", visible=False, elem_id="evaldim_markdown")
77
+ geo_leftvote_btn = gr.Button(
78
+ value="👈 A is better", visible=False, interactive=False
79
+ )
80
+ geo_rightvote_btn = gr.Button(
81
+ value="👉 B is better", visible=False, interactive=False
82
+ )
83
+ geo_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
84
+ geo_bothbad_btn = gr.Button(
85
+ value="👎 Both are bad", visible=False, interactive=False
86
+ )
87
+
88
+ with gr.Row(elem_id="Texture Quality"):
89
+ text_md = gr.Markdown("Texture Quality: ", visible=False, elem_id="evaldim_markdown")
90
+ text_leftvote_btn = gr.Button(
91
+ value="👈 A is better", visible=False, interactive=False
92
+ )
93
+ text_rightvote_btn = gr.Button(
94
+ value="👉 B is better", visible=False, interactive=False
95
+ )
96
+ text_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
97
+ text_bothbad_btn = gr.Button(
98
+ value="👎 Both are bad", visible=False, interactive=False
99
+ )
100
+
101
+ with gr.Row(elem_id="Alignment Quality"):
102
+ align_md = gr.Markdown("Alignment Quality: ", visible=False, elem_id="evaldim_markdown")
103
+ align_leftvote_btn = gr.Button(
104
+ value="👈 A is better", visible=False, interactive=False
105
+ )
106
+ align_rightvote_btn = gr.Button(
107
+ value="👉 B is better", visible=False, interactive=False
108
+ )
109
+ align_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
110
+ align_bothbad_btn = gr.Button(
111
+ value="👎 Both are bad", visible=False, interactive=False
112
+ )
113
+
114
+ with gr.Row():
115
+ textbox = gr.Textbox(
116
+ show_label=False,
117
+ placeholder="👉 Enter your prompt or Sample a random prompt, and press ENTER",
118
+ container=True,
119
+ elem_id="input_box",
120
+ )
121
+ sample_btn = gr.Button(value="🎲 Sample", variant="primary", scale=0)
122
+ send_btn = gr.Button(value="📤 Send", variant="primary", scale=0)
123
+
124
+ with gr.Row():
125
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
126
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
127
+ share_btn = gr.Button(value="📷 Share")
128
+
129
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
130
+
131
+ geo_btn_list = [geo_leftvote_btn, geo_rightvote_btn, geo_tie_btn, geo_bothbad_btn]
132
+ text_btn_list = [text_leftvote_btn, text_rightvote_btn, text_tie_btn, text_bothbad_btn]
133
+ align_btn_list = [align_leftvote_btn, align_rightvote_btn, align_tie_btn, align_bothbad_btn]
134
+ states = [state_0, state_1]
135
+ model_selectors = [model_selector_left, model_selector_right]
136
+ results = [normal_left, rgb_left, normal_right, rgb_right]
137
+
138
+ for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
139
+ leftvote_btn, rightvote_btn, tie_btn, bothbad_btn = btn_list
140
+
141
+ leftvote_btn.click(
142
+ leftvote_last_response_anony,
143
+ states + model_selectors,
144
+ [textbox] + btn_list + model_selectors
145
+ )
146
+ rightvote_btn.click(
147
+ rightvote_last_response_anony,
148
+ states + model_selectors,
149
+ [textbox] + btn_list + model_selectors
150
+ )
151
+ tie_btn.click(
152
+ tievote_last_response_anony,
153
+ states + model_selectors,
154
+ [textbox] + btn_list + model_selectors
155
+ )
156
+ bothbad_btn.click(
157
+ bothbad_vote_last_response_anony,
158
+ states + model_selectors,
159
+ [textbox] + btn_list + model_selectors
160
+ )
161
+
162
+ sample_btn.click(
163
+ sample_prompt_side_by_side,
164
+ states + model_selectors,
165
+ states + [textbox],
166
+ api_name="sample_btn_anony"
167
+ )
168
+
169
+ textbox.submit(
170
+ sample_model,
171
+ states + [model_str],
172
+ states + model_selectors
173
+ ).then(
174
+ gen_func,
175
+ states + [textbox] + model_selectors,
176
+ states + results + model_selectors,
177
+ api_name="submit_btn_anony"
178
+ ).then(
179
+ enable_mds,
180
+ None,
181
+ [geo_md, text_md, align_md]
182
+ ).then(
183
+ enable_buttons_side_by_side,
184
+ None,
185
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
186
+ )
187
+
188
+ send_btn.click(
189
+ sample_model,
190
+ states + [model_str],
191
+ states + model_selectors
192
+ ).then(
193
+ gen_func,
194
+ states + [textbox] + model_selectors,
195
+ states + results + model_selectors,
196
+ api_name="send_btn_anony"
197
+ ).then(
198
+ enable_mds,
199
+ None,
200
+ [geo_md, text_md, align_md]
201
+ ).then(
202
+ enable_buttons_side_by_side,
203
+ None,
204
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
205
+ )
206
+
207
+ clear_btn.click(
208
+ clear_t2s_history_side_by_side_anony,
209
+ None,
210
+ states + [textbox] + results + model_selectors,
211
+ api_name="clear_btn_anony"
212
+ ).then(
213
+ disable_mds,
214
+ None,
215
+ [geo_md, text_md, align_md]
216
+ ).then(
217
+ disable_buttons_side_by_side,
218
+ None,
219
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
220
+ )
221
+
222
+ regenerate_btn.click(
223
+ sample_model,
224
+ states + [model_str],
225
+ states + model_selectors
226
+ ).then(
227
+ gen_func,
228
+ states + [textbox] + model_selectors,
229
+ states + results + model_selectors,
230
+ api_name="regenerate_btn_anony"
231
+ ).then(
232
+ enable_mds,
233
+ None,
234
+ [geo_md, text_md, align_md]
235
+ ).then(
236
+ enable_buttons_side_by_side,
237
+ None,
238
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
239
+ )
240
+
241
+ share_btn.click(
242
+ share_click,
243
+ states + model_selectors,
244
+ [],
245
+ js=share_js
246
+ )
247
+
248
+
249
+ def build_t2s_ui_side_by_side_named(models):
250
+ notice_markdown = """
251
+ # ⚔️ GenAI-Arena ⚔️ : Benchmarking Text-to-3D generative models
252
+ ## 📜 Rules
253
+ - Generate with any two selected models side-by-side and vote!
254
+ - Sample or Input prompt you want to generate.
255
+ - Click "Send" to submit the prompt.
256
+ - Click "Clear" to start a new round.
257
+
258
+ ## 🏆 Arena Elo
259
+ Find out who is the 🥇conditional image generation models! More models are going to be supported.
260
+
261
+ ## 👇 Generating now!
262
+
263
+ """
264
+ model_list = models.get_t2s_models()
265
+ gen_func = partial(generate_t2s_multi, models.inference_parallel, models.render_parallel)
266
+
267
+ state_0 = gr.State()
268
+ state_1 = gr.State()
269
+
270
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
271
+
272
+ with gr.Group(elem_id="share-region-named"):
273
+ with gr.Accordion("🔍 Expand to see all Arena players", open=False):
274
+ model_description_md = get_model_description_md(model_list)
275
+ gr.Markdown(model_description_md, elem_id="model_description_markdown")
276
+
277
+ with gr.Row():
278
+ with gr.Column():
279
+ model_selector_left = gr.Dropdown(
280
+ choices=model_list,
281
+ value=model_list[0] if len(model_list) > 0 else "",
282
+ interactive=True,
283
+ show_label=False,
284
+ container=False,
285
+ )
286
+ with gr.Column():
287
+ model_selector_right = gr.Dropdown(
288
+ choices=model_list,
289
+ value=model_list[1] if len(model_list) > 1 else "",
290
+ interactive=True,
291
+ show_label=False,
292
+ container=False,
293
+ )
294
+
295
+ with gr.Row():
296
+ with gr.Column():
297
+ normal_left = gr.Image(width=512, label = "Model A", show_download_button=True)
298
+ rgb_left = gr.Image(width=512, label = "Model A", show_download_button=True)
299
+ with gr.Column():
300
+ normal_right = gr.Image(width=512, label = "Model B", show_download_button=True,)
301
+ rgb_right = gr.Image(width=512, label = "Model B", show_download_button=True,)
302
+
303
+ with gr.Row():
304
+ slow_warning = gr.Markdown("", elem_id="notice_markdown")
305
+
306
+ with gr.Row(elem_id="Geometry Quality"):
307
+ geo_md = gr.Markdown("Geometry Quality: ", visible=False, elem_id="evaldim_markdown")
308
+ geo_leftvote_btn = gr.Button(
309
+ value="👈 A is better", visible=False, interactive=False
310
+ )
311
+ geo_rightvote_btn = gr.Button(
312
+ value="👉 B is better", visible=False, interactive=False
313
+ )
314
+ geo_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
315
+ geo_bothbad_btn = gr.Button(
316
+ value="👎 Both are bad", visible=False, interactive=False
317
+ )
318
+
319
+ with gr.Row(elem_id="Texture Quality"):
320
+ text_md = gr.Markdown("Texture Quality: ", visible=False, elem_id="evaldim_markdown")
321
+ text_leftvote_btn = gr.Button(
322
+ value="👈 A is better", visible=False, interactive=False
323
+ )
324
+ text_rightvote_btn = gr.Button(
325
+ value="👉 B is better", visible=False, interactive=False
326
+ )
327
+ text_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
328
+ text_bothbad_btn = gr.Button(
329
+ value="👎 Both are bad", visible=False, interactive=False
330
+ )
331
+
332
+ with gr.Row(elem_id="Alignment Quality"):
333
+ align_md = gr.Markdown("Alignment Quality: ", visible=False, elem_id="evaldim_markdown")
334
+ align_leftvote_btn = gr.Button(
335
+ value="👈 A is better", visible=False, interactive=False
336
+ )
337
+ align_rightvote_btn = gr.Button(
338
+ value="👉 B is better", visible=False, interactive=False
339
+ )
340
+ align_tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
341
+ align_bothbad_btn = gr.Button(
342
+ value="👎 Both are bad", visible=False, interactive=False
343
+ )
344
+
345
+ with gr.Row():
346
+ textbox = gr.Textbox(
347
+ show_label=False,
348
+ placeholder="👉 Enter your prompt or Sample a random prompt, and press ENTER",
349
+ container=True,
350
+ elem_id="input_box",
351
+ )
352
+ sample_btn = gr.Button(value="🎲 Sample", variant="primary", scale=0)
353
+ send_btn = gr.Button(value="📤 Send", variant="primary", scale=0)
354
+
355
+ with gr.Row():
356
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
357
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
358
+ share_btn = gr.Button(value="📷 Share")
359
+
360
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
361
+
362
+ geo_btn_list = [geo_leftvote_btn, geo_rightvote_btn, geo_tie_btn, geo_bothbad_btn]
363
+ text_btn_list = [text_leftvote_btn, text_rightvote_btn, text_tie_btn, text_bothbad_btn]
364
+ align_btn_list = [align_leftvote_btn, align_rightvote_btn, align_tie_btn, align_bothbad_btn]
365
+ states = [state_0, state_1]
366
+ model_selectors = [model_selector_left, model_selector_right]
367
+ results = [normal_left, rgb_left, normal_right, rgb_right]
368
+
369
+ model_selector_left.change(
370
+ clear_t2s_history_side_by_side,
371
+ None,
372
+ states + [textbox] + results,
373
+ api_name="model_selector_left"
374
+ )
375
+ model_selector_right.change(
376
+ clear_t2s_history_side_by_side,
377
+ None,
378
+ states + [textbox] + results,
379
+ api_name="model_selector_right"
380
+ )
381
+
382
+ for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
383
+ leftvote_btn, rightvote_btn, tie_btn, bothbad_btn = btn_list
384
+
385
+ leftvote_btn.click(
386
+ leftvote_last_response_named,
387
+ states + model_selectors,
388
+ [textbox] + btn_list
389
+ )
390
+ rightvote_btn.click(
391
+ rightvote_last_response_named,
392
+ states + model_selectors,
393
+ [textbox] + btn_list
394
+ )
395
+ tie_btn.click(
396
+ tievote_last_response_named,
397
+ states + model_selectors,
398
+ [textbox] + btn_list
399
+ )
400
+ bothbad_btn.click(
401
+ bothbad_vote_last_response_named,
402
+ states + model_selectors,
403
+ [textbox] + btn_list
404
+ )
405
+
406
+ sample_btn.click(
407
+ sample_prompt_side_by_side,
408
+ states + model_selectors,
409
+ states + [textbox],
410
+ api_name="sample_btn_named"
411
+ )
412
+
413
+ textbox.submit(
414
+ gen_func,
415
+ states + [textbox] + model_selectors,
416
+ states + results + model_selectors,
417
+ api_name="submit_btn_named"
418
+ ).then(
419
+ enable_mds,
420
+ None,
421
+ [geo_md, text_md, align_md]
422
+ ).then(
423
+ enable_buttons_side_by_side,
424
+ None,
425
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
426
+ )
427
+
428
+ send_btn.click(
429
+ gen_func,
430
+ states + [textbox] + model_selectors,
431
+ states + results + model_selectors,
432
+ api_name="send_btn_named"
433
+ ).then(
434
+ enable_mds,
435
+ None,
436
+ [geo_md, text_md, align_md]
437
+ ).then(
438
+ enable_buttons_side_by_side,
439
+ None,
440
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
441
+ )
442
+
443
+ clear_btn.click(
444
+ clear_t2s_history_side_by_side,
445
+ None,
446
+ states + [textbox] + results,
447
+ api_name="clear_btn_named"
448
+ ).then(
449
+ disable_mds,
450
+ None,
451
+ [geo_md, text_md, align_md]
452
+ ).then(
453
+ disable_buttons_side_by_side,
454
+ None,
455
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
456
+ )
457
+
458
+ regenerate_btn.click(
459
+ gen_func,
460
+ states + [textbox] + model_selectors,
461
+ states + results + model_selectors,
462
+ api_name="regenerate_btn_named"
463
+ ).then(
464
+ enable_mds,
465
+ None,
466
+ [geo_md, text_md, align_md]
467
+ ).then(
468
+ enable_buttons_side_by_side,
469
+ None,
470
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
471
+ )
472
+
473
+ share_btn.click(
474
+ share_click,
475
+ states + model_selectors,
476
+ [],
477
+ js=share_js
478
+ )
479
+
480
+
481
+ def build_t2s_ui_single_model(models):
482
+ notice_markdown = """
483
+ # 🏔️ Play with Image Generation Models
484
+ {promotion}
485
+
486
+ ## 🤖 Choose any model to generate
487
+
488
+ """
489
+ model_list = models.get_t2s_models()
490
+ gen_func = partial(generate_t2s, models.inference_parallel, models.render_parallel)
491
+
492
+ gr.Markdown(notice_markdown, elem_id="notice_markdown")
493
+
494
+ with gr.Row():
495
+ with gr.Accordion("🔍 Expand to see all Arena players", open=False):
496
+ model_description_md = get_model_description_md(model_list)
497
+ gr.Markdown(model_description_md, elem_id="model_description_markdown")
498
+
499
+ with gr.Row(elem_id="model_selector_row"):
500
+ model_selector = gr.Dropdown(
501
+ choices=model_list,
502
+ value=model_list[0] if len(model_list) > 0 else "",
503
+ interactive=True,
504
+ show_label=False
505
+ )
506
+
507
+ with gr.Row():
508
+ normal = gr.Image(width=512, label = "Normal", show_download_button=True)
509
+ rgb = gr.Image(width=512, label = "RGB", show_download_button=True,)
510
+
511
+ with gr.Row():
512
+ textbox = gr.Textbox(
513
+ show_label=False,
514
+ placeholder="👉 Enter your prompt or Sample a random prompt, and press ENTER",
515
+ container=True,
516
+ elem_id="input_box",
517
+ )
518
+ sample_btn = gr.Button(value="🎲 Sample", variant="primary", scale=0)
519
+ send_btn = gr.Button(value="📤 Send", variant="primary", scale=0)
520
+
521
+ with gr.Row():
522
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
523
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
524
+
525
+
526
+ with gr.Row(elem_id="Geometry Quality"):
527
+ gr.Markdown("Geometry Quality: ", elem_id="evaldim_markdown")
528
+ geo_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
529
+ geo_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
530
+ geo_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
531
+
532
+
533
+ with gr.Row(elem_id="Texture Quality"):
534
+ gr.Markdown("Texture Quality: ", elem_id="evaldim_markdown")
535
+ text_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
536
+ text_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
537
+ text_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
538
+
539
+ with gr.Row(elem_id="Alignment Quality"):
540
+ gr.Markdown("Alignment Quality: ", elem_id="evaldim_markdown")
541
+ align_upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
542
+ align_downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
543
+ align_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
544
+
545
+ gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
546
+
547
+ state = gr.State()
548
+ geo_btn_list = [geo_upvote_btn, geo_downvote_btn, geo_flag_btn]
549
+ text_btn_list = [text_upvote_btn, text_downvote_btn, text_flag_btn]
550
+ align_btn_list = [align_upvote_btn, align_downvote_btn, align_flag_btn]
551
+
552
+ for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
553
+ upvote_btn, downvote_btn, flag_btn = btn_list
554
+
555
+ upvote_btn.click(
556
+ upvote_last_response,
557
+ [state, model_selector],
558
+ [textbox] + btn_list
559
+ )
560
+
561
+ downvote_btn.click(
562
+ downvote_last_response,
563
+ [state, model_selector],
564
+ [textbox] + btn_list
565
+ )
566
+ flag_btn.click(
567
+ flag_last_response,
568
+ [state, model_selector],
569
+ [textbox] + btn_list
570
+ )
571
+
572
+ sample_btn.click(
573
+ sample_prompt,
574
+ [state, model_selector],
575
+ [state, textbox],
576
+ api_name="sample_btn_single"
577
+ )
578
+
579
+ textbox.submit(
580
+ gen_func,
581
+ [state, textbox, model_selector],
582
+ [state, normal, rgb],
583
+ api_name="submit_btn_single",
584
+ show_progress = "full"
585
+ ).then(
586
+ enable_buttons,
587
+ None,
588
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
589
+ )
590
+
591
+ send_btn.click(
592
+ gen_func,
593
+ [state, textbox, model_selector],
594
+ [state, normal, rgb],
595
+ api_name="send_btn_single",
596
+ show_progress = "full"
597
+ ).then(
598
+ enable_buttons,
599
+ None,
600
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
601
+ )
602
+
603
+ clear_btn.click(
604
+ clear_t2s_history,
605
+ None,
606
+ [state, textbox, normal, rgb],
607
+ api_name="clear_history_single",
608
+ show_progress="full"
609
+ ).then(
610
+ disable_buttons,
611
+ None,
612
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
613
+ )
614
+
615
+ regenerate_btn.click(
616
+ gen_func,
617
+ [state, textbox, model_selector],
618
+ [state, normal, rgb],
619
+ api_name="regenerate_btn_single",
620
+ show_progress = "full"
621
+ ).then(
622
+ enable_buttons,
623
+ None,
624
+ geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
625
+ )
626
+
serve/inference.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## All Generation Gradio Interface
2
+ import uuid
3
+ import time
4
+
5
+ from .utils import *
6
+ from .vote_utils import t2s_logger, t2s_multi_logger, i2s_logger, i2s_multi_logger
7
+ from .constants import IMAGE_DIR, OFFLINE_DIR, TEXT_PROMPT_PATH
8
+
9
+ with open(TEXT_PROMPT_PATH, 'r') as f:
10
+ prompt_list = json.load(f)
11
+
12
+
13
+ class State:
14
+ def __init__(self,
15
+ model_name, i2s_mode=False, offline=False,
16
+ prompt=None, image=None, offline_idx=None,
17
+ normal_video=None , rgb_video=None):
18
+ self.conv_id = uuid.uuid4().hex
19
+ self.model_name = model_name
20
+ self.i2s_mode = i2s_mode
21
+ self.offline = offline
22
+
23
+ self.prompt = prompt
24
+ self.image = image
25
+ self.offline_idx = offline_idx
26
+ # self.output = None
27
+ self.normal_video = normal_video
28
+ self.rgb_video = rgb_video
29
+
30
+ def dict(self):
31
+ base = {
32
+ "conv_id": self.conv_id,
33
+ "model_name": self.model_name,
34
+ "i2s_mode": self.i2s_mode,
35
+ "offline": self.offline,
36
+ "prompt": self.prompt
37
+ }
38
+ if not self.offline and not self.offline_idx:
39
+ base['offline_idx'] = self.offline_idx
40
+ return base
41
+
42
+ # class StateI2S:
43
+ # def __init__(self, model_name):
44
+ # self.conv_id = uuid.uuid4().hex
45
+ # self.model_name = model_name
46
+ # self.image = None
47
+ # self.output = None
48
+
49
+ # def dict(self):
50
+ # base = {
51
+ # "conv_id": self.conv_id,
52
+ # "model_name": self.model_name,
53
+
54
+ # }
55
+ # return base
56
+
57
+ def sample_t2s_model(state_0, state_1, model_list):
58
+ model_name_0, model_name_1 = random.sample(eval(model_list), 2)
59
+
60
+ if state_0 is None:
61
+ state_0 = State(model_name_0, i2s_mode=False)
62
+ if state_1 is None:
63
+ state_1 = State(model_name_1, i2s_mode=False)
64
+
65
+ state_0.model_name = model_name_0
66
+ state_0.i2s_mode = False
67
+ state_1.model_name = model_name_1
68
+ state_1.i2s_mode = False
69
+ return state_0, state_1, model_name_0, model_name_1
70
+
71
+ def sample_i2s_model(state_0, state_1, model_list):
72
+ model_name_0, model_name_1 = random.sample(eval(model_list), 2)
73
+
74
+ if state_0 is None:
75
+ state_0 = State(model_name_0, i2s_mode=True)
76
+ if state_1 is None:
77
+ state_1 = State(model_name_1, i2s_mode=True)
78
+
79
+ state_0.model_name = model_name_0
80
+ state_0.i2s_mode = True
81
+ state_1.model_name = model_name_1
82
+ state_1.i2s_mode = True
83
+ return state_0, state_1, model_name_0, model_name_1
84
+
85
+ def sample_prompt(state, model_name):
86
+ if state is None:
87
+ state = State(model_name)
88
+
89
+ idx = random.randint(0, len(prompt_list)-1)
90
+ prompt = prompt_list[idx]
91
+
92
+ state.model_name = model_name
93
+ state.prompt = prompt
94
+ return state, prompt
95
+
96
+ def sample_prompt_side_by_side(state_0, state_1, model_name_0, model_name_1):
97
+ if state_0 is None:
98
+ state_0 = State(model_name_0)
99
+ if state_1 is None:
100
+ state_1 = State(model_name_1)
101
+
102
+ idx = random.randint(0, len(prompt_list)-1)
103
+ prompt = prompt_list[idx]
104
+
105
+ state_0.offline, state_1.offline = True, True
106
+ state_0.offline_idx, state_1.offline_idx = idx, idx
107
+ state_0.prompt, state_1.prompt = prompt, prompt
108
+ return state_0, state_1, prompt
109
+
110
+ def sample_image(state, model_name):
111
+ if state is None:
112
+ state = State(model_name)
113
+
114
+ idx = random.randint(0, len(prompt_list)-1)
115
+ prompt = prompt_list[idx]
116
+
117
+ state.model_name = model_name
118
+ state.prompt = prompt
119
+ return state, prompt
120
+
121
+ def sample_image_side_by_side(state_0, state_1, model_name_0, model_name_1):
122
+ if state_0 is None:
123
+ state_0 = State(model_name_0)
124
+ if state_1 is None:
125
+ state_1 = State(model_name_1)
126
+
127
+ idx = random.randint(0, len(prompt_list)-1)
128
+ prompt = prompt_list[idx]
129
+
130
+ state_0.offline, state_1.offline = True, True
131
+ state_0.offline_idx, state_1.offline_idx = idx, idx
132
+ state_0.prompt, state_1.prompt = prompt, prompt
133
+ return state_0, state_1, prompt
134
+
135
+ def generate_t2s(gen_func, render_func,
136
+ state,
137
+ text,
138
+ model_name,
139
+ request: gr.Request):
140
+ if not text:
141
+ raise gr.Warning("Prompt cannot be empty.")
142
+ if not model_name:
143
+ raise gr.Warning("Model name cannot be empty.")
144
+
145
+ if state is None:
146
+ state = State(model_name, i2s_mode=False, offline=False)
147
+
148
+ ip = get_ip(request)
149
+ t2s_logger.info(f"generate. ip: {ip}")
150
+
151
+ state.model_name = model_name
152
+ state.prompt = text
153
+ try:
154
+ idx = prompt_list.index(text)
155
+ state.offline = True
156
+ state.offline_idx = idx
157
+ except:
158
+ state.offline = False
159
+ state.offline_idx = None
160
+
161
+ if not state.offline and not state.offline_idx:
162
+ start_time = time.time()
163
+ normal_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "normal", f"{state.offline_idx}.mp4")
164
+ rgb_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "rgb", f"{state.offline_idx}.mp4")
165
+
166
+ state.normal_video = normal_video
167
+ state.rgb_video = rgb_video
168
+ yield state, normal_video, rgb_video
169
+
170
+ # logger.info(f"===output===: {output}")
171
+ data = {
172
+ "ip": ip,
173
+ "model": model_name,
174
+ "type": "offline",
175
+ "gen_params": {},
176
+ "state": state.dict(),
177
+ "start": round(start_time, 4),
178
+ }
179
+ else:
180
+ start_time = time.time()
181
+ shape = gen_func(text, model_name)
182
+ generate_time = time.time() - start_time
183
+
184
+ normal_video, rgb_video = render_func(shape, model_name)
185
+ finish_time = time.time()
186
+ render_time = finish_time - start_time - generate_time
187
+
188
+ state.normal_video = normal_video
189
+ state.rgb_video = rgb_video
190
+ yield state, normal_video, rgb_video
191
+
192
+ # logger.info(f"===output===: {output}")
193
+ data = {
194
+ "ip": ip,
195
+ "model": model_name,
196
+ "type": "online",
197
+ "gen_params": {},
198
+ "state": state.dict(),
199
+ "start": round(start_time, 4),
200
+ "time": round(finish_time - start_time, 4),
201
+ "generate_time": round(generate_time, 4),
202
+ "render_time": round(render_time, 4),
203
+ }
204
+
205
+ with open(get_conv_log_filename(), "a") as fout:
206
+ fout.write(json.dumps(data) + "\n")
207
+ append_json_item_on_log_server(data, get_conv_log_filename())
208
+
209
+ # output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png'
210
+ # os.makedirs(os.path.dirname(output_file), exist_ok=True)
211
+ # with open(output_file, 'w') as f:
212
+ # state.output.save(f, 'PNG')
213
+ # save_image_file_on_log_server(output_file)
214
+
215
+ def generate_t2s_multi(gen_func, render_func,
216
+ state_0, state_1,
217
+ text,
218
+ model_name_0, model_name_1,
219
+ request: gr.Request):
220
+ if not text:
221
+ raise gr.Warning("Prompt cannot be empty.")
222
+ if not model_name_0:
223
+ raise gr.Warning("Model name A cannot be empty.")
224
+ if not model_name_1:
225
+ raise gr.Warning("Model name B cannot be empty.")
226
+
227
+ if state_0 is None:
228
+ state_0 = State(model_name_0, i2s_mode=False, offline=False)
229
+ if state_1 is None:
230
+ state_1 = State(model_name_1, i2s_mode=False, offline=False)
231
+
232
+ ip = get_ip(request)
233
+ t2s_multi_logger.info(f"generate. ip: {ip}")
234
+
235
+ state_0.model_name, state_1.model_name = model_name_0, model_name_1
236
+ state_0.prompt, state_1.prompt = text, text
237
+ try:
238
+ idx = prompt_list.index(text)
239
+ state_0.offline, state_1.offline = True, True
240
+ state_0.offline_idx, state_1.offline_idx = idx, idx
241
+ except:
242
+ state_0.offline, state_1.offline = False, False
243
+ state_0.offline_idx, state_1.offline_idx = None, None
244
+
245
+ if not state_0.offline and not state_0.offline_idx:
246
+ start_time = time.time()
247
+ normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
248
+ rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
249
+ normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
250
+ rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
251
+
252
+ state_0.normal_video = normal_video_0
253
+ state_0.rgb_video = rgb_video_0
254
+ state_1.normal_video = normal_video_1
255
+ state_1.rgb_video = rgb_video_1
256
+ yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1
257
+
258
+ # logger.info(f"===output===: {output}")
259
+ data_0 = {
260
+ "ip": get_ip(request),
261
+ "model": model_name_0,
262
+ "type": "offline",
263
+ "gen_params": {},
264
+ "state": state_0.dict(),
265
+ "start": round(start_time, 4),
266
+ }
267
+ data_1 = {
268
+ "ip": get_ip(request),
269
+ "model": model_name_1,
270
+ "type": "offline",
271
+ "gen_params": {},
272
+ "state": state_1.dict(),
273
+ "start": round(start_time, 4),
274
+ }
275
+ else:
276
+ start_time = time.time()
277
+ shape_0, shape_1 = gen_func(text, model_name_0, model_name_1)
278
+ generate_time = time.time() - start_time
279
+
280
+ normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0,
281
+ shape_1, model_name_1)
282
+ finish_time = time.time()
283
+ render_time = finish_time - start_time - generate_time
284
+
285
+ state_0.normal_video = normal_video_0
286
+ state_0.rgb_video = rgb_video_0
287
+ state_1.normal_video = normal_video_1
288
+ state_1.rgb_video = rgb_video_1
289
+ yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1
290
+
291
+ # logger.info(f"===output===: {output}")
292
+ data_0 = {
293
+ "ip": get_ip(request),
294
+ "model": model_name_0,
295
+ "type": "online",
296
+ "gen_params": {},
297
+ "state": state_0.dict(),
298
+ "start": round(start_time, 4),
299
+ "time": round(finish_time - start_time, 4),
300
+ "generate_time": round(generate_time, 4),
301
+ "render_time": round(render_time, 4),
302
+ }
303
+ data_1 = {
304
+ "ip": get_ip(request),
305
+ "model": model_name_1,
306
+ "type": "online",
307
+ "gen_params": {},
308
+ "state": state_1.dict(),
309
+ "start": round(start_time, 4),
310
+ "time": round(finish_time - start_time, 4),
311
+ "generate_time": round(generate_time, 4),
312
+ "render_time": round(render_time, 4),
313
+ }
314
+
315
+ with open(get_conv_log_filename(), "a") as fout:
316
+ fout.write(json.dumps(data_0) + "\n")
317
+ fout.write(json.dumps(data_1) + "\n")
318
+ append_json_item_on_log_server(data_0, get_conv_log_filename())
319
+ append_json_item_on_log_server(data_1, get_conv_log_filename())
320
+
321
+ # for i, state in enumerate([state_0, state_1]):
322
+ # output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png'
323
+ # os.makedirs(os.path.dirname(output_file), exist_ok=True)
324
+ # with open(output_file, 'w') as f:
325
+ # state.output.save(f, 'PNG')
326
+ # save_image_file_on_log_server(output_file)
327
+
328
+ def generate_t2s_multi_annoy(gen_func, render_func,
329
+ state_0, state_1,
330
+ text,
331
+ model_name_0, model_name_1,
332
+ request: gr.Request):
333
+ if not text:
334
+ raise gr.Warning("Prompt cannot be empty.")
335
+ if state_0 is None:
336
+ state_0 = State(model_name_0, i2s_mode=False, offline=False)
337
+ if state_1 is None:
338
+ state_1 = State(model_name_1, i2s_mode=False, offline=False)
339
+
340
+ ip = get_ip(request)
341
+ t2s_multi_logger.info(f"generate. ip: {ip}")
342
+
343
+ state_0.model_name, state_1.model_name = model_name_0, model_name_1
344
+ state_0.prompt, state_1.prompt = text, text
345
+ try:
346
+ idx = prompt_list.index(text)
347
+ state_0.offline, state_1.offline = True, True
348
+ state_0.offline_idx, state_1.offline_idx = idx, idx
349
+ except:
350
+ state_0.offline, state_1.offline = False, False
351
+ state_0.offline_idx, state_1.offline_idx = None, None
352
+
353
+ if not state_0.offline and not state_0.offline_idx:
354
+ start_time = time.time()
355
+ normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
356
+ rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
357
+ normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
358
+ rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
359
+
360
+ state_0.normal_video = normal_video_0
361
+ state_0.rgb_video = rgb_video_0
362
+ state_1.normal_video = normal_video_1
363
+ state_1.rgb_video = rgb_video_1
364
+ yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_1, rgb_video_1, \
365
+ gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
366
+
367
+ # logger.info(f"===output===: {output}")
368
+ data_0 = {
369
+ "ip": get_ip(request),
370
+ "model": model_name_0,
371
+ "type": "offline",
372
+ "gen_params": {},
373
+ "state": state_0.dict(),
374
+ "start": round(start_time, 4),
375
+ }
376
+ data_1 = {
377
+ "ip": get_ip(request),
378
+ "model": model_name_1,
379
+ "type": "offline",
380
+ "gen_params": {},
381
+ "state": state_1.dict(),
382
+ "start": round(start_time, 4),
383
+ }
384
+ else:
385
+ start_time = time.time()
386
+ shape_0, shape_1 = gen_func(text, model_name_0, model_name_1)
387
+ generate_time = time.time() - start_time
388
+
389
+ normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0,
390
+ shape_1, model_name_1)
391
+ finish_time = time.time()
392
+ render_time = finish_time - start_time - generate_time
393
+
394
+ state_0.normal_video = normal_video_0
395
+ state_0.rgb_video = rgb_video_0
396
+ state_1.normal_video = normal_video_1
397
+ state_1.rgb_video = rgb_video_1
398
+ yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \
399
+ gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
400
+
401
+ # logger.info(f"===output===: {output}")
402
+ data_0 = {
403
+ "ip": get_ip(request),
404
+ "model": model_name_0,
405
+ "type": "online",
406
+ "gen_params": {},
407
+ "state": state_0.dict(),
408
+ "start": round(start_time, 4),
409
+ "time": round(finish_time - start_time, 4),
410
+ "generate_time": round(generate_time, 4),
411
+ "render_time": round(render_time, 4),
412
+ }
413
+ data_1 = {
414
+ "ip": get_ip(request),
415
+ "model": model_name_1,
416
+ "type": "online",
417
+ "gen_params": {},
418
+ "state": state_1.dict(),
419
+ "start": round(start_time, 4),
420
+ "time": round(finish_time - start_time, 4),
421
+ "generate_time": round(generate_time, 4),
422
+ "render_time": round(render_time, 4),
423
+ }
424
+
425
+ with open(get_conv_log_filename(), "a") as fout:
426
+ fout.write(json.dumps(data_0) + "\n")
427
+ fout.write(json.dumps(data_1) + "\n")
428
+ append_json_item_on_log_server(data_0, get_conv_log_filename())
429
+ append_json_item_on_log_server(data_1, get_conv_log_filename())
430
+
431
+ # for i, state in enumerate([state_0, state_1]):
432
+ # output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png'
433
+ # os.makedirs(os.path.dirname(output_file), exist_ok=True)
434
+ # with open(output_file, 'w') as f:
435
+ # state.output.save(f, 'PNG')
436
+ # save_image_file_on_log_server(output_file)
437
+
438
+
439
+ def generate_i2s(gen_func, render_func, state, image, model_name, request: gr.Request):
440
+ if not image:
441
+ raise gr.Warning("Image cannot be empty.")
442
+ if not model_name:
443
+ raise gr.Warning("Model name cannot be empty.")
444
+ if state is None:
445
+ state = State(model_name, i2s_mode=True, offline=False)
446
+
447
+ ip = get_ip(request)
448
+ t2s_logger.info(f"generate. ip: {ip}")
449
+
450
+ state.model_name = model_name
451
+ state.image = image
452
+
453
+ if not state.offline and not state.offline_idx:
454
+ start_time = time.time()
455
+ normal_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "normal", f"{state.offline_idx}.mp4")
456
+ rgb_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "rgb", f"{state.offline_idx}.mp4")
457
+
458
+ state.normal_video = normal_video
459
+ state.rgb_video = rgb_video
460
+ yield state, normal_video, rgb_video
461
+
462
+ # logger.info(f"===output===: {output}")
463
+ data = {
464
+ "ip": ip,
465
+ "model": model_name,
466
+ "type": "offline",
467
+ "gen_params": {},
468
+ "state": state.dict(),
469
+ "start": round(start_time, 4),
470
+ }
471
+ else:
472
+ start_time = time.time()
473
+ shape = gen_func(image, model_name)
474
+ generate_time = time.time() - start_time
475
+
476
+ normal_video, rgb_video = render_func(shape, model_name)
477
+ finish_time = time.time()
478
+ render_time = finish_time - start_time - generate_time
479
+
480
+ state.normal_video = normal_video
481
+ state.rgb_video = rgb_video
482
+ yield state, normal_video, rgb_video
483
+
484
+ # logger.info(f"===output===: {output}")
485
+ data = {
486
+ "ip": ip,
487
+ "model": model_name,
488
+ "type": "online",
489
+ "gen_params": {},
490
+ "state": state.dict(),
491
+ "start": round(start_time, 4),
492
+ "time": round(finish_time - start_time, 4),
493
+ "generate_time": round(generate_time, 4),
494
+ "render_time": round(render_time, 4),
495
+ }
496
+
497
+ with open(get_conv_log_filename(), "a") as fout:
498
+ fout.write(json.dumps(data) + "\n")
499
+ append_json_item_on_log_server(data, get_conv_log_filename())
500
+
501
+ # src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png'
502
+ # os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
503
+ # with open(src_img_file, 'w') as f:
504
+ # state.source_image.save(f, 'PNG')
505
+ # output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png'
506
+ # with open(output_file, 'w') as f:
507
+ # state.output.save(f, 'PNG')
508
+ # save_image_file_on_log_server(src_img_file)
509
+ # save_image_file_on_log_server(output_file)
510
+
511
+ def generate_i2s_multi(gen_func, render_func,
512
+ state_0, state_1,
513
+ image,
514
+ model_name_0, model_name_1,
515
+ request: gr.Request):
516
+ if not image:
517
+ raise gr.Warning("Image cannot be empty.")
518
+ if not model_name_0:
519
+ raise gr.Warning("Model name A cannot be empty.")
520
+ if not model_name_1:
521
+ raise gr.Warning("Model name B cannot be empty.")
522
+
523
+ if state_0 is None:
524
+ state_0 = State(model_name_0, i2s_mode=True, offline=False)
525
+ if state_1 is None:
526
+ state_1 = State(model_name_1, i2s_mode=True, offline=False)
527
+
528
+ ip = get_ip(request)
529
+ t2s_multi_logger.info(f"generate. ip: {ip}")
530
+
531
+ state_0.model_name, state_1.model_name = model_name_0, model_name_1
532
+ state_0.image, state_1.image = image, image
533
+
534
+ if not state_0.offline and not state_0.offline_idx and \
535
+ not state_1.offline and not state_1.offline_idx:
536
+ start_time = time.time()
537
+ normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
538
+ rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
539
+ normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
540
+ rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
541
+
542
+ state_0.normal_video = normal_video_0
543
+ state_0.rgb_video = rgb_video_0
544
+ state_1.normal_video = normal_video_1
545
+ state_1.rgb_video = rgb_video_1
546
+ yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \
547
+ gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
548
+
549
+ # logger.info(f"===output===: {output}")
550
+ data_0 = {
551
+ "ip": get_ip(request),
552
+ "model": model_name_0,
553
+ "type": "offline",
554
+ "gen_params": {},
555
+ "state": state_0.dict(),
556
+ "start": round(start_time, 4),
557
+ }
558
+ data_1 = {
559
+ "ip": get_ip(request),
560
+ "model": model_name_1,
561
+ "type": "offline",
562
+ "gen_params": {},
563
+ "state": state_1.dict(),
564
+ "start": round(start_time, 4),
565
+ }
566
+ else:
567
+ start_time = time.time()
568
+ shape_0, shape_1 = gen_func(image, model_name_0, model_name_1)
569
+ generate_time = time.time() - start_time
570
+
571
+ normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0,
572
+ shape_1, model_name_1)
573
+ finish_time = time.time()
574
+ render_time = finish_time - start_time - generate_time
575
+
576
+ state_0.normal_video = normal_video_0
577
+ state_0.rgb_video = rgb_video_0
578
+ state_1.normal_video = normal_video_1
579
+ state_1.rgb_video = rgb_video_1
580
+ yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1
581
+
582
+ # logger.info(f"===output===: {output}")
583
+ data_0 = {
584
+ "ip": get_ip(request),
585
+ "model": model_name_0,
586
+ "type": "online",
587
+ "gen_params": {},
588
+ "state": state_0.dict(),
589
+ "start": round(start_time, 4),
590
+ "time": round(finish_time - start_time, 4),
591
+ "generate_time": round(generate_time, 4),
592
+ "render_time": round(render_time, 4),
593
+ }
594
+ data_1 = {
595
+ "ip": get_ip(request),
596
+ "model": model_name_1,
597
+ "type": "online",
598
+ "gen_params": {},
599
+ "state": state_1.dict(),
600
+ "start": round(start_time, 4),
601
+ "time": round(finish_time - start_time, 4),
602
+ "generate_time": round(generate_time, 4),
603
+ "render_time": round(render_time, 4),
604
+ }
605
+
606
+ with open(get_conv_log_filename(), "a") as fout:
607
+ fout.write(json.dumps(data_0) + "\n")
608
+ fout.write(json.dumps(data_1) + "\n")
609
+ append_json_item_on_log_server(data_0, get_conv_log_filename())
610
+ append_json_item_on_log_server(data_1, get_conv_log_filename())
611
+
612
+ # for i, state in enumerate([state_0, state_1]):
613
+ # src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png'
614
+ # os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
615
+ # with open(src_img_file, 'w') as f:
616
+ # state.source_image.save(f, 'PNG')
617
+ # output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png'
618
+ # with open(output_file, 'w') as f:
619
+ # state.output.save(f, 'PNG')
620
+ # save_image_file_on_log_server(src_img_file)
621
+ # save_image_file_on_log_server(output_file)
622
+
623
+
624
+ def generate_i2s_multi_annoy(gen_func,
625
+ state_0, state_1,
626
+ image,
627
+ model_name_0, model_name_1,
628
+ request: gr.Request):
629
+ if not image:
630
+ raise gr.Warning("Image cannot be empty.")
631
+ if state_0 is None:
632
+ state_0 = State(model_name_0, i2s_mode=True, offline=False)
633
+ if state_1 is None:
634
+ state_1 = State(model_name_1, i2s_mode=True, offline=False)
635
+
636
+ ip = get_ip(request)
637
+ t2s_multi_logger.info(f"generate. ip: {ip}")
638
+
639
+ state_0.model_name, state_1.model_name = model_name_0, model_name_1
640
+ state_0.image, state_1.image = image, image
641
+
642
+ if not state_0.offline and not state_0.offline_idx and \
643
+ not state_1.offline and not state_1.offline_idx:
644
+ start_time = time.time()
645
+ normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
646
+ rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
647
+ normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
648
+ rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
649
+
650
+ state_0.normal_video = normal_video_0
651
+ state_0.rgb_video = rgb_video_0
652
+ state_1.normal_video = normal_video_1
653
+ state_1.rgb_video = rgb_video_1
654
+ yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \
655
+ gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
656
+
657
+ # logger.info(f"===output===: {output}")
658
+ data_0 = {
659
+ "ip": get_ip(request),
660
+ "model": model_name_0,
661
+ "type": "offline",
662
+ "gen_params": {},
663
+ "state": state_0.dict(),
664
+ "start": round(start_time, 4),
665
+ }
666
+ data_1 = {
667
+ "ip": get_ip(request),
668
+ "model": model_name_1,
669
+ "type": "offline",
670
+ "gen_params": {},
671
+ "state": state_1.dict(),
672
+ "start": round(start_time, 4),
673
+ }
674
+ else:
675
+ start_time = time.time()
676
+ shape_0, shape_1 = gen_func(image, model_name_0, model_name_1)
677
+ generate_time = time.time() - start_time
678
+
679
+ normal_video_0, rgb_video_0, normal_video_1, rgb_video_1 = render_func(shape_0, model_name_0,
680
+ shape_1, model_name_1)
681
+ finish_time = time.time()
682
+ render_time = finish_time - start_time - generate_time
683
+
684
+ state_0.normal_video = normal_video_0
685
+ state_0.rgb_video = rgb_video_0
686
+ state_1.normal_video = normal_video_1
687
+ state_1.rgb_video = rgb_video_1
688
+ yield state_0, state_1, normal_video_0, rgb_video_0, normal_video_0, rgb_video_1, \
689
+ gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
690
+
691
+ # logger.info(f"===output===: {output}")
692
+ data_0 = {
693
+ "ip": get_ip(request),
694
+ "model": model_name_0,
695
+ "type": "online",
696
+ "gen_params": {},
697
+ "state": state_0.dict(),
698
+ "start": round(start_time, 4),
699
+ "time": round(finish_time - start_time, 4),
700
+ "generate_time": round(generate_time, 4),
701
+ "render_time": round(render_time, 4),
702
+ }
703
+ data_1 = {
704
+ "ip": get_ip(request),
705
+ "model": model_name_1,
706
+ "type": "online",
707
+ "gen_params": {},
708
+ "state": state_1.dict(),
709
+ "start": round(start_time, 4),
710
+ "time": round(finish_time - start_time, 4),
711
+ "generate_time": round(generate_time, 4),
712
+ "render_time": round(render_time, 4),
713
+ }
714
+
715
+ with open(get_conv_log_filename(), "a") as fout:
716
+ fout.write(json.dumps(data_0) + "\n")
717
+ fout.write(json.dumps(data_1) + "\n")
718
+ append_json_item_on_log_server(data_0, get_conv_log_filename())
719
+ append_json_item_on_log_server(data_1, get_conv_log_filename())
720
+
721
+ # for i, state in enumerate([state_0, state_1]):
722
+ # src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png'
723
+ # os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
724
+ # with open(src_img_file, 'w') as f:
725
+ # state.source_image.save(f, 'PNG')
726
+ # output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png'
727
+ # with open(output_file, 'w') as f:
728
+ # state.output.save(f, 'PNG')
729
+ # save_image_file_on_log_server(src_img_file)
730
+ # save_image_file_on_log_server(output_file)
serve/leaderboard.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Live monitor of the website statistics and leaderboard.
3
+
4
+ Dependency:
5
+ sudo apt install pkg-config libicu-dev
6
+ pip install pytz gradio gdown plotly polyglot pyicu pycld2 tabulate
7
+ """
8
+
9
+ import argparse
10
+ import ast
11
+ import pickle
12
+ import os
13
+ import threading
14
+ import time
15
+
16
+ import gradio as gr
17
+ import numpy as np
18
+ import pandas as pd
19
+
20
+
21
+ basic_component_values = [None] * 6
22
+ leader_component_values = [None] * 5
23
+
24
+
25
+ # def make_leaderboard_md(elo_results):
26
+ # leaderboard_md = f"""
27
+ # # 🏆 Chatbot Arena Leaderboard
28
+ # | [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
29
+
30
+ # This leaderboard is based on the following three benchmarks.
31
+ # - [Chatbot Arena](https://lmsys.org/blog/2023-05-03-arena/) - a crowdsourced, randomized battle platform. We use 100K+ user votes to compute Elo ratings.
32
+ # - [MT-Bench](https://arxiv.org/abs/2306.05685) - a set of challenging multi-turn questions. We use GPT-4 to grade the model responses.
33
+ # - [MMLU](https://arxiv.org/abs/2009.03300) (5-shot) - a test to measure a model's multitask accuracy on 57 tasks.
34
+
35
+ # 💻 Code: The Arena Elo ratings are computed by this [notebook]({notebook_url}). The MT-bench scores (single-answer grading on a scale of 10) are computed by [fastchat.llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge). The MMLU scores are mostly computed by [InstructEval](https://github.com/declare-lab/instruct-eval). Higher values are better for all benchmarks. Empty cells mean not available. Last updated: November, 2023.
36
+ # """
37
+ # return leaderboard_md
38
+
39
+ def make_leaderboard_md(elo_results):
40
+ leaderboard_md = f"""
41
+ # 🏆 GenAI-Arena Leaderboard
42
+ | [GitHub](https://github.com/TIGER-AI-Lab/ImagenHub) | [Dataset](https://huggingface.co/ImagenHub) | [Twitter](https://twitter.com/TianleLI123/status/1757245259149422752) |
43
+
44
+ """
45
+ return leaderboard_md
46
+
47
+
48
+ def make_leaderboard_md_live(elo_results):
49
+ leaderboard_md = f"""
50
+ # Leaderboard
51
+ Last updated: {elo_results["last_updated_datetime"]}
52
+ {elo_results["leaderboard_table"]}
53
+ """
54
+ return leaderboard_md
55
+
56
+
57
+ def model_hyperlink(model_name, link):
58
+ return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
59
+
60
+
61
+ def load_leaderboard_table_csv(filename, add_hyperlink=True):
62
+ df = pd.read_csv(filename)
63
+ for col in df.columns:
64
+ if "Arena Elo rating" in col:
65
+ df[col] = df[col].apply(lambda x: int(x) if x != "-" else np.nan)
66
+ elif col == "MMLU":
67
+ df[col] = df[col].apply(lambda x: round(x * 100, 1) if x != "-" else np.nan)
68
+ elif col == "MT-bench (win rate %)":
69
+ df[col] = df[col].apply(lambda x: round(x, 1) if x != "-" else np.nan)
70
+ elif col == "MT-bench (score)":
71
+ df[col] = df[col].apply(lambda x: round(x, 2) if x != "-" else np.nan)
72
+
73
+ if add_hyperlink and col == "Model":
74
+ df[col] = df.apply(lambda row: model_hyperlink(row[col], row["Link"]), axis=1)
75
+ return df
76
+
77
+
78
+
79
+ def build_basic_stats_tab():
80
+ empty = "Loading ..."
81
+ basic_component_values[:] = [empty, None, empty, empty, empty, empty]
82
+
83
+ md0 = gr.Markdown(empty)
84
+ gr.Markdown("#### Figure 1: Number of model calls and votes")
85
+ plot_1 = gr.Plot(show_label=False)
86
+ with gr.Row():
87
+ with gr.Column():
88
+ md1 = gr.Markdown(empty)
89
+ with gr.Column():
90
+ md2 = gr.Markdown(empty)
91
+ with gr.Row():
92
+ with gr.Column():
93
+ md3 = gr.Markdown(empty)
94
+ with gr.Column():
95
+ md4 = gr.Markdown(empty)
96
+ return [md0, plot_1, md1, md2, md3, md4]
97
+
98
+
99
+ def get_full_table(anony_arena_df, full_arena_df, model_table_df):
100
+ values = []
101
+ for i in range(len(model_table_df)):
102
+ row = []
103
+ model_key = model_table_df.iloc[i]["key"]
104
+ model_name = model_table_df.iloc[i]["Model"]
105
+ # model display name
106
+ row.append(model_name)
107
+ if model_key in anony_arena_df.index:
108
+ idx = anony_arena_df.index.get_loc(model_key)
109
+ row.append(round(anony_arena_df.iloc[idx]["rating"]))
110
+ else:
111
+ row.append(np.nan)
112
+ if model_key in full_arena_df.index:
113
+ idx = full_arena_df.index.get_loc(model_key)
114
+ row.append(round(full_arena_df.iloc[idx]["rating"]))
115
+ else:
116
+ row.append(np.nan)
117
+ # row.append(model_table_df.iloc[i]["MT-bench (score)"])
118
+ # row.append(model_table_df.iloc[i]["Num Battles"])
119
+ # row.append(model_table_df.iloc[i]["MMLU"])
120
+ # Organization
121
+ row.append(model_table_df.iloc[i]["Organization"])
122
+ # license
123
+ row.append(model_table_df.iloc[i]["License"])
124
+
125
+ values.append(row)
126
+ values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9)
127
+ return values
128
+
129
+
130
+ def get_arena_table(arena_df, model_table_df):
131
+ # sort by rating
132
+ arena_df = arena_df.sort_values(by=["rating"], ascending=False)
133
+ values = []
134
+ for i in range(len(arena_df)):
135
+ row = []
136
+ model_key = arena_df.index[i]
137
+ model_name = model_table_df[model_table_df["key"] == model_key]["Model"].values[
138
+ 0
139
+ ]
140
+
141
+ # rank
142
+ row.append(i + 1)
143
+ # model display name
144
+ row.append(model_name)
145
+ # elo rating
146
+ row.append(round(arena_df.iloc[i]["rating"]))
147
+ upper_diff = round(arena_df.iloc[i]["rating_q975"] - arena_df.iloc[i]["rating"])
148
+ lower_diff = round(arena_df.iloc[i]["rating"] - arena_df.iloc[i]["rating_q025"])
149
+ row.append(f"+{upper_diff}/-{lower_diff}")
150
+ # num battles
151
+ row.append(round(arena_df.iloc[i]["num_battles"]))
152
+ # Organization
153
+ row.append(
154
+ model_table_df[model_table_df["key"] == model_key]["Organization"].values[0]
155
+ )
156
+ # license
157
+ row.append(
158
+ model_table_df[model_table_df["key"] == model_key]["License"].values[0]
159
+ )
160
+
161
+ values.append(row)
162
+ return values
163
+
164
+ def make_arena_leaderboard_md(elo_results):
165
+ arena_df = elo_results["leaderboard_table_df"]
166
+ last_updated = elo_results["last_updated_datetime"]
167
+ total_votes = sum(arena_df["num_battles"]) // 2
168
+ total_models = len(arena_df)
169
+
170
+ leaderboard_md = f"""
171
+
172
+
173
+ Total #models: **{total_models}**(anonymous). Total #votes: **{total_votes}**. Last updated: {last_updated}.
174
+ (Note: Only anonymous votes are considered here. Check the full leaderboard for all votes.)
175
+
176
+ Contribute the votes 🗳️ at [GenAI-Arena](https://huggingface.co/spaces/TIGER-Lab/GenAI-Arena)!
177
+
178
+ If you want to see more models, please help us [add them](https://github.com/TIGER-AI-Lab/ImagenHub?tab=readme-ov-file#-contributing-).
179
+ """
180
+ return leaderboard_md
181
+
182
+ def make_full_leaderboard_md(elo_results):
183
+ arena_df = elo_results["leaderboard_table_df"]
184
+ last_updated = elo_results["last_updated_datetime"]
185
+ total_votes = sum(arena_df["num_battles"]) // 2
186
+ total_models = len(arena_df)
187
+
188
+ leaderboard_md = f"""
189
+ Total #models: **{total_models}**(full:anonymous+open). Total #votes: **{total_votes}**. Last updated: {last_updated}.
190
+
191
+ Contribute your vote 🗳️ at [vision-arena](https://huggingface.co/spaces/WildVision/vision-arena)!
192
+ """
193
+ return leaderboard_md
194
+
195
+ def build_leaderboard_tab(elo_results_file, leaderboard_table_file, show_plot=False):
196
+ if elo_results_file is None: # Do live update
197
+ md = "Loading ..."
198
+ p1 = p2 = p3 = p4 = None
199
+ else:
200
+ with open(elo_results_file, "rb") as fin:
201
+ elo_results = pickle.load(fin)
202
+
203
+ anony_elo_results = elo_results["anony"]
204
+ full_elo_results = elo_results["full"]
205
+ anony_arena_df = anony_elo_results["leaderboard_table_df"]
206
+ full_arena_df = full_elo_results["leaderboard_table_df"]
207
+ p1 = anony_elo_results["win_fraction_heatmap"]
208
+ p2 = anony_elo_results["battle_count_heatmap"]
209
+ p3 = anony_elo_results["bootstrap_elo_rating"]
210
+ p4 = anony_elo_results["average_win_rate_bar"]
211
+
212
+ md = make_leaderboard_md(anony_elo_results)
213
+
214
+ md_1 = gr.Markdown(md, elem_id="leaderboard_markdown")
215
+
216
+ if leaderboard_table_file:
217
+ model_table_df = load_leaderboard_table_csv(leaderboard_table_file)
218
+ with gr.Tabs() as tabs:
219
+ # arena table
220
+ arena_table_vals = get_arena_table(anony_arena_df, model_table_df)
221
+ with gr.Tab("Arena Elo", id=0):
222
+ md = make_arena_leaderboard_md(anony_elo_results)
223
+ gr.Markdown(md, elem_id="leaderboard_markdown")
224
+ gr.Dataframe(
225
+ headers=[
226
+ "Rank",
227
+ "🤖 Model",
228
+ "⭐ Arena Elo",
229
+ "📊 95% CI",
230
+ "🗳️ Votes",
231
+ "Organization",
232
+ "License",
233
+ ],
234
+ datatype=[
235
+ "str",
236
+ "markdown",
237
+ "number",
238
+ "str",
239
+ "number",
240
+ "str",
241
+ "str",
242
+ ],
243
+ value=arena_table_vals,
244
+ elem_id="arena_leaderboard_dataframe",
245
+ height=700,
246
+ column_widths=[50, 200, 100, 100, 100, 150, 150],
247
+ wrap=True,
248
+ )
249
+ with gr.Tab("Full Leaderboard", id=1):
250
+ md = make_full_leaderboard_md(full_elo_results)
251
+ gr.Markdown(md, elem_id="leaderboard_markdown")
252
+ full_table_vals = get_full_table(anony_arena_df, full_arena_df, model_table_df)
253
+ gr.Dataframe(
254
+ headers=[
255
+ "🤖 Model",
256
+ "⭐ Arena Elo (anony)",
257
+ "⭐ Arena Elo (full)",
258
+ "Organization",
259
+ "License",
260
+ ],
261
+ datatype=["markdown", "number", "number", "str", "str"],
262
+ value=full_table_vals,
263
+ elem_id="full_leaderboard_dataframe",
264
+ column_widths=[200, 100, 100, 100, 150, 150],
265
+ height=700,
266
+ wrap=True,
267
+ )
268
+ if not show_plot:
269
+ gr.Markdown(
270
+ """ ## We are still collecting more votes on more models. The ranking will be updated very fruquently. Please stay tuned!
271
+ """,
272
+ elem_id="leaderboard_markdown",
273
+ )
274
+ else:
275
+ pass
276
+
277
+ leader_component_values[:] = [md, p1, p2, p3, p4]
278
+
279
+ """
280
+ with gr.Row():
281
+ with gr.Column():
282
+ gr.Markdown(
283
+ "#### Figure 1: Fraction of Model A Wins for All Non-tied A vs. B Battles"
284
+ )
285
+ plot_1 = gr.Plot(p1, show_label=False)
286
+ with gr.Column():
287
+ gr.Markdown(
288
+ "#### Figure 2: Battle Count for Each Combination of Models (without Ties)"
289
+ )
290
+ plot_2 = gr.Plot(p2, show_label=False)
291
+ with gr.Row():
292
+ with gr.Column():
293
+ gr.Markdown(
294
+ "#### Figure 3: Bootstrap of Elo Estimates (1000 Rounds of Random Sampling)"
295
+ )
296
+ plot_3 = gr.Plot(p3, show_label=False)
297
+ with gr.Column():
298
+ gr.Markdown(
299
+ "#### Figure 4: Average Win Rate Against All Other Models (Assuming Uniform Sampling and No Ties)"
300
+ )
301
+ plot_4 = gr.Plot(p4, show_label=False)
302
+ """
303
+
304
+ from .utils import acknowledgment_md
305
+
306
+ gr.Markdown(acknowledgment_md)
307
+
308
+ # return [md_1, plot_1, plot_2, plot_3, plot_4]
309
+ return [md_1]
serve/log_utils.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common utilities.
3
+ """
4
+ from asyncio import AbstractEventLoop
5
+ import json
6
+ import logging
7
+ import logging.handlers
8
+ import os
9
+ import platform
10
+ import sys
11
+ from typing import AsyncGenerator, Generator
12
+ import warnings
13
+ from pathlib import Path
14
+
15
+ import requests
16
+
17
+ from .constants import LOGDIR, LOG_SERVER_ADDR, SAVE_LOG
18
+ from .utils import save_log_str_on_log_server
19
+
20
+
21
+ handler = None
22
+ visited_loggers = set()
23
+
24
+
25
+ # Assuming LOGDIR and other necessary imports and global variables are defined
26
+
27
+ class APIHandler(logging.Handler):
28
+ """Custom logging handler that sends logs to an API."""
29
+
30
+ def __init__(self, apiUrl, log_path, *args, **kwargs):
31
+ super(APIHandler, self).__init__(*args, **kwargs)
32
+ self.apiUrl = apiUrl
33
+ self.log_path = log_path
34
+
35
+ def emit(self, record):
36
+ log_entry = self.format(record)
37
+ try:
38
+ save_log_str_on_log_server(log_entry, self.log_path)
39
+ except requests.RequestException as e:
40
+ print(f"Error sending log to API: {e}", file=sys.stderr)
41
+
42
+ def build_logger(logger_name, logger_filename, add_remote_handler=False):
43
+ global handler
44
+
45
+ formatter = logging.Formatter(
46
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
47
+ datefmt="%Y-%m-%d %H:%M:%S",
48
+ )
49
+
50
+ # Set the format of root handlers
51
+ if not logging.getLogger().handlers:
52
+ if sys.version_info[1] >= 9:
53
+ # This is for windows
54
+ logging.basicConfig(level=logging.INFO, encoding="utf-8")
55
+ else:
56
+ if platform.system() == "Windows":
57
+ warnings.warn(
58
+ "If you are running on Windows, "
59
+ "we recommend you use Python >= 3.9 for UTF-8 encoding."
60
+ )
61
+ logging.basicConfig(level=logging.INFO)
62
+ logging.getLogger().handlers[0].setFormatter(formatter)
63
+
64
+ # Redirect stdout and stderr to loggers
65
+ stdout_logger = logging.getLogger("stdout")
66
+ stdout_logger.setLevel(logging.INFO)
67
+ sl = StreamToLogger(stdout_logger, logging.INFO)
68
+ sys.stdout = sl
69
+
70
+ stderr_logger = logging.getLogger("stderr")
71
+ stderr_logger.setLevel(logging.ERROR)
72
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
73
+ sys.stderr = sl
74
+
75
+ # Get logger
76
+ logger = logging.getLogger(logger_name)
77
+ logger.setLevel(logging.INFO)
78
+
79
+ if add_remote_handler:
80
+ # Add APIHandler to send logs to your API
81
+ api_url = f"{LOG_SERVER_ADDR}/{SAVE_LOG}"
82
+
83
+ remote_logger_filename = str(Path(logger_filename).stem + "_remote.log")
84
+ api_handler = APIHandler(apiUrl=api_url, log_path=f"{LOGDIR}/{remote_logger_filename}")
85
+ api_handler.setFormatter(formatter)
86
+ logger.addHandler(api_handler)
87
+
88
+ stdout_logger.addHandler(api_handler)
89
+ stderr_logger.addHandler(api_handler)
90
+
91
+ # if LOGDIR is empty, then don't try output log to local file
92
+ if LOGDIR != "":
93
+ os.makedirs(LOGDIR, exist_ok=True)
94
+ filename = os.path.join(LOGDIR, logger_filename)
95
+ handler = logging.handlers.TimedRotatingFileHandler(
96
+ filename, when="D", utc=True, encoding="utf-8"
97
+ )
98
+ handler.setFormatter(formatter)
99
+
100
+ for l in [stdout_logger, stderr_logger, logger]:
101
+ if l in visited_loggers:
102
+ continue
103
+ visited_loggers.add(l)
104
+ l.addHandler(handler)
105
+
106
+ return logger
107
+
108
+
109
+ class StreamToLogger(object):
110
+ """
111
+ Fake file-like stream object that redirects writes to a logger instance.
112
+ """
113
+
114
+ def __init__(self, logger, log_level=logging.INFO):
115
+ self.terminal = sys.stdout
116
+ self.logger = logger
117
+ self.log_level = log_level
118
+ self.linebuf = ""
119
+
120
+ def __getattr__(self, attr):
121
+ return getattr(self.terminal, attr)
122
+
123
+ def write(self, buf):
124
+ temp_linebuf = self.linebuf + buf
125
+ self.linebuf = ""
126
+ for line in temp_linebuf.splitlines(True):
127
+ # From the io.TextIOWrapper docs:
128
+ # On output, if newline is None, any '\n' characters written
129
+ # are translated to the system default line separator.
130
+ # By default sys.stdout.write() expects '\n' newlines and then
131
+ # translates them so this is still cross platform.
132
+ if line[-1] == "\n":
133
+ encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
134
+ self.logger.log(self.log_level, encoded_message.rstrip())
135
+ else:
136
+ self.linebuf += line
137
+
138
+ def flush(self):
139
+ if self.linebuf != "":
140
+ encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
141
+ self.logger.log(self.log_level, encoded_message.rstrip())
142
+ self.linebuf = ""
serve/utils.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import datetime
5
+ import requests
6
+ import numpy as np
7
+ import gradio as gr
8
+ from pathlib import Path
9
+ from model.model_registry import *
10
+ from .constants import LOGDIR, LOG_SERVER_ADDR, APPEND_JSON, SAVE_IMAGE, SAVE_LOG
11
+ from typing import Union
12
+
13
+
14
+ enable_btn = gr.update(interactive=True, visible=True)
15
+ disable_btn = gr.update(interactive=False)
16
+ invisible_btn = gr.update(interactive=False, visible=False)
17
+ no_change_btn = gr.update(value="No Change", interactive=True, visible=True)
18
+
19
+
20
+ def build_about():
21
+ about_markdown = f"""
22
+ # About Us
23
+ This is a project from TIGER Lab at University of Waterloo.
24
+
25
+ ## Contributors:
26
+ [Tianle Li](https://scholar.google.com/citations?user=g213g7YAAAAJ&hl=en), [Dongfu Jiang](https://jdf-prog.github.io/), [Yuansheng Ni](https://yuanshengni.github.io/).
27
+
28
+ ## Contact:
29
+ Email: [email protected] (Tianle Li)
30
+
31
+ ## Advisors
32
+ [Wenhu Chen](https://wenhuchen.github.io/)
33
+
34
+ ## Sponsorship
35
+ We are keep looking for sponsorship to support the arena project for the long term. Please contact us if you are interested in supporting this project.
36
+ """
37
+
38
+ gr.Markdown(about_markdown, elem_id="about_markdown")
39
+
40
+
41
+ acknowledgment_md = """
42
+ ### Acknowledgment
43
+ <div class="image-container">
44
+ <p> Our code base is built upon <a href="https://github.com/lm-sys/FastChat" target="_blank">FastChat</a> and <a href="https://github.com/TIGER-AI-Lab/ImagenHub" target="_blank">ImagenHub</a></p>.
45
+ </div>
46
+ """
47
+
48
+ block_css = """
49
+ #notice_markdown {
50
+ font-size: 110%
51
+ }
52
+ #notice_markdown th {
53
+ display: none;
54
+ }
55
+ #notice_markdown td {
56
+ padding-top: 6px;
57
+ padding-bottom: 6px;
58
+ }
59
+ #model_description_markdown {
60
+ font-size: 110%
61
+ }
62
+ #leaderboard_markdown {
63
+ font-size: 110%
64
+ }
65
+ #leaderboard_markdown td {
66
+ padding-top: 6px;
67
+ padding-bottom: 6px;
68
+ }
69
+ #leaderboard_dataframe td {
70
+ line-height: 0.1em;
71
+ }
72
+ #about_markdown {
73
+ font-size: 110%
74
+ }
75
+ #ack_markdown {
76
+ font-size: 110%
77
+ }
78
+ #evaldim_markdown {
79
+ font-weight: bold;
80
+ text-align: center;
81
+ background-color: white;
82
+ }
83
+ #input_box textarea {
84
+ }
85
+ footer {
86
+ display:none !important
87
+ }
88
+ .image-about img {
89
+ margin: 0 30px;
90
+ margin-top: 30px;
91
+ height: 60px;
92
+ max-height: 100%;
93
+ width: auto;
94
+ float: left;
95
+ .input-image, .image-preview {
96
+ margin: 0 30px;
97
+ height: 30px;
98
+ max-height: 100%;
99
+ width: auto;
100
+ max-width: 30%;}
101
+ }
102
+ """
103
+ def enable_mds():
104
+ return tuple(gr.update(visible=True) for _ in range(3))
105
+
106
+ def disable_mds():
107
+ return tuple(gr.update(visible=False) for _ in range(3))
108
+
109
+ def enable_buttons_side_by_side():
110
+ return tuple(gr.update(visible=True, interactive=True) for i in range(14))
111
+
112
+ def disable_buttons_side_by_side():
113
+ return tuple(gr.update(visible=i>=12, interactive=False) for i in range(14))
114
+
115
+ def enable_buttons():
116
+ return tuple(gr.update(interactive=True) for _ in range(11))
117
+
118
+ def disable_buttons():
119
+ return tuple(gr.update(interactive=False) for _ in range(11))
120
+
121
+ def clear_t2s_history():
122
+ return None, "", None, None
123
+
124
+ def clear_t2s_history_side_by_side():
125
+ return [None] * 2 + [""] + [None] * 4
126
+
127
+ def clear_t2s_history_side_by_side_anony():
128
+ return [None] * 2 + [""] + [None] * 4 + [gr.Markdown("", visible=False), gr.Markdown("", visible=False)]
129
+
130
+ def clear_i2s_history():
131
+ return None, None, None, None
132
+
133
+ def clear_i2s_history_side_by_side():
134
+ return [None] * 2 + [None] + [None] * 4
135
+
136
+ def clear_i2s_history_side_by_side_anony():
137
+ return [None] * 2 + [None] + [None] * 4 + [gr.Markdown("", visible=False), gr.Markdown("", visible=False)]
138
+
139
+ def get_ip(request: gr.Request):
140
+ if request:
141
+ if "cf-connecting-ip" in request.headers:
142
+ ip = request.headers["cf-connecting-ip"] or request.client.host
143
+ else:
144
+ ip = request.client.host
145
+ else:
146
+ ip = None
147
+ return ip
148
+
149
+ def get_conv_log_filename():
150
+ t = datetime.datetime.now()
151
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
152
+ return name
153
+
154
+ def save_image_file_on_log_server(image_file:str):
155
+
156
+ image_file = Path(image_file).absolute().relative_to(os.getcwd())
157
+ image_file = str(image_file)
158
+ # Open the image file in binary mode
159
+ url = f"{LOG_SERVER_ADDR}/{SAVE_IMAGE}"
160
+ with open(image_file, 'rb') as f:
161
+ # Make the POST request, sending the image file and the image path
162
+ response = requests.post(url, files={'image': f}, data={'image_path': image_file})
163
+ return response
164
+
165
+ def append_json_item_on_log_server(json_item: Union[dict, str], log_file: str):
166
+ if isinstance(json_item, dict):
167
+ json_item = json.dumps(json_item)
168
+ log_file = Path(log_file).absolute().relative_to(os.getcwd())
169
+ log_file = str(log_file)
170
+ url = f"{LOG_SERVER_ADDR}/{APPEND_JSON}"
171
+ # Make the POST request, sending the JSON string and the log file name
172
+ response = requests.post(url, data={'json_str': json_item, 'file_name': log_file})
173
+ return response
174
+
175
+ def save_log_str_on_log_server(log_str: str, log_file: str):
176
+ log_file = Path(log_file).absolute().relative_to(os.getcwd())
177
+ log_file = str(log_file)
178
+ url = f"{LOG_SERVER_ADDR}/{SAVE_LOG}"
179
+ # Make the POST request, sending the log message and the log file name
180
+ response = requests.post(url, data={'message': log_str, 'log_path': log_file})
181
+ return response
serve/vote_utils.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import time
3
+ import json
4
+ import uuid
5
+ import gradio as gr
6
+ from pathlib import Path
7
+ from .utils import *
8
+ from .log_utils import build_logger
9
+ from .constants import IMAGE_DIR
10
+
11
+ t2s_logger = build_logger("gradio_web_server_text2shape", "gr_web_text2shape.log") # t2s = image generation, loggers for single model direct chat
12
+ t2s_multi_logger = build_logger("gradio_web_server_text2shape_multi", "gr_web_text2shape_multi.log") # t2s_multi = image generation multi, loggers for side-by-side and battle
13
+ i2s_logger = build_logger("gradio_web_server_image2shape", "gr_web_image2shape.log") # i2s = image editing, loggers for single model direct chat
14
+ i2s_multi_logger = build_logger("gradio_web_server_image2shape_multi", "gr_web_image2shape_multi.log") # i2s_multi = image editing multi, loggers for side-by-side and battle
15
+
16
+ def vote_last_response_t2s(state, vote_type, model_selector, request: gr.Request):
17
+ with open(get_conv_log_filename(), "a") as fout:
18
+ data = {
19
+ "tstamp": round(time.time(), 4),
20
+ "type": vote_type,
21
+ "model": model_selector,
22
+ "state": state.dict(),
23
+ "ip": get_ip(request),
24
+ }
25
+ fout.write(json.dumps(data) + "\n")
26
+ append_json_item_on_log_server(data, get_conv_log_filename())
27
+
28
+ def vote_last_response_t2s_multi(states, vote_type, model_selectors, request: gr.Request):
29
+ with open(get_conv_log_filename(), "a") as fout:
30
+ data = {
31
+ "tstamp": round(time.time(), 4),
32
+ "type": vote_type,
33
+ "models": [x for x in model_selectors],
34
+ "states": [x.dict() for x in states],
35
+ "ip": get_ip(request),
36
+ }
37
+ fout.write(json.dumps(data) + "\n")
38
+ append_json_item_on_log_server(data, get_conv_log_filename())
39
+ # for state in states:
40
+ # output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png'
41
+ # with open(output_file, 'w') as f:
42
+ # state.output.save(f, 'PNG')
43
+ # save_image_file_on_log_server(output_file)
44
+
45
+ def vote_last_response_i2s(state, vote_type, model_selector, request: gr.Request):
46
+ with open(get_conv_log_filename(), "a") as fout:
47
+ data = {
48
+ "tstamp": round(time.time(), 4),
49
+ "type": vote_type,
50
+ "model": model_selector,
51
+ "state": state.dict(),
52
+ "ip": get_ip(request),
53
+ }
54
+ fout.write(json.dumps(data) + "\n")
55
+ append_json_item_on_log_server(data, get_conv_log_filename())
56
+ # output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}.png'
57
+ # source_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_source.png'
58
+ # with open(output_file, 'w') as f:
59
+ # state.output.save(f, 'PNG')
60
+ # with open(source_file, 'w') as sf:
61
+ # state.source_image.save(sf, 'PNG')
62
+ # save_image_file_on_log_server(output_file)
63
+ # save_image_file_on_log_server(source_file)
64
+
65
+ def vote_last_response_i2s_multi(states, vote_type, model_selectors, request: gr.Request):
66
+ with open(get_conv_log_filename(), "a") as fout:
67
+ data = {
68
+ "tstamp": round(time.time(), 4),
69
+ "type": vote_type,
70
+ "models": [x for x in model_selectors],
71
+ "states": [x.dict() for x in states],
72
+ "ip": get_ip(request),
73
+ }
74
+ fout.write(json.dumps(data) + "\n")
75
+ append_json_item_on_log_server(data, get_conv_log_filename())
76
+ # for state in states:
77
+ # output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}.png'
78
+ # source_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_source.png'
79
+ # with open(output_file, 'w') as f:
80
+ # state.output.save(f, 'PNG')
81
+ # with open(source_file, 'w') as sf:
82
+ # state.source_image.save(sf, 'PNG')
83
+ # save_image_file_on_log_server(output_file)
84
+ # save_image_file_on_log_server(source_file)
85
+
86
+
87
+ ## Text-to-Shape Generation (t2s) Single Model Direct Chat
88
+ def upvote_last_response_t2s(state, model_selector, request: gr.Request):
89
+ ip = get_ip(request)
90
+ t2s_logger.info(f"upvote. ip: {ip}")
91
+ vote_last_response_t2s(state, "upvote", model_selector, request)
92
+ return ("",) + (disable_btn,) * 3
93
+
94
+ def downvote_last_response_t2s(state, model_selector, request: gr.Request):
95
+ ip = get_ip(request)
96
+ t2s_logger.info(f"downvote. ip: {ip}")
97
+ vote_last_response_t2s(state, "downvote", model_selector, request)
98
+ return ("",) + (disable_btn,) * 3
99
+
100
+ def flag_last_response_t2s(state, model_selector, request: gr.Request):
101
+ ip = get_ip(request)
102
+ t2s_logger.info(f"flag. ip: {ip}")
103
+ vote_last_response_t2s(state, "flag", model_selector, request)
104
+ return ("",) + (disable_btn,) * 3
105
+
106
+
107
+ ## Text-to-Shape Generation Multi (t2s_multi) Side-by-Side and Battle
108
+ def leftvote_last_response_t2s_named(
109
+ state0, state1, model_selector0, model_selector1, request: gr.Request
110
+ ):
111
+ t2s_multi_logger.info(f"leftvote (named). ip: {get_ip(request)}")
112
+ vote_last_response_t2s_multi(
113
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
114
+ )
115
+ return ("",) + (disable_btn,) * 4
116
+
117
+ def rightvote_last_response_t2s_named(
118
+ state0, state1, model_selector0, model_selector1, request: gr.Request
119
+ ):
120
+ t2s_multi_logger.info(f"rightvote (named). ip: {get_ip(request)}")
121
+ vote_last_response_t2s_multi(
122
+ [state0, state1], "rightvote", [model_selector0, model_selector1], request
123
+ )
124
+ return ("",) + (disable_btn,) * 4
125
+
126
+ def tievote_last_response_t2s_named(
127
+ state0, state1, model_selector0, model_selector1, request: gr.Request
128
+ ):
129
+ t2s_multi_logger.info(f"tievote (named). ip: {get_ip(request)}")
130
+ vote_last_response_t2s_multi(
131
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
132
+ )
133
+ return ("",) + (disable_btn,) * 4
134
+
135
+ def bothbad_vote_last_response_t2s_named(
136
+ state0, state1, model_selector0, model_selector1, request: gr.Request
137
+ ):
138
+ t2s_multi_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
139
+ vote_last_response_t2s_multi(
140
+ [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
141
+ )
142
+ return ("",) + (disable_btn,) * 4
143
+
144
+
145
+ def leftvote_last_response_t2s_anony(
146
+ state0, state1, model_selector0, model_selector1, request: gr.Request
147
+ ):
148
+ t2s_multi_logger.info(f"leftvote (named). ip: {get_ip(request)}")
149
+ vote_last_response_t2s_multi(
150
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
151
+ )
152
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
153
+ return ("",) + (disable_btn,) * 4 + names
154
+
155
+ def rightvote_last_response_t2s_anony(
156
+ state0, state1, model_selector0, model_selector1, request: gr.Request
157
+ ):
158
+ t2s_multi_logger.info(f"rightvote (named). ip: {get_ip(request)}")
159
+ vote_last_response_t2s_multi(
160
+ [state0, state1], "rightvote", [model_selector0, model_selector1], request
161
+ )
162
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
163
+ return ("",) + (disable_btn,) * 4 + names
164
+
165
+ def tievote_last_response_t2s_anony(
166
+ state0, state1, model_selector0, model_selector1, request: gr.Request
167
+ ):
168
+ t2s_multi_logger.info(f"tievote (named). ip: {get_ip(request)}")
169
+ vote_last_response_t2s_multi(
170
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
171
+ )
172
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
173
+ return ("",) + (disable_btn,) * 4 + names
174
+
175
+ def bothbad_vote_last_response_t2s_anony(
176
+ state0, state1, model_selector0, model_selector1, request: gr.Request
177
+ ):
178
+ t2s_multi_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
179
+ vote_last_response_t2s_multi(
180
+ [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
181
+ )
182
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
183
+ return ("",) + (disable_btn,) * 4 + names
184
+
185
+ ## Image-to-Shape (i2s) Single Model Direct Chat
186
+ def upvote_last_response_i2s(state, model_selector, request: gr.Request):
187
+ ip = get_ip(request)
188
+ i2s_logger.info(f"upvote. ip: {ip}")
189
+ vote_last_response_i2s(state, "upvote", model_selector, request)
190
+ return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
191
+
192
+ def downvote_last_response_i2s(state, model_selector, request: gr.Request):
193
+ ip = get_ip(request)
194
+ i2s_logger.info(f"downvote. ip: {ip}")
195
+ vote_last_response_i2s(state, "downvote", model_selector, request)
196
+ return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
197
+
198
+ def flag_last_response_i2s(state, model_selector, request: gr.Request):
199
+ ip = get_ip(request)
200
+ i2s_logger.info(f"flag. ip: {ip}")
201
+ vote_last_response_i2s(state, "flag", model_selector, request)
202
+ return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
203
+
204
+
205
+ ## Image-to-Shape Multi (i2s_multi) Side-by-Side and Battle
206
+ def leftvote_last_response_i2s_named(
207
+ state0, state1, model_selector0, model_selector1, request: gr.Request
208
+ ):
209
+ i2s_multi_logger.info(f"leftvote (anony). ip: {get_ip(request)}")
210
+ vote_last_response_i2s_multi(
211
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
212
+ )
213
+ return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4
214
+
215
+ def rightvote_last_response_i2s_named(
216
+ state0, state1, model_selector0, model_selector1, request: gr.Request
217
+ ):
218
+ i2s_multi_logger.info(f"rightvote (anony). ip: {get_ip(request)}")
219
+ vote_last_response_i2s_multi(
220
+ [state0, state1], "rightvote", [model_selector0, model_selector1], request
221
+ )
222
+ return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4
223
+
224
+ def tievote_last_response_i2s_named(
225
+ state0, state1, model_selector0, model_selector1, request: gr.Request
226
+ ):
227
+ i2s_multi_logger.info(f"tievote (anony). ip: {get_ip(request)}")
228
+ vote_last_response_i2s_multi(
229
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
230
+ )
231
+ return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4
232
+
233
+ def bothbad_vote_last_response_i2s_named(
234
+ state0, state1, model_selector0, model_selector1, request: gr.Request
235
+ ):
236
+ i2s_multi_logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
237
+ vote_last_response_i2s_multi(
238
+ [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
239
+ )
240
+ return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4
241
+
242
+
243
+ def leftvote_last_response_i2s_anony(
244
+ state0, state1, model_selector0, model_selector1, request: gr.Request
245
+ ):
246
+ i2s_multi_logger.info(f"leftvote (anony). ip: {get_ip(request)}")
247
+ vote_last_response_i2s_multi(
248
+ [state0, state1], "leftvote", [model_selector0, model_selector1], request
249
+ )
250
+ # names = (
251
+ # "### Model A: " + state0.model_name,
252
+ # "### Model B: " + state1.model_name,
253
+ # )
254
+ # names = (state0.model_name, state1.model_name)
255
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
256
+ return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4 + names
257
+
258
+ def rightvote_last_response_i2s_anony(
259
+ state0, state1, model_selector0, model_selector1, request: gr.Request
260
+ ):
261
+ i2s_multi_logger.info(f"rightvote (anony). ip: {get_ip(request)}")
262
+ vote_last_response_i2s_multi(
263
+ [state0, state1], "rightvote", [model_selector0, model_selector1], request
264
+ )
265
+ # names = (
266
+ # "### Model A: " + state0.model_name,
267
+ # "### Model B: " + state1.model_name,
268
+ # )
269
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
270
+ return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4 + names
271
+
272
+ def tievote_last_response_i2s_anony(
273
+ state0, state1, model_selector0, model_selector1, request: gr.Request
274
+ ):
275
+ i2s_multi_logger.info(f"tievote (anony). ip: {get_ip(request)}")
276
+ vote_last_response_i2s_multi(
277
+ [state0, state1], "tievote", [model_selector0, model_selector1], request
278
+ )
279
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
280
+ return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4 + names
281
+
282
+ def bothbad_vote_last_response_i2s_anony(
283
+ state0, state1, model_selector0, model_selector1, request: gr.Request
284
+ ):
285
+ i2s_multi_logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
286
+ vote_last_response_i2s_multi(
287
+ [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
288
+ )
289
+ names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
290
+ return gr.Image(height=512, width=512, type="pil") + (disable_btn,) * 4 + names
291
+
292
+
293
+ share_js = """
294
+ function (a, b, c, d) {
295
+ const captureElement = document.querySelector('#share-region-named');
296
+ html2canvas(captureElement)
297
+ .then(canvas => {
298
+ canvas.style.display = 'none'
299
+ document.body.appendChild(canvas)
300
+ return canvas
301
+ })
302
+ .then(canvas => {
303
+ const image = canvas.toDataURL('image/png')
304
+ const a = document.createElement('a')
305
+ a.setAttribute('download', 'chatbot-arena.png')
306
+ a.setAttribute('href', image)
307
+ a.click()
308
+ canvas.remove()
309
+ });
310
+ return [a, b, c, d];
311
+ }
312
+ """
313
+ def share_click_t2s_multi(state0, state1, model_selector0, model_selector1, request: gr.Request):
314
+ t2s_multi_logger.info(f"share (anony). ip: {get_ip(request)}")
315
+ if state0 is not None and state1 is not None:
316
+ vote_last_response_t2s_multi(
317
+ [state0, state1], "share", [model_selector0, model_selector1], request
318
+ )
319
+
320
+ def share_click_i2s_multi(state0, state1, model_selector0, model_selector1, request: gr.Request):
321
+ i2s_multi_logger.info(f"share (anony). ip: {get_ip(request)}")
322
+ if state0 is not None and state1 is not None:
323
+ vote_last_response_i2s_multi(
324
+ [state0, state1], "share", [model_selector0, model_selector1], request
325
+ )
326
+