GMFTBY commited on
Commit
8d678a1
1 Parent(s): 9522941

Update model/openllama.py

Browse files
Files changed (1) hide show
  1. model/openllama.py +1 -8
model/openllama.py CHANGED
@@ -1,7 +1,6 @@
1
  from header import *
2
  import os
3
  import torch.nn.functional as F
4
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
5
  from .ImageBind import *
6
  from .ImageBind import data
7
  from .modeling_llama import LlamaForCausalLM
@@ -103,13 +102,7 @@ class OpenLLAMAPEFTModel(nn.Module):
103
  target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
104
  )
105
 
106
- with init_empty_weights():
107
- config = LlamaConfig.from_pretrained(vicuna_ckpt_path, use_auth_token=os.environ['API_TOKEN'])
108
- self.llama_model = LlamaForCausalLM(config)
109
-
110
- self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map='sequential')
111
-
112
- # self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, use_auth_token=os.environ['API_TOKEN'])
113
  self.llama_model = get_peft_model(self.llama_model, peft_config)
114
  self.llama_model.print_trainable_parameters()
115
 
 
1
  from header import *
2
  import os
3
  import torch.nn.functional as F
 
4
  from .ImageBind import *
5
  from .ImageBind import data
6
  from .modeling_llama import LlamaForCausalLM
 
102
  target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
103
  )
104
 
105
+ self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, use_auth_token=os.environ['API_TOKEN'])
 
 
 
 
 
 
106
  self.llama_model = get_peft_model(self.llama_model, peft_config)
107
  self.llama_model.print_trainable_parameters()
108