Spaces:
Running
Running
File size: 4,392 Bytes
c3d8da6 d00e7d9 c3d8da6 d00e7d9 07b8186 c3d8da6 d00e7d9 c3d8da6 d00e7d9 c3d8da6 d00e7d9 c3d8da6 d00e7d9 c3d8da6 d00e7d9 b9e362c 07b8186 d00e7d9 5dfe7bb d00e7d9 b9e362c d00e7d9 5dfe7bb d00e7d9 c3d8da6 d00e7d9 c3d8da6 d00e7d9 c3d8da6 d00e7d9 5dfe7bb d00e7d9 b9e362c d00e7d9 b9e362c d00e7d9 07b8186 b9e362c 07b8186 d00e7d9 07b8186 d00e7d9 c3d8da6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import gradio as gr
import requests
from PIL import Image
import io
from typing import Any, Tuple
import os
class Client:
def __init__(self, server_url: str):
self.server_url = server_url
def send_request(self, task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
response = requests.post(
self.server_url,
json={
"task_name": task_name,
"model_name": model_name,
"text": text,
"normalization_type": normalization_type
},
timeout=60
)
if response.status_code == 200:
response_data = response.json()
img_data = bytes.fromhex(response_data["image"])
log_info = response_data["log"]
img = Image.open(io.BytesIO(img_data))
return img, log_info
else:
return "Error, please retry", "Error: Could not get response from server"
client = Client(f"http://{os.environ['SERVER']}/predict")
def get_layerwise_nonlinearity(task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
return client.send_request(task_name, model_name, text, normalization_type)
with gr.Blocks() as demo:
with gr.Row():
model_selector = gr.Dropdown(
choices=[
"facebook/opt-1.3b",
"TheBloke/Llama-2-7B-fp16"
# "facebook/opt-2.7b",
# "microsoft/Phi-3-mini-128k-instruct"
],
value="facebook/opt-1.3b",
label="Select Model"
)
task_selector = gr.Dropdown(
choices=[
"Layer wise non-linearity",
"Next-token prediction from intermediate representations",
"Contextualization measurement",
"Layerwise predictions (logit lens)",
"Tokenwise loss without i-th layer"
],
value="Layer wise non-linearity",
label="Select Mode"
)
normalization_selector = gr.Dropdown(
choices=["global", "token-wise"], #, "sentence-wise"],
value="token-wise",
label="Select Normalization"
)
with gr.Column():
text_message = gr.Textbox(label="Enter your request:", value="I love to live my life")
submit = gr.Button("Submit")
box_for_plot = gr.Image(label="Visualization", type="pil")
log_output = gr.Textbox(label="Log Output", lines=10, interactive=False, value="")
def update_output(task_name: str, model_name: str, text: str, normalization_type: str, existing_log: str) -> Tuple[Any, str]:
img, new_log = get_layerwise_nonlinearity(task_name, model_name, text, normalization_type)
combined_log = existing_log + "---\n" + new_log + "\n"
return img, combined_log
def set_default(task_name: str) -> str:
if task_name == "Layer wise non-linearity":
return "token-wise"
if task_name == "Next-token prediction from intermediate representations":
return "token-wise"
if task_name == "Contextualization measurement":
return "global"
if task_name == "Layerwise predictions (logit lens)":
return "global"
if task_name == "Tokenwise loss without i-th layer":
return "token-wise"
def check_normalization(task_name: str, normalization_name) -> Tuple[str, str]:
if task_name == "Contextualization measurement" and normalization_name == "token-wise":
return ("global", "\nALERT: Cannot apply token-wise normalization to one sentence, setting global normalization\n")
return (normalization_name, "")
task_selector.select(set_default, [task_selector], [normalization_selector])
normalization_selector.select(check_normalization, [task_selector, normalization_selector], [normalization_selector, log_output])
submit.click(
fn=update_output,
inputs=[task_selector, model_selector, text_message, normalization_selector, log_output],
outputs=[box_for_plot, log_output]
)
if __name__ == "__main__":
demo.launch(share=True, server_port=7860, server_name="0.0.0.0")
|