gregH commited on
Commit
071d58e
1 Parent(s): 86cfe0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -32
app.py CHANGED
@@ -98,9 +98,10 @@ def user(message, history):
98
 
99
  def gradient_cuff_reject(message,sample_times,perturb_times,threshold):
100
  #to determine whether the query is malicious
101
- if threshold==0:
102
- return True
103
  # first-stage rejection
 
 
104
  sft_embed=shift_direction_embedding[0]*0.0
105
  original_input_id=tok.encode(message,return_tensors="pt",add_special_tokens=False)[0]
106
  original_embedding=embedding_func(original_input_id.to(device)).cpu()
@@ -115,39 +116,42 @@ def gradient_cuff_reject(message,sample_times,perturb_times,threshold):
115
  results[0]=sum(labels)/len(labels)
116
  if results[0]>0.5:
117
  return True
118
- # second-stage rejection
119
- shift_embeddings=[]
120
- for sft_embed in shift_direction_embedding[:perturb_times]:
121
- #original_input_id=tok.encode(message,return_tensors="pt",add_special_tokens=False)[0]
122
- #original_embedding=embedding_func(original_input_id.to(device)).cpu()
123
- shift_embeddings+=[0.02*sft_embed for _ in range(sample_times)]
124
 
125
- input_embeds=embedding_shift(
126
- original_embedding,shift_embeddings,prefix_embedding,suffix_embedding
127
- )
128
- with torch.no_grad():
129
- responses = engine(input_embeds)
130
- for idx in range(perturb_times):
131
- labels=get_labels(
132
- responses[idx*sample_times:(idx+1)*sample_times]
 
 
 
 
133
  )
134
- results[idx+1]=sum(labels)/len(labels)
135
- est_grad=[(results[j+1]-results[0])/0.02*shift_direction_embedding[j] for j in range(perturb_times)]
136
- est_grad=sum(est_grad)/len(est_grad)
137
- if est_grad.norm().item()>threshold:
138
- return True
 
 
 
 
 
 
139
  return False
140
 
141
- def chat(message, history, sample_times, perturb_times):
142
- if sample_times*perturb_times>0:
143
- if gradient_cuff_reject(message,sample_times,perturb_times,100):
144
- answer="[Gradient Cuff Rejection] I cannot fulfill your request".split(" ")
145
- partial_text = ""
146
- for new_text in answer:
147
- partial_text += (new_text+" ")
148
- # Yield an empty string to cleanup the message textbox and the updated conversation history
149
- yield partial_text
150
- return 0
151
  chat = []
152
  for item in history:
153
  chat.append({"role": "user", "content": item[0]})
@@ -181,7 +185,8 @@ def chat(message, history, sample_times, perturb_times):
181
  #demo = gr.ChatInterface(fn=chat, examples=["hello", "hola", "merhaba"], title="Gradient Cuff Vicuna-7B-V1.5")
182
  with gr.ChatInterface(fn=chat, title="Gradient Cuff Stablelm-2-zephyr-1_6b",additional_inputs=[
183
  gr.Slider(minimum=0, maximum=10, step=1, value=2, label="N - Sample times"),
184
- gr.Slider(minimum=0, maximum=10, step=1, value=2, label="P - Perturb times")
 
185
  ]
186
  ) as demo:
187
  with gr.Tab("benign"):
 
98
 
99
  def gradient_cuff_reject(message,sample_times,perturb_times,threshold):
100
  #to determine whether the query is malicious
101
+
 
102
  # first-stage rejection
103
+ if sample_times==0:
104
+ return False
105
  sft_embed=shift_direction_embedding[0]*0.0
106
  original_input_id=tok.encode(message,return_tensors="pt",add_special_tokens=False)[0]
107
  original_embedding=embedding_func(original_input_id.to(device)).cpu()
 
116
  results[0]=sum(labels)/len(labels)
117
  if results[0]>0.5:
118
  return True
 
 
 
 
 
 
119
 
120
+ if perturb_times>0:
121
+ # second-stage rejection
122
+ if threshold==0:
123
+ return True
124
+ shift_embeddings=[]
125
+ for sft_embed in shift_direction_embedding[:perturb_times]:
126
+ #original_input_id=tok.encode(message,return_tensors="pt",add_special_tokens=False)[0]
127
+ #original_embedding=embedding_func(original_input_id.to(device)).cpu()
128
+ shift_embeddings+=[0.02*sft_embed for _ in range(sample_times)]
129
+
130
+ input_embeds=embedding_shift(
131
+ original_embedding,shift_embeddings,prefix_embedding,suffix_embedding
132
  )
133
+ with torch.no_grad():
134
+ responses = engine(input_embeds)
135
+ for idx in range(perturb_times):
136
+ labels=get_labels(
137
+ responses[idx*sample_times:(idx+1)*sample_times]
138
+ )
139
+ results[idx+1]=sum(labels)/len(labels)
140
+ est_grad=[(results[j+1]-results[0])/0.02*shift_direction_embedding[j] for j in range(perturb_times)]
141
+ est_grad=sum(est_grad)/len(est_grad)
142
+ if est_grad.norm().item()>threshold:
143
+ return True
144
  return False
145
 
146
+ def chat(message, history, sample_times, perturb_times,threshold):
147
+ if gradient_cuff_reject(message,sample_times,perturb_times,threshold):
148
+ answer="[Gradient Cuff Rejection] I cannot fulfill your request".split(" ")
149
+ partial_text = ""
150
+ for new_text in answer:
151
+ partial_text += (new_text+" ")
152
+ # Yield an empty string to cleanup the message textbox and the updated conversation history
153
+ yield partial_text
154
+ return 0
 
155
  chat = []
156
  for item in history:
157
  chat.append({"role": "user", "content": item[0]})
 
185
  #demo = gr.ChatInterface(fn=chat, examples=["hello", "hola", "merhaba"], title="Gradient Cuff Vicuna-7B-V1.5")
186
  with gr.ChatInterface(fn=chat, title="Gradient Cuff Stablelm-2-zephyr-1_6b",additional_inputs=[
187
  gr.Slider(minimum=0, maximum=10, step=1, value=2, label="N - Sample times"),
188
+ gr.Slider(minimum=0, maximum=10, step=1, value=2, label="P - Perturb times"),
189
+ gr.Slider(minimum=0, maximum=100, step=1, value=50, label="t - threshold")
190
  ]
191
  ) as demo:
192
  with gr.Tab("benign"):