bstraehle commited on
Commit
8f5a0c7
1 Parent(s): 8417f84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
4
  import os, torch
5
  from datasets import load_dataset
6
  from huggingface_hub import HfApi, login
7
- from peft import LoraConfig, TaskType, get_peft_model
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
9
 
10
  ACTION_1 = "Prompt base model"
@@ -102,8 +102,10 @@ def fine_tune_model(base_model_name, dataset_name):
102
  #target_modules=["q", "v"],
103
  task_type=TaskType.SEQ_2_SEQ_LM,
104
  )
 
 
105
 
106
- peft_model = get_peft_model(model, lora_config)
107
 
108
  print("### PEFT")
109
  peft_model.print_trainable_parameters() # trainable params: 6,815,744 || all params: 8,037,076,992 || trainable%: 0.0848
 
4
  import os, torch
5
  from datasets import load_dataset
6
  from huggingface_hub import HfApi, login
7
+ from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
9
 
10
  ACTION_1 = "Prompt base model"
 
102
  #target_modules=["q", "v"],
103
  task_type=TaskType.SEQ_2_SEQ_LM,
104
  )
105
+
106
+ model = prepare_model_for_kbit_training(model)
107
 
108
+ model = get_peft_model(model, lora_config)
109
 
110
  print("### PEFT")
111
  peft_model.print_trainable_parameters() # trainable params: 6,815,744 || all params: 8,037,076,992 || trainable%: 0.0848