Error after loading in 4bit
I was able to load the model in 4 bit so it fits in 35.6GB memory.
However when I do model.generate() it gives below error. Does anyone have a clue? Thanks!
File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module..new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:292, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
289 hidden_states = self.input_layernorm(hidden_states)
291 # Self Attention
--> 292 hidden_states, self_attn_weights, present_key_value = self.self_attn(
293 hidden_states=hidden_states,
294 attention_mask=attention_mask,
295 position_ids=position_ids,
296 past_key_value=past_key_value,
297 output_attentions=output_attentions,
298 use_cache=use_cache,
299 )
300 hidden_states = residual + hidden_states
302 # Fully Connected
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module..new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:195, in LlamaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
192 bsz, q_len, _ = hidden_states.size()
194 query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
--> 195 key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
196 value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
198 kv_seq_len = key_states.shape[-2]
RuntimeError: shape '[1, 64, 64, 128]' is invalid for input of size 65536