ShiJueXiaofei
commited on
Commit
•
2cdf703
1
Parent(s):
8fd7fba
fix when use_cache = False,inference 乱码
Browse files当加载原始模型,设置 use_cache = False 时,对next_token的预测,input_ids的截取只判断了 is_first_forward ,仍然截取处理,只使用最新的token写入input_ids。此时没有past_key_value参数,会导致模型推理乱码。
应该 判断 is_first_forward == False and self.config.use_cache == True 的时候,才能截取最新预测的token,传入model,否则要传入前面原始文本序列及已经预测的token。
- modeling_chatglm.py +3 -2
modeling_chatglm.py
CHANGED
@@ -904,8 +904,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
904 |
if position_ids is None:
|
905 |
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
906 |
if not is_first_forward:
|
907 |
-
|
908 |
-
|
|
|
909 |
return {
|
910 |
"input_ids": input_ids,
|
911 |
"past_key_values": past_key_values,
|
|
|
904 |
if position_ids is None:
|
905 |
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
906 |
if not is_first_forward:
|
907 |
+
if self.config.use_cache:
|
908 |
+
position_ids = position_ids[..., -1:]
|
909 |
+
input_ids = input_ids[:, -1:]
|
910 |
return {
|
911 |
"input_ids": input_ids,
|
912 |
"past_key_values": past_key_values,
|