bstraehle commited on
Commit
76d0fb3
1 Parent(s): 72f32e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -3
app.py CHANGED
@@ -26,10 +26,74 @@ def process(action, base_model_id, dataset, system_prompt, user_prompt, schema):
26
  result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
27
  return result
28
 
 
 
 
 
 
29
  def fine_tune_model(base_model_id, dataset):
30
- tokenizer = download_model(base_model_id)
31
- fine_tuned_model_id = upload_model(base_model_id, tokenizer)
32
- return fine_tuned_model_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def prompt_model(model_id, system_prompt, user_prompt, schema):
35
  pipe = pipeline("text-generation",
 
26
  result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
27
  return result
28
 
29
+ # Preprocess the dataset
30
+ def preprocess(examples):
31
+ model_inputs = tokenizer(examples["text"], text_target=examples["sql"], max_length=512, truncation=True)
32
+ return model_inputs
33
+
34
  def fine_tune_model(base_model_id, dataset):
35
+ # tokenizer = download_model(base_model_id)
36
+ # fine_tuned_model_id = upload_model(base_model_id, tokenizer)
37
+ # return fine_tuned_model_id
38
+ # Load the dataset
39
+ dataset = load_dataset("gretelai/synthetic_text_to_sql")
40
+
41
+ # Load pre-trained model and tokenizer
42
+ model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
43
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
44
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
45
+
46
+ dataset = dataset.map(preprocess, batched=True)
47
+
48
+ # Split dataset to training and validation sets
49
+ train_dataset = dataset["train"].shuffle(seed=42).select(range(1000)) # Adjust the range as needed
50
+ val_dataset = dataset["test"].shuffle(seed=42).select(range(100)) # Adjust the range as needed
51
+
52
+ # Set training arguments
53
+ training_args = Seq2SeqTrainingArguments(
54
+ output_dir="./results",
55
+ num_train_epochs=3, # Adjust as needed
56
+ per_device_train_batch_size=16,
57
+ per_device_eval_batch_size=64,
58
+ warmup_steps=500,
59
+ weight_decay=0.01,
60
+ logging_dir="./logs",
61
+ save_total_limit=2,
62
+ save_steps=500,
63
+ eval_steps=500,
64
+ metric_for_best_model="accuracy",
65
+ greater_is_better=True,
66
+ save_on_each_node=True,
67
+ load_best_model_at_end=True,
68
+ )
69
+
70
+ # Create Trainer instance
71
+ trainer = Seq2SeqTrainer(
72
+ model=model,
73
+ args=training_args,
74
+ train_dataset=train_dataset,
75
+ eval_dataset=val_dataset,
76
+ compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
77
+ )
78
+
79
+ # Train the model
80
+ trainer.train()
81
+
82
+ # Save the trained model
83
+ trainer.save_model("./fine_tuned_model")
84
+
85
+ # Create a repository object
86
+ repo = Repository(
87
+ local_dir="./fine_tuned_model",
88
+ repo_type="model",
89
+ repo_id="bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql",
90
+ )
91
+
92
+ # Login to the Hugging Face hub
93
+ repo.login(token=os.environ["HF_TOKEN"])
94
+
95
+ # Push the model to the hub
96
+ repo.push_to_hub(commit_message="Initial commit")
97
 
98
  def prompt_model(model_id, system_prompt, user_prompt, schema):
99
  pipe = pipeline("text-generation",