int4量化Qwen/Qwen-14B-Chat运行出错

#2
by Trenx - opened

问题:int4量化后key.dtype为float16,但是query.dtype仍然为float32,在进行query和key点乘时报出类型错误
建议:在258行_attn函数中第一行加入判断query和key是否为同一类型的判断,并统一两者类型

+1,load in 4bit 报错如下:

  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/qwen-chat-pytorch-14b/modeling_qwen.py", line 1109, in forward
    transformer_outputs = self.transformer(
  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/qwen-chat-pytorch-14b/modeling_qwen.py", line 930, in forward
    outputs = block(
  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/qwen-chat-pytorch-14b/modeling_qwen.py", line 631, in forward
    attn_outputs = self.attn(
  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/qwen-chat-pytorch-14b/modeling_qwen.py", line 556, in forward
    attn_output, attn_weight = self._attn(
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/qwen-chat-pytorch-14b/modeling_qwen.py", line 314, in _attn
    attn_weights = torch.matmul(query, key.transpose(-1, -2))
RuntimeError: [address=127.0.0.1:36263, pid=21908] expected scalar type Half but found Float

How about using Qwen-14B-Int4? This one performs better than BNB. Check the section quantization in our github readme for more information.

jklj077 changed discussion status to closed

Sign up or log in to comment