Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ import os, torch
|
|
5 |
from datasets import load_dataset
|
6 |
from huggingface_hub import HfApi, login
|
7 |
from peft import LoraConfig, TaskType, get_peft_model
|
8 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
|
9 |
|
10 |
ACTION_1 = "Prompt base model"
|
11 |
ACTION_2 = "Fine-tune base model"
|
@@ -103,16 +103,16 @@ def fine_tune_model(base_model_name, dataset_name):
|
|
103 |
task_type=TaskType.SEQ_2_SEQ_LM,
|
104 |
)
|
105 |
|
106 |
-
|
107 |
|
108 |
print("### PEFT")
|
109 |
-
|
110 |
print("###")
|
111 |
|
112 |
# Create trainer
|
113 |
|
114 |
-
trainer =
|
115 |
-
model=
|
116 |
args=training_args,
|
117 |
train_dataset=train_dataset,
|
118 |
eval_dataset=eval_dataset,
|
|
|
5 |
from datasets import load_dataset
|
6 |
from huggingface_hub import HfApi, login
|
7 |
from peft import LoraConfig, TaskType, get_peft_model
|
8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, Trainer, pipeline
|
9 |
|
10 |
ACTION_1 = "Prompt base model"
|
11 |
ACTION_2 = "Fine-tune base model"
|
|
|
103 |
task_type=TaskType.SEQ_2_SEQ_LM,
|
104 |
)
|
105 |
|
106 |
+
peft_model = get_peft_model(model, lora_config)
|
107 |
|
108 |
print("### PEFT")
|
109 |
+
peft_model.print_trainable_parameters()
|
110 |
print("###")
|
111 |
|
112 |
# Create trainer
|
113 |
|
114 |
+
trainer = Trainer(
|
115 |
+
model=peft_model,
|
116 |
args=training_args,
|
117 |
train_dataset=train_dataset,
|
118 |
eval_dataset=eval_dataset,
|