int4量化Qwen/Qwen-14B-Chat运行出错
#2
by
Trenx
- opened
问题:int4量化后key.dtype为float16,但是query.dtype仍然为float32,在进行query和key点乘时报出类型错误
建议:在258行_attn函数中第一行加入判断query和key是否为同一类型的判断,并统一两者类型
+1,load in 4bit 报错如下:
File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/qwen-chat-pytorch-14b/modeling_qwen.py", line 1109, in forward
transformer_outputs = self.transformer(
File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/qwen-chat-pytorch-14b/modeling_qwen.py", line 930, in forward
outputs = block(
File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/qwen-chat-pytorch-14b/modeling_qwen.py", line 631, in forward
attn_outputs = self.attn(
File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/qwen-chat-pytorch-14b/modeling_qwen.py", line 556, in forward
attn_output, attn_weight = self._attn(
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/qwen-chat-pytorch-14b/modeling_qwen.py", line 314, in _attn
attn_weights = torch.matmul(query, key.transpose(-1, -2))
RuntimeError: [address=127.0.0.1:36263, pid=21908] expected scalar type Half but found Float
How about using Qwen-14B-Int4? This one performs better than BNB. Check the section quantization in our github readme for more information.
jklj077
changed discussion status to
closed