import torch import torch.nn as nn from transformers import LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaAttention model = LlamaForCausalLM.from_pretrained( "meta-llama/Llama-2-70b-hf", use_cache=False, torch_dtype=torch.bfloat16, use_flash_attention_2=True, max_position_embeddings=8192, ) def replace_modules(module): has_bias = module.q_proj.bias is not None qkv_weight = torch.cat([module.q_proj.weight.data, module.k_proj.weight.data, module.v_proj.weight.data], dim=0) module.qkv_proj = nn.Linear(module.hidden_size, qkv_weight.shape[0], bias=has_bias) module.qkv_proj.weight.data = qkv_weight if has_bias: qkv_bias = torch.cat([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias], dim=0) module.qkv_proj.bias.data = qkv_bias del module.q_proj del module.k_proj del module.v_proj module.dim1 = module.num_heads * module.head_dim module.dim2 = module.num_key_value_heads * module.head_dim for name, module in model.named_modules(): if isinstance(module, LlamaAttention): replace_modules(module) model.config.save_pretrained("my_config") model.save_pretrained("llama2-70b")