Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
|
|
4 |
import os, torch
|
5 |
from datasets import load_dataset
|
6 |
from huggingface_hub import HfApi, login
|
7 |
-
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
|
9 |
|
10 |
ACTION_1 = "Prompt base model"
|
@@ -148,6 +148,22 @@ def load_model(model_name):
|
|
148 |
#print("### PEFT")
|
149 |
#peft_model.print_trainable_parameters() # trainable params: 6,815,744 || all params: 8,037,076,992 || trainable%: 0.0848
|
150 |
#print("###")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
return model, tokenizer
|
153 |
|
|
|
4 |
import os, torch
|
5 |
from datasets import load_dataset
|
6 |
from huggingface_hub import HfApi, login
|
7 |
+
from peft import LoraConfig, PeftModel
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
|
9 |
|
10 |
ACTION_1 = "Prompt base model"
|
|
|
148 |
#print("### PEFT")
|
149 |
#peft_model.print_trainable_parameters() # trainable params: 6,815,744 || all params: 8,037,076,992 || trainable%: 0.0848
|
150 |
#print("###")
|
151 |
+
|
152 |
+
###
|
153 |
+
print("111")
|
154 |
+
peft_config = LoraConfig(
|
155 |
+
lora_alpha=16,
|
156 |
+
lora_dropout=0,
|
157 |
+
r=64,
|
158 |
+
bias="none",
|
159 |
+
task_type="CAUSAL_LM",
|
160 |
+
)
|
161 |
+
print("222")
|
162 |
+
model = PeftModel.from_pretrained(base_model, "new_model", peft_config=peft_config)
|
163 |
+
print("333")
|
164 |
+
model = model.merge_and_unload()
|
165 |
+
print("444")
|
166 |
+
###
|
167 |
|
168 |
return model, tokenizer
|
169 |
|