Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
|
|
4 |
import os, torch
|
5 |
from datasets import load_dataset
|
6 |
from huggingface_hub import HfApi, login
|
7 |
-
from peft import LoraConfig, TaskType, get_peft_model
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
|
9 |
|
10 |
ACTION_1 = "Prompt base model"
|
@@ -102,8 +102,10 @@ def fine_tune_model(base_model_name, dataset_name):
|
|
102 |
#target_modules=["q", "v"],
|
103 |
task_type=TaskType.SEQ_2_SEQ_LM,
|
104 |
)
|
|
|
|
|
105 |
|
106 |
-
|
107 |
|
108 |
print("### PEFT")
|
109 |
peft_model.print_trainable_parameters() # trainable params: 6,815,744 || all params: 8,037,076,992 || trainable%: 0.0848
|
|
|
4 |
import os, torch
|
5 |
from datasets import load_dataset
|
6 |
from huggingface_hub import HfApi, login
|
7 |
+
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
|
9 |
|
10 |
ACTION_1 = "Prompt base model"
|
|
|
102 |
#target_modules=["q", "v"],
|
103 |
task_type=TaskType.SEQ_2_SEQ_LM,
|
104 |
)
|
105 |
+
|
106 |
+
model = prepare_model_for_kbit_training(model)
|
107 |
|
108 |
+
model = get_peft_model(model, lora_config)
|
109 |
|
110 |
print("### PEFT")
|
111 |
peft_model.print_trainable_parameters() # trainable params: 6,815,744 || all params: 8,037,076,992 || trainable%: 0.0848
|