File size: 4,392 Bytes
c3d8da6
 
 
 
d00e7d9
c3d8da6
 
 
 
 
 
d00e7d9
07b8186
 
 
 
 
 
 
 
 
 
c3d8da6
d00e7d9
 
 
c3d8da6
d00e7d9
c3d8da6
d00e7d9
c3d8da6
 
 
d00e7d9
 
c3d8da6
 
d00e7d9
 
 
 
b9e362c
 
07b8186
d00e7d9
 
 
 
 
 
5dfe7bb
d00e7d9
b9e362c
 
d00e7d9
 
5dfe7bb
d00e7d9
 
 
 
 
 
 
c3d8da6
d00e7d9
c3d8da6
d00e7d9
 
c3d8da6
d00e7d9
 
 
 
 
 
5dfe7bb
d00e7d9
 
 
b9e362c
d00e7d9
b9e362c
d00e7d9
 
 
 
07b8186
b9e362c
07b8186
 
 
d00e7d9
07b8186
d00e7d9
 
 
 
 
c3d8da6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import requests
from PIL import Image
import io
from typing import Any, Tuple
import os

class Client:
    def __init__(self, server_url: str):
        self.server_url = server_url

    def send_request(self, task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
        response = requests.post(
            self.server_url, 
            json={
                "task_name": task_name, 
                "model_name": model_name, 
                "text": text, 
                "normalization_type": normalization_type
            }, 
            timeout=60
        )
        if response.status_code == 200:
            response_data = response.json()
            img_data = bytes.fromhex(response_data["image"])
            log_info = response_data["log"]
            img = Image.open(io.BytesIO(img_data))
            return img, log_info
        else:
            return "Error, please retry", "Error: Could not get response from server"

client = Client(f"http://{os.environ['SERVER']}/predict")

def get_layerwise_nonlinearity(task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
    return client.send_request(task_name, model_name, text, normalization_type)

with gr.Blocks() as demo:
    with gr.Row():
        model_selector = gr.Dropdown(
            choices=[
                "facebook/opt-1.3b",
                "TheBloke/Llama-2-7B-fp16"
                # "facebook/opt-2.7b",
                # "microsoft/Phi-3-mini-128k-instruct"
            ],
            value="facebook/opt-1.3b",
            label="Select Model"
        )
        task_selector = gr.Dropdown(
            choices=[
                "Layer wise non-linearity", 
                "Next-token prediction from intermediate representations", 
                "Contextualization measurement",
                "Layerwise predictions (logit lens)",
                "Tokenwise loss without i-th layer"
            ], 
            value="Layer wise non-linearity",
            label="Select Mode"
        )
        normalization_selector = gr.Dropdown(
            choices=["global", "token-wise"], #, "sentence-wise"],
            value="token-wise",
            label="Select Normalization"
        )
    with gr.Column():
        text_message = gr.Textbox(label="Enter your request:", value="I love to live my life")
        submit = gr.Button("Submit")
        box_for_plot = gr.Image(label="Visualization", type="pil")
        log_output = gr.Textbox(label="Log Output", lines=10, interactive=False, value="")

        def update_output(task_name: str, model_name: str, text: str, normalization_type: str, existing_log: str) -> Tuple[Any, str]:
            img, new_log = get_layerwise_nonlinearity(task_name, model_name, text, normalization_type)
            combined_log = existing_log + "---\n" + new_log + "\n"
            return img, combined_log
        
        def set_default(task_name: str) -> str:
            if task_name == "Layer wise non-linearity":
                return "token-wise"
            if task_name == "Next-token prediction from intermediate representations":
                return "token-wise"
            if task_name == "Contextualization measurement":
                return "global"
            if task_name == "Layerwise predictions (logit lens)":
                return "global"
            if task_name == "Tokenwise loss without i-th layer":
                return "token-wise"

        def check_normalization(task_name: str, normalization_name) -> Tuple[str, str]:
            if task_name == "Contextualization measurement" and normalization_name == "token-wise":
                return ("global", "\nALERT: Cannot apply token-wise normalization to one sentence, setting global normalization\n")
            return (normalization_name, "")
        
        task_selector.select(set_default, [task_selector], [normalization_selector])
        normalization_selector.select(check_normalization, [task_selector, normalization_selector], [normalization_selector, log_output])
        submit.click(
            fn=update_output, 
            inputs=[task_selector, model_selector, text_message, normalization_selector, log_output], 
            outputs=[box_for_plot, log_output]
        )

if __name__ == "__main__":
    demo.launch(share=True, server_port=7860, server_name="0.0.0.0")