BAAI
/

shunxing1234 commited on
Commit
1357761
1 Parent(s): 31c9df4

Update cyg_conversation.py

Browse files
Files changed (1) hide show
  1. cyg_conversation.py +24 -0
cyg_conversation.py CHANGED
@@ -126,6 +126,30 @@ conv_templates = {
126
  "bair_v1": conv_bair_v1,
127
  }
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  if __name__ == "__main__":
131
  print(default_conversation.get_prompt())
 
126
  "bair_v1": conv_bair_v1,
127
  }
128
 
129
+ def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token):
130
+ conv = default_conversation.copy()
131
+
132
+ conv.append_message(conv.roles[1], None)
133
+ conv.append_message(conv.roles[0], text)
134
+
135
+ example = tokenizer.encode_plus(f"{conv.get_prompt()}", None, max_length=None)['input_ids']
136
+
137
+ while(len(history) > 0 and (len(example) < max_token)):
138
+ tmp = history.pop()
139
+ if tmp[0] == 'ASSISTANT':
140
+ conv.append_message(conv.roles[1], tmp[1])
141
+ else:
142
+ conv.append_message(conv.roles[0], tmp[1])
143
+ example = tokenizer.encode_plus(f"{conv.get_prompt()}", None, max_length=None)['input_ids']
144
+
145
+ if len(example) >= max_token:
146
+ conv.messages.pop()
147
+ conv.messages = conv.messages[::-1]
148
+ print('model in:', conv.get_prompt())
149
+ example = tokenizer.encode_plus(f"{conv.get_prompt()}", None, max_length=None)['input_ids']
150
+ example = example[1:-1]
151
+
152
+ return example
153
 
154
  if __name__ == "__main__":
155
  print(default_conversation.get_prompt())