gregH commited on
Commit
e8653c8
1 Parent(s): 89a6bc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -61,7 +61,7 @@ suffix_embedding=embedding_func(
61
  )
62
  #print(prefix_embedding)
63
  print(f"Sucessfully loaded the model to the memory")
64
- shift_direction_embedding=torch.randn(10,prefix_embedding.shape[-1])
65
  shift_direction_embedding=[0.0*shift_direction_embedding[0]]+[item for item in shift_direction_embedding]
66
  start_message = ""
67
 
@@ -89,7 +89,7 @@ def engine(input_embeds):
89
  top_p = 0.9,
90
  pad_token_id=tok.pad_token_id
91
  )
92
- output_text += tokenizer.batch_decode(outputs, skip_special_tokens=True)
93
  return output_text
94
 
95
  def user(message, history):
@@ -102,14 +102,15 @@ def gradient_cuff_reject(message):
102
  for sft_embed in shift_direction_embedding:
103
  original_input_id=tok.encode(message,return_tensors="pt",add_special_tokens=False)[0]
104
  original_embedding=embedding_func(original_input_id.to(device)).cpu()
105
- shift_embeddings=[0.02*sft_embed for _ in range(10)]
106
  input_embeds=embedding_shift(
107
  original_embedding,shift_embeddings,prefix_embedding,suffix_embedding
108
  )
109
  with torch.no_grad():
110
  responses = engine(input_embeds)
 
111
  results.append(
112
- sum(get_labels(responses))/10
113
  )
114
  if len(results)==1:
115
  if results[0]>0.5:
 
61
  )
62
  #print(prefix_embedding)
63
  print(f"Sucessfully loaded the model to the memory")
64
+ shift_direction_embedding=torch.randn(2,prefix_embedding.shape[-1])
65
  shift_direction_embedding=[0.0*shift_direction_embedding[0]]+[item for item in shift_direction_embedding]
66
  start_message = ""
67
 
 
89
  top_p = 0.9,
90
  pad_token_id=tok.pad_token_id
91
  )
92
+ output_text += tok.batch_decode(outputs, skip_special_tokens=True)
93
  return output_text
94
 
95
  def user(message, history):
 
102
  for sft_embed in shift_direction_embedding:
103
  original_input_id=tok.encode(message,return_tensors="pt",add_special_tokens=False)[0]
104
  original_embedding=embedding_func(original_input_id.to(device)).cpu()
105
+ shift_embeddings=[0.02*sft_embed for _ in range(5)]
106
  input_embeds=embedding_shift(
107
  original_embedding,shift_embeddings,prefix_embedding,suffix_embedding
108
  )
109
  with torch.no_grad():
110
  responses = engine(input_embeds)
111
+ labels=get_labels(responses)
112
  results.append(
113
+ sum(labels)/len(labels)
114
  )
115
  if len(results)==1:
116
  if results[0]>0.5: