bstraehle commited on
Commit
493d46b
1 Parent(s): 2fb4377

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -5,7 +5,7 @@ import os, torch
5
  from datasets import load_dataset
6
  from huggingface_hub import HfApi, login
7
  from peft import LoraConfig, TaskType, get_peft_model
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
9
 
10
  ACTION_1 = "Prompt base model"
11
  ACTION_2 = "Fine-tune base model"
@@ -103,16 +103,16 @@ def fine_tune_model(base_model_name, dataset_name):
103
  task_type=TaskType.SEQ_2_SEQ_LM,
104
  )
105
 
106
- model = get_peft_model(model, lora_config)
107
 
108
  print("### PEFT")
109
- model.print_trainable_parameters()
110
  print("###")
111
 
112
  # Create trainer
113
 
114
- trainer = Seq2SeqTrainer(
115
- model=model,
116
  args=training_args,
117
  train_dataset=train_dataset,
118
  eval_dataset=eval_dataset,
 
5
  from datasets import load_dataset
6
  from huggingface_hub import HfApi, login
7
  from peft import LoraConfig, TaskType, get_peft_model
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, Trainer, pipeline
9
 
10
  ACTION_1 = "Prompt base model"
11
  ACTION_2 = "Fine-tune base model"
 
103
  task_type=TaskType.SEQ_2_SEQ_LM,
104
  )
105
 
106
+ peft_model = get_peft_model(model, lora_config)
107
 
108
  print("### PEFT")
109
+ peft_model.print_trainable_parameters()
110
  print("###")
111
 
112
  # Create trainer
113
 
114
+ trainer = Trainer(
115
+ model=peft_model,
116
  args=training_args,
117
  train_dataset=train_dataset,
118
  eval_dataset=eval_dataset,