bstraehle commited on
Commit
9b16331
1 Parent(s): 53b729b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -4,6 +4,7 @@ import gradio as gr
4
  import os, torch
5
  from datasets import load_dataset
6
  from huggingface_hub import HfApi, login
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
8
 
9
  ACTION_1 = "Prompt base model"
@@ -80,8 +81,8 @@ def fine_tune_model(base_model_name, dataset_name):
80
 
81
  training_args = Seq2SeqTrainingArguments(
82
  output_dir=f"./{FT_MODEL_NAME}",
83
- num_train_epochs=3,
84
- #max_steps=1, # overwrites num_train_epochs
85
  push_to_hub=True, # only pushes model, also need to push tokenizer (see below)
86
  # TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
87
  )
@@ -89,6 +90,14 @@ def fine_tune_model(base_model_name, dataset_name):
89
  print("### Training arguments")
90
  print(training_args)
91
  print("###")
 
 
 
 
 
 
 
 
92
 
93
  # Create trainer
94
 
@@ -97,6 +106,7 @@ def fine_tune_model(base_model_name, dataset_name):
97
  args=training_args,
98
  train_dataset=train_dataset,
99
  eval_dataset=eval_dataset,
 
100
  # TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
101
  )
102
 
 
4
  import os, torch
5
  from datasets import load_dataset
6
  from huggingface_hub import HfApi, login
7
+ from peft import LoraConfig
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
9
 
10
  ACTION_1 = "Prompt base model"
 
81
 
82
  training_args = Seq2SeqTrainingArguments(
83
  output_dir=f"./{FT_MODEL_NAME}",
84
+ num_train_epochs=3, # 37,500 steps
85
+ max_steps=1, # overwrites num_train_epochs
86
  push_to_hub=True, # only pushes model, also need to push tokenizer (see below)
87
  # TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
88
  )
 
90
  print("### Training arguments")
91
  print(training_args)
92
  print("###")
93
+
94
+ # PEFT
95
+
96
+ peft_config = LoraConfig(
97
+ r=8,
98
+ bias="none",
99
+ task_type="CAUSAL_LM",
100
+ )
101
 
102
  # Create trainer
103
 
 
106
  args=training_args,
107
  train_dataset=train_dataset,
108
  eval_dataset=eval_dataset,
109
+ peft_config=peft_config,
110
  # TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
111
  )
112