Update app.py
Browse files
app.py
CHANGED
@@ -26,10 +26,74 @@ def process(action, base_model_id, dataset, system_prompt, user_prompt, schema):
|
|
26 |
result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
|
27 |
return result
|
28 |
|
|
|
|
|
|
|
|
|
|
|
29 |
def fine_tune_model(base_model_id, dataset):
|
30 |
-
tokenizer = download_model(base_model_id)
|
31 |
-
fine_tuned_model_id = upload_model(base_model_id, tokenizer)
|
32 |
-
return fine_tuned_model_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
def prompt_model(model_id, system_prompt, user_prompt, schema):
|
35 |
pipe = pipeline("text-generation",
|
|
|
26 |
result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
|
27 |
return result
|
28 |
|
29 |
+
# Preprocess the dataset
|
30 |
+
def preprocess(examples):
|
31 |
+
model_inputs = tokenizer(examples["text"], text_target=examples["sql"], max_length=512, truncation=True)
|
32 |
+
return model_inputs
|
33 |
+
|
34 |
def fine_tune_model(base_model_id, dataset):
|
35 |
+
# tokenizer = download_model(base_model_id)
|
36 |
+
# fine_tuned_model_id = upload_model(base_model_id, tokenizer)
|
37 |
+
# return fine_tuned_model_id
|
38 |
+
# Load the dataset
|
39 |
+
dataset = load_dataset("gretelai/synthetic_text_to_sql")
|
40 |
+
|
41 |
+
# Load pre-trained model and tokenizer
|
42 |
+
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
43 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
44 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
45 |
+
|
46 |
+
dataset = dataset.map(preprocess, batched=True)
|
47 |
+
|
48 |
+
# Split dataset to training and validation sets
|
49 |
+
train_dataset = dataset["train"].shuffle(seed=42).select(range(1000)) # Adjust the range as needed
|
50 |
+
val_dataset = dataset["test"].shuffle(seed=42).select(range(100)) # Adjust the range as needed
|
51 |
+
|
52 |
+
# Set training arguments
|
53 |
+
training_args = Seq2SeqTrainingArguments(
|
54 |
+
output_dir="./results",
|
55 |
+
num_train_epochs=3, # Adjust as needed
|
56 |
+
per_device_train_batch_size=16,
|
57 |
+
per_device_eval_batch_size=64,
|
58 |
+
warmup_steps=500,
|
59 |
+
weight_decay=0.01,
|
60 |
+
logging_dir="./logs",
|
61 |
+
save_total_limit=2,
|
62 |
+
save_steps=500,
|
63 |
+
eval_steps=500,
|
64 |
+
metric_for_best_model="accuracy",
|
65 |
+
greater_is_better=True,
|
66 |
+
save_on_each_node=True,
|
67 |
+
load_best_model_at_end=True,
|
68 |
+
)
|
69 |
+
|
70 |
+
# Create Trainer instance
|
71 |
+
trainer = Seq2SeqTrainer(
|
72 |
+
model=model,
|
73 |
+
args=training_args,
|
74 |
+
train_dataset=train_dataset,
|
75 |
+
eval_dataset=val_dataset,
|
76 |
+
compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
|
77 |
+
)
|
78 |
+
|
79 |
+
# Train the model
|
80 |
+
trainer.train()
|
81 |
+
|
82 |
+
# Save the trained model
|
83 |
+
trainer.save_model("./fine_tuned_model")
|
84 |
+
|
85 |
+
# Create a repository object
|
86 |
+
repo = Repository(
|
87 |
+
local_dir="./fine_tuned_model",
|
88 |
+
repo_type="model",
|
89 |
+
repo_id="bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql",
|
90 |
+
)
|
91 |
+
|
92 |
+
# Login to the Hugging Face hub
|
93 |
+
repo.login(token=os.environ["HF_TOKEN"])
|
94 |
+
|
95 |
+
# Push the model to the hub
|
96 |
+
repo.push_to_hub(commit_message="Initial commit")
|
97 |
|
98 |
def prompt_model(model_id, system_prompt, user_prompt, schema):
|
99 |
pipe = pipeline("text-generation",
|