gregH commited on
Commit
0a218bc
1 Parent(s): 354d51b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -48,7 +48,6 @@ def set_seed(_hashed_seed):
48
 
49
  set_seed(13)
50
 
51
- original_response=""
52
  print(f"Starting to load the model to memory")
53
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
54
 
@@ -155,12 +154,12 @@ def gradient_cuff_reject(message,with_defense, sample_times,perturb_times,thresh
155
  responses = engine(input_embeds.to(torch.bfloat16))
156
  original_response=responses[0]
157
  if not with_defense:
158
- return (False,None,None)
159
  labels=get_labels(responses)
160
  results=[None for _ in range(perturb_times+1)]
161
  results[0]=sum(labels)/len(labels)
162
  if results[0]>=0.5:
163
- return (True,1-results[0],None)
164
 
165
  # second-stage rejection
166
  shift_embeddings=[]
@@ -182,9 +181,9 @@ def gradient_cuff_reject(message,with_defense, sample_times,perturb_times,thresh
182
  est_grad=[(results[j+1]-results[0])/0.02*shift_direction_embedding[j] for j in range(perturb_times)]
183
  est_grad=sum(est_grad)/len(est_grad)
184
  if est_grad.norm().item()>threshold:
185
- return (True,1-results[0],est_grad.norm().item())
186
 
187
- return (False,1-results[0],est_grad.norm().item())
188
 
189
  def chat(message, history, with_defense,threshold):
190
  perturb_times=9
@@ -212,7 +211,7 @@ def chat(message, history, with_defense,threshold):
212
  input_ids = tok([messages], return_tensors="pt")["input_ids"]
213
  #response= "[Gradient Cuff Checking: "+reject_information + "]\n"+ chat_engine(input_ids)
214
  #response=chat_engine(input_ids)
215
- response=original_response
216
  response=response.split(" ")
217
 
218
  # Initialize an empty string to store the generated text
 
48
 
49
  set_seed(13)
50
 
 
51
  print(f"Starting to load the model to memory")
52
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
53
 
 
154
  responses = engine(input_embeds.to(torch.bfloat16))
155
  original_response=responses[0]
156
  if not with_defense:
157
+ return (False,None,None,original_response)
158
  labels=get_labels(responses)
159
  results=[None for _ in range(perturb_times+1)]
160
  results[0]=sum(labels)/len(labels)
161
  if results[0]>=0.5:
162
+ return (True,1-results[0],None,original_response)
163
 
164
  # second-stage rejection
165
  shift_embeddings=[]
 
181
  est_grad=[(results[j+1]-results[0])/0.02*shift_direction_embedding[j] for j in range(perturb_times)]
182
  est_grad=sum(est_grad)/len(est_grad)
183
  if est_grad.norm().item()>threshold:
184
+ return (True,1-results[0],est_grad.norm().item(),original_response)
185
 
186
+ return (False,1-results[0],est_grad.norm().item(),original_response)
187
 
188
  def chat(message, history, with_defense,threshold):
189
  perturb_times=9
 
211
  input_ids = tok([messages], return_tensors="pt")["input_ids"]
212
  #response= "[Gradient Cuff Checking: "+reject_information + "]\n"+ chat_engine(input_ids)
213
  #response=chat_engine(input_ids)
214
+ response=return_value[-1]
215
  response=response.split(" ")
216
 
217
  # Initialize an empty string to store the generated text