bstraehle commited on
Commit
689da75
1 Parent(s): da6722c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -60,10 +60,10 @@ def fine_tune_model(base_model_name, dataset_name):
60
  print("###")
61
 
62
  # Split dataset into training and validation sets
63
- train_dataset = dataset["train"]
64
- test_dataset = dataset["test"]
65
- #train_dataset = dataset["train"].shuffle(seed=42).select(range(1000))
66
- #test_dataset = dataset["test"].shuffle(seed=42).select(range(100))
67
 
68
  print("### Training dataset")
69
  print(train_dataset)
@@ -77,7 +77,7 @@ def fine_tune_model(base_model_name, dataset_name):
77
  logging_dir="./logging",
78
  num_train_epochs=1,
79
  max_steps=1, # overwrites num_train_epochs
80
- push_to_hub=True,
81
  #per_device_train_batch_size=16,
82
  #per_device_eval_batch_size=64,
83
  #eval_strategy="steps",
@@ -107,6 +107,9 @@ def fine_tune_model(base_model_name, dataset_name):
107
 
108
  # Train model
109
  trainer.train()
 
 
 
110
 
111
  def prompt_model(model_name, system_prompt, user_prompt, sql_context):
112
  pipe = pipeline("text-generation",
 
60
  print("###")
61
 
62
  # Split dataset into training and validation sets
63
+ #train_dataset = dataset["train"]
64
+ #test_dataset = dataset["test"]
65
+ train_dataset = dataset["train"].shuffle(seed=42).select(range(1000))
66
+ test_dataset = dataset["test"].shuffle(seed=42).select(range(100))
67
 
68
  print("### Training dataset")
69
  print(train_dataset)
 
77
  logging_dir="./logging",
78
  num_train_epochs=1,
79
  max_steps=1, # overwrites num_train_epochs
80
+ #push_to_hub=True,
81
  #per_device_train_batch_size=16,
82
  #per_device_eval_batch_size=64,
83
  #eval_strategy="steps",
 
107
 
108
  # Train model
109
  trainer.train()
110
+
111
+ # Save model to HF
112
+ trainer.push_to_hub(FT_MODEL_NAME, use_auth_token=True)
113
 
114
  def prompt_model(model_name, system_prompt, user_prompt, sql_context):
115
  pipe = pipeline("text-generation",