matveymih commited on
Commit
d00e7d9
1 Parent(s): c3d8da6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -12
app.py CHANGED
@@ -2,35 +2,85 @@ import gradio as gr
2
  import requests
3
  from PIL import Image
4
  import io
5
- from typing import Any
6
  import os
7
 
8
  class Client:
9
  def __init__(self, server_url: str):
10
  self.server_url = server_url
11
 
12
- def send_request(self, model_name: str, text: str) -> Any:
13
- response = requests.post(self.server_url, json={"model_name": model_name, "text": text})
14
  if response.status_code == 200:
15
- img_data = response.content
 
 
16
  img = Image.open(io.BytesIO(img_data))
17
- return img
18
  else:
19
- return "Error, please retry"
20
 
21
  client = Client(f"http://{os.environ['SERVER']}/predict")
22
 
23
- def get_layerwise_nonlinearity(model_name: str, text: str) -> Any:
24
- return client.send_request(model_name, text)
25
 
26
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  with gr.Column():
28
- model_selector = gr.Dropdown(choices=["mistralai/Mistral-7B-v0.1", "facebook/opt-125m"], label="Select Model")
29
- text_message = gr.Textbox(label="Enter your request:")
30
  submit = gr.Button("Submit")
31
- box_for_plot = gr.Image(label="Layer wise non-linearity (with first layer)", type="pil")
 
32
 
33
- submit.click(get_layerwise_nonlinearity, [model_selector, text_message], [box_for_plot])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  if __name__ == "__main__":
36
  demo.launch(share=True, server_port=7860, server_name="0.0.0.0")
 
2
  import requests
3
  from PIL import Image
4
  import io
5
+ from typing import Any, Tuple
6
  import os
7
 
8
  class Client:
9
  def __init__(self, server_url: str):
10
  self.server_url = server_url
11
 
12
+ def send_request(self, task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
13
+ response = requests.post(self.server_url, json={"task_name": task_name, "model_name": model_name, "text": text, "normalization_type": normalization_type}, timeout=10)
14
  if response.status_code == 200:
15
+ response_data = response.json()
16
+ img_data = bytes.fromhex(response_data["image"])
17
+ log_info = response_data["log"]
18
  img = Image.open(io.BytesIO(img_data))
19
+ return img, log_info
20
  else:
21
+ return "Error, please retry", "Error: Could not get response from server"
22
 
23
  client = Client(f"http://{os.environ['SERVER']}/predict")
24
 
25
+ def get_layerwise_nonlinearity(task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
26
+ return client.send_request(task_name, model_name, text, normalization_type)
27
 
28
  with gr.Blocks() as demo:
29
+ with gr.Row():
30
+ model_selector = gr.Dropdown(
31
+ choices=[
32
+ "facebook/opt-1.3b",
33
+ "facebook/opt-2.7b",
34
+ "microsoft/Phi-3-mini-128k-instruct"
35
+ ],
36
+ value="facebook/opt-1.3b",
37
+ label="Select Model"
38
+ )
39
+ task_selector = gr.Dropdown(
40
+ choices=[
41
+ "Layer wise non-linearity (with first layer)",
42
+ "Next-token prediction from intermediate representations",
43
+ "Contextualization mesurment",
44
+ "Layerwise predictions and losses",
45
+ "Tokenwise loss without i-th layer"
46
+ ],
47
+ value="Layer wise non-linearity (with first layer)",
48
+ label="Select Mode"
49
+ )
50
+ normalization_selector = gr.Dropdown(
51
+ choices=["global", "token-wise"], #, "sentence-wise"],
52
+ value="token-wise",
53
+ label="Select Normalization"
54
+ )
55
  with gr.Column():
56
+ text_message = gr.Textbox(label="Enter your request:", value="I love to live my life")
 
57
  submit = gr.Button("Submit")
58
+ box_for_plot = gr.Image(label="Visualization", type="pil")
59
+ log_output = gr.Textbox(label="Log Output", lines=10, interactive=False, value="")
60
 
61
+ def update_output(task_name: str, model_name: str, text: str, normalization_type: str, existing_log: str) -> Tuple[Any, str]:
62
+ img, new_log = get_layerwise_nonlinearity(task_name, model_name, text, normalization_type)
63
+ combined_log = existing_log + "---\n" + new_log + "\n"
64
+ return img, combined_log
65
+
66
+ def set_default(task_name: str) -> str:
67
+ if task_name == "Layer wise non-linearity (with first layer)":
68
+ return "token-wise"
69
+ if task_name == "Next-token prediction from intermediate representations":
70
+ return "token-wise"
71
+ if task_name == "Contextualization mesurment":
72
+ return "global"
73
+ if task_name == "Layerwise predictions and losses":
74
+ return "global"
75
+ if task_name == "Tokenwise loss without i-th layer":
76
+ return "token-wise"
77
+
78
+ task_selector.select(set_default, [task_selector], [normalization_selector])
79
+ submit.click(
80
+ fn=update_output,
81
+ inputs=[task_selector, model_selector, text_message, normalization_selector, log_output],
82
+ outputs=[box_for_plot, log_output]
83
+ )
84
 
85
  if __name__ == "__main__":
86
  demo.launch(share=True, server_port=7860, server_name="0.0.0.0")