Ahren09 commited on
Commit
a2294bf
1 Parent(s): 6e747af

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +261 -0
app.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ import threading
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+
10
+ from llava.mm_utils import get_model_name_from_path
11
+ from llava.model.builder import load_pretrained_model
12
+ from llava_utils import prompt_wrapper, generator
13
+ from utils import normalize, denormalize, load_image
14
+
15
+ os.environ['DISABLE_NEST_ASYNCIO'] = 'True'
16
+
17
+ # if not asyncio.get_event_loop().is_running():
18
+ # asyncio.set_event_loop(asyncio.new_event_loop())
19
+
20
+ # asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
21
+
22
+ BASE_DIR = "/workingdir/soh337/llavaguard"
23
+ # BASE_DIR = "/Users/ahren/Workspace/Multimodal/llavaguard"
24
+ UNCONSTRAINED_ATTACK_IMAGE_PATH = 'unconstrained_attack_images/adversarial_'
25
+ CONSTRAINED_ATTACK_IMAGE_PATH = 'adversarial_qna_images/adv_image_'
26
+ MODEL_PATH = "liuhaotian/llava-v1.5-13b"
27
+
28
+ TEXT_SAFETY_PATCHES = {
29
+ "optimized": "text_patch_optimized",
30
+ "heuristic": "text_patch_heuristic"
31
+ }
32
+
33
+ IMAGE_SAFETY_PATCHES = {
34
+ "default": "safety_patch.pt"
35
+ }
36
+
37
+
38
+ def rtp_read(text_file):
39
+ dataset = []
40
+ lines = open(text_file).read().split("\n")
41
+ for li in lines:
42
+ obj = json.loads(li)
43
+ if obj['challenging']:
44
+ dataset.append(obj['prompt']['text'])
45
+ return dataset
46
+
47
+
48
+ model = loaded_model_name = tokenizer = image_processor = context_len = my_generator = None
49
+
50
+ def load_model_async(model_path, model_name):
51
+ global tokenizer, model, image_processor, context_len, loaded_model_name, my_generator
52
+ print(f"Loading {model_name} model ... ")
53
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
54
+ if "llava" in model_name.lower():
55
+ loaded_model_name = "LLaVA"
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ my_generator = generator.Generator(model=model, tokenizer=tokenizer)
60
+
61
+ threading.Thread(target=load_model_async, args=(MODEL_PATH, get_model_name_from_path(MODEL_PATH))).start()
62
+
63
+ print('>>> Initializing Models')
64
+
65
+ prompts = rtp_read(osp.join('harmful_corpus/rtp_prompts.jsonl'))
66
+
67
+ # out_unprotected: responses without the safety patch
68
+ out, out_unprotected = [], []
69
+ text_prompt = 'Please extend the following sentence: %s'
70
+
71
+
72
+ def generate_answer(image, user_message: str, requested_model_name: str,
73
+ image_safety_patch_type: str, text_safety_patch_type: str,
74
+ ):
75
+ global tokenizer, model, image_processor, context_len, loaded_model_name, my_generator
76
+
77
+ with open(TEXT_SAFETY_PATCHES[text_safety_patch_type], 'r') as file:
78
+ text_safety_patch = file.read().rstrip()
79
+
80
+ image_safety_patch = IMAGE_SAFETY_PATCHES[image_safety_patch_type]
81
+ if requested_model_name == "LLaVA":
82
+
83
+ if requested_model_name == loaded_model_name:
84
+
85
+ print(f"{requested_model_name} model already loaded.")
86
+
87
+ else:
88
+ print(f"Loading {requested_model_name} model ... ")
89
+
90
+ threading.Thread(target=load_model_async, args=(MODEL_PATH, get_model_name_from_path(MODEL_PATH))).start()
91
+ my_generator = generator.Generator(model=model, tokenizer=tokenizer)
92
+
93
+ # load a randomly-sampled unconstrained attack image as Image object
94
+ if isinstance(image, str):
95
+ image = load_image(image)
96
+
97
+ # transform the image using the visual encoder (CLIP) of LLaVA 1.5; the processed image size would be PyTorch tensor whose shape is (336,336).
98
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()
99
+
100
+ if image_safety_patch != None:
101
+ # make the image pixel values between (0,1)
102
+ image = normalize(image)
103
+ # load the safety patch tensor whose values are (0,1)
104
+ safety_patch = torch.load(image_safety_patch).cuda()
105
+ # apply the safety patch to the input image, clamp it between (0,1) and denormalize it to the original pixel values
106
+ safe_image = denormalize((image + safety_patch).clamp(0, 1))
107
+ # make sure the image value is between (0,1)
108
+ print(torch.min(image), torch.max(image), torch.min(safe_image), torch.max(safe_image))
109
+
110
+ else:
111
+ safe_image = image
112
+
113
+ model.eval()
114
+
115
+ user_message_unprotected = user_message
116
+ if text_safety_patch != None:
117
+ if text_safety_patch_type == "optimal":
118
+ # use the below for optimal text safety patch
119
+ user_message = text_safety_patch + '\n' + user_message
120
+
121
+ elif text_safety_patch_type == "heuristic":
122
+ # use the below for heuristic text safety patch
123
+ user_message += '\n' + text_safety_patch
124
+ else:
125
+ raise ValueError(f"Invalid safety patch type: {user_message}")
126
+
127
+ text_prompt_template_unprotected = prompt_wrapper.prepare_text_prompt(text_prompt % user_message_unprotected)
128
+ prompt_unprotected = prompt_wrapper.Prompt(model, tokenizer, text_prompts=text_prompt_template_unprotected,
129
+ device=model.device)
130
+
131
+ text_prompt_template = prompt_wrapper.prepare_text_prompt(text_prompt % user_message)
132
+ prompt = prompt_wrapper.Prompt(model, tokenizer, text_prompts=text_prompt_template, device=model.device)
133
+
134
+ response_unprotected = my_generator.generate(prompt_unprotected, image).replace("[INST]", "").replace("[/INST]",
135
+ "").replace(
136
+ "[SYS]", "").replace("[/SYS/]", "").strip()
137
+
138
+ response = my_generator.generate(prompt, safe_image).replace("[INST]", "").replace("[/INST]", "").replace(
139
+ "[SYS]", "").replace("[/SYS/]", "").strip()
140
+
141
+ if text_safety_patch != None:
142
+ response = response.replace(text_safety_patch, "")
143
+
144
+ response_unprotected = response_unprotected.replace(text_safety_patch, "")
145
+
146
+ print(" -- [Unprotected] continuation: ---")
147
+ print(response_unprotected)
148
+ print(" -- [Protected] continuation: ---")
149
+ print(response)
150
+
151
+ out.append({'prompt': user_message, 'continuation': response})
152
+ out_unprotected.append({'prompt': user_message, 'continuation': response_unprotected})
153
+
154
+ return response, response_unprotected
155
+
156
+
157
+ def get_list_of_examples():
158
+ global rtp
159
+ examples = []
160
+
161
+ # Use the first 3 prompts for constrained attack
162
+ for i, prompt in enumerate(prompts[:3]):
163
+ image_num = np.random.randint(25) # Randomly select an image number
164
+ image_path = f'{CONSTRAINED_ATTACK_IMAGE_PATH}{image_num}.bmp'
165
+
166
+ examples.append(
167
+ [image_path, prompt]
168
+ )
169
+
170
+ # Use the 3-6th prompts for unconstrained attack
171
+ for i, prompt in enumerate(prompts[3:6]):
172
+ image_num = np.random.randint(25) # Randomly select an image number
173
+ image_path = f'{UNCONSTRAINED_ATTACK_IMAGE_PATH}{image_num}.bmp'
174
+
175
+ examples.append(
176
+ [image_path, prompt]
177
+ )
178
+
179
+ return examples
180
+
181
+
182
+ css = """#col-container {max-width: 90%; margin-left: auto; margin-right: auto; display: flex; flex-direction: column;}
183
+ #header {text-align: center;}
184
+ #col-chatbox {flex: 1; max-height: min(750px, 100%);}
185
+ #label {font-size: 2em; padding: 0.5em; margin: 0;}
186
+ .message {font-size: 1.2em;}
187
+ .message-wrap {max-height: min(700px, 100vh);}
188
+ """
189
+
190
+
191
+ def get_empty_state():
192
+ # TODO: Not sure what this means
193
+ return gr.State({"arena": None})
194
+
195
+
196
+ examples = get_list_of_examples()
197
+
198
+
199
+ # Define a function to update inputs based on selected example
200
+ def update_inputs(example_id):
201
+ selected_example = examples[int(example_id)]
202
+ return selected_example['image_path'], selected_example['text']
203
+
204
+
205
+ model_selector, image_patch_selector, text_patch_selector = None, None, None
206
+
207
+
208
+ def process_text_and_image(image_path: str, user_message: str):
209
+ global model_selector, image_patch_selector, text_patch_selector
210
+ print(f"User Message: {user_message}")
211
+ # print(f"Text Safety Patch: {safety_patch}")
212
+ print(f"Image Path: {image_path}")
213
+ print(model_selector.value)
214
+
215
+ # generate_answer(user_message, image_path, "LLaVA", "heuristic", "default")
216
+ response, response_unprotected = generate_answer(image_path, user_message, model_selector.value, image_patch_selector.value,
217
+ text_patch_selector.value)
218
+
219
+ return response, response_unprotected
220
+
221
+
222
+ with gr.Blocks(css=css) as demo:
223
+ state = get_empty_state()
224
+ all_components = []
225
+
226
+ with gr.Column(elem_id="col-container"):
227
+ gr.Markdown(
228
+ """# 🦙LLaVAGuard🔥<br>
229
+ Safeguarding your Multimodal LLM
230
+ **[Project Homepage](#)**""",
231
+ elem_id="header",
232
+ )
233
+
234
+ # example_selector = gr.Dropdown(choices=[f"Example {i}" for i, e in enumerate(examples)],
235
+ # label="Select an Example")
236
+
237
+ with gr.Row():
238
+ model_selector = gr.Dropdown(choices=["LLaVA"], label="Model", info="Select Model", value="LLaVA")
239
+ image_patch_selector = gr.Dropdown(choices=["default"], label="Image Patch", info="Select Image Safety "
240
+ "Patch", value="default")
241
+ text_patch_selector = gr.Dropdown(choices=["heuristic", "optimized"], label="Text Patch", info="Select "
242
+ "Text "
243
+ "Safety "
244
+ "Patch",
245
+ value="heuristic")
246
+
247
+ image_and_text_uploader = gr.Interface(
248
+ fn=process_text_and_image,
249
+ inputs=[gr.Image(type="pil", label="Upload your image", interactive=True),
250
+
251
+ gr.Textbox(placeholder="Input a question", label="Your Question"),
252
+ ],
253
+ examples=examples,
254
+ outputs=[
255
+ gr.Textbox(label="With Safety Patches"),
256
+ gr.Textbox(label="NO Safety Patches")
257
+ ])
258
+
259
+
260
+ # Launch the demo
261
+ demo.launch()