Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -98,9 +98,10 @@ def user(message, history):
|
|
98 |
|
99 |
def gradient_cuff_reject(message,sample_times,perturb_times,threshold):
|
100 |
#to determine whether the query is malicious
|
101 |
-
|
102 |
-
return True
|
103 |
# first-stage rejection
|
|
|
|
|
104 |
sft_embed=shift_direction_embedding[0]*0.0
|
105 |
original_input_id=tok.encode(message,return_tensors="pt",add_special_tokens=False)[0]
|
106 |
original_embedding=embedding_func(original_input_id.to(device)).cpu()
|
@@ -115,39 +116,42 @@ def gradient_cuff_reject(message,sample_times,perturb_times,threshold):
|
|
115 |
results[0]=sum(labels)/len(labels)
|
116 |
if results[0]>0.5:
|
117 |
return True
|
118 |
-
# second-stage rejection
|
119 |
-
shift_embeddings=[]
|
120 |
-
for sft_embed in shift_direction_embedding[:perturb_times]:
|
121 |
-
#original_input_id=tok.encode(message,return_tensors="pt",add_special_tokens=False)[0]
|
122 |
-
#original_embedding=embedding_func(original_input_id.to(device)).cpu()
|
123 |
-
shift_embeddings+=[0.02*sft_embed for _ in range(sample_times)]
|
124 |
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
133 |
)
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
return False
|
140 |
|
141 |
-
def chat(message, history, sample_times, perturb_times):
|
142 |
-
if sample_times
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
return 0
|
151 |
chat = []
|
152 |
for item in history:
|
153 |
chat.append({"role": "user", "content": item[0]})
|
@@ -181,7 +185,8 @@ def chat(message, history, sample_times, perturb_times):
|
|
181 |
#demo = gr.ChatInterface(fn=chat, examples=["hello", "hola", "merhaba"], title="Gradient Cuff Vicuna-7B-V1.5")
|
182 |
with gr.ChatInterface(fn=chat, title="Gradient Cuff Stablelm-2-zephyr-1_6b",additional_inputs=[
|
183 |
gr.Slider(minimum=0, maximum=10, step=1, value=2, label="N - Sample times"),
|
184 |
-
gr.Slider(minimum=0, maximum=10, step=1, value=2, label="P - Perturb times")
|
|
|
185 |
]
|
186 |
) as demo:
|
187 |
with gr.Tab("benign"):
|
|
|
98 |
|
99 |
def gradient_cuff_reject(message,sample_times,perturb_times,threshold):
|
100 |
#to determine whether the query is malicious
|
101 |
+
|
|
|
102 |
# first-stage rejection
|
103 |
+
if sample_times==0:
|
104 |
+
return False
|
105 |
sft_embed=shift_direction_embedding[0]*0.0
|
106 |
original_input_id=tok.encode(message,return_tensors="pt",add_special_tokens=False)[0]
|
107 |
original_embedding=embedding_func(original_input_id.to(device)).cpu()
|
|
|
116 |
results[0]=sum(labels)/len(labels)
|
117 |
if results[0]>0.5:
|
118 |
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
+
if perturb_times>0:
|
121 |
+
# second-stage rejection
|
122 |
+
if threshold==0:
|
123 |
+
return True
|
124 |
+
shift_embeddings=[]
|
125 |
+
for sft_embed in shift_direction_embedding[:perturb_times]:
|
126 |
+
#original_input_id=tok.encode(message,return_tensors="pt",add_special_tokens=False)[0]
|
127 |
+
#original_embedding=embedding_func(original_input_id.to(device)).cpu()
|
128 |
+
shift_embeddings+=[0.02*sft_embed for _ in range(sample_times)]
|
129 |
+
|
130 |
+
input_embeds=embedding_shift(
|
131 |
+
original_embedding,shift_embeddings,prefix_embedding,suffix_embedding
|
132 |
)
|
133 |
+
with torch.no_grad():
|
134 |
+
responses = engine(input_embeds)
|
135 |
+
for idx in range(perturb_times):
|
136 |
+
labels=get_labels(
|
137 |
+
responses[idx*sample_times:(idx+1)*sample_times]
|
138 |
+
)
|
139 |
+
results[idx+1]=sum(labels)/len(labels)
|
140 |
+
est_grad=[(results[j+1]-results[0])/0.02*shift_direction_embedding[j] for j in range(perturb_times)]
|
141 |
+
est_grad=sum(est_grad)/len(est_grad)
|
142 |
+
if est_grad.norm().item()>threshold:
|
143 |
+
return True
|
144 |
return False
|
145 |
|
146 |
+
def chat(message, history, sample_times, perturb_times,threshold):
|
147 |
+
if gradient_cuff_reject(message,sample_times,perturb_times,threshold):
|
148 |
+
answer="[Gradient Cuff Rejection] I cannot fulfill your request".split(" ")
|
149 |
+
partial_text = ""
|
150 |
+
for new_text in answer:
|
151 |
+
partial_text += (new_text+" ")
|
152 |
+
# Yield an empty string to cleanup the message textbox and the updated conversation history
|
153 |
+
yield partial_text
|
154 |
+
return 0
|
|
|
155 |
chat = []
|
156 |
for item in history:
|
157 |
chat.append({"role": "user", "content": item[0]})
|
|
|
185 |
#demo = gr.ChatInterface(fn=chat, examples=["hello", "hola", "merhaba"], title="Gradient Cuff Vicuna-7B-V1.5")
|
186 |
with gr.ChatInterface(fn=chat, title="Gradient Cuff Stablelm-2-zephyr-1_6b",additional_inputs=[
|
187 |
gr.Slider(minimum=0, maximum=10, step=1, value=2, label="N - Sample times"),
|
188 |
+
gr.Slider(minimum=0, maximum=10, step=1, value=2, label="P - Perturb times"),
|
189 |
+
gr.Slider(minimum=0, maximum=100, step=1, value=50, label="t - threshold")
|
190 |
]
|
191 |
) as demo:
|
192 |
with gr.Tab("benign"):
|