File size: 5,332 Bytes
ac66ae2 a993443 a70f9f8 9cb6b16 e760c8d 2dfbd8a c8534fb 2dfbd8a df44c11 2dfbd8a df44c11 2dfbd8a a9bd106 5b45741 2dfbd8a 5b45741 2dfbd8a 5b45741 2dfbd8a a9bd106 88543e6 9cb6b16 88543e6 ffef239 88543e6 bd9d23a 88543e6 340f2ae 88543e6 3d77c48 2750069 3d77c48 76d0fb3 93508c3 88543e6 76d0fb3 88543e6 a79d5cd 76d0fb3 88543e6 76d0fb3 88543e6 76d0fb3 34a45fd 88543e6 76d0fb3 88543e6 76d0fb3 88543e6 76d0fb3 88543e6 76d0fb3 88543e6 76d0fb3 88543e6 a9bd106 88543e6 03a8827 88543e6 03a8827 88543e6 03a8827 cda682a 03a8827 88543e6 03a8827 88543e6 03a8827 7f9f34a 88543e6 613b540 88543e6 2b03f9f 88543e6 2dfbd8a 2b03f9f 88543e6 809f79a 88543e6 2371111 b4de4c9 2dfbd8a 88543e6 083fde1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import gradio as gr
import os, torch
from datasets import load_dataset
from huggingface_hub import HfApi, login
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
ACTION_1 = "Prompt base model"
ACTION_2 = "Fine-tune base model"
ACTION_3 = "Prompt fine-tuned model"
SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
USER_PROMPT = "What is the total trade value and average price for each trader and stock in the trade_history table?"
SQL_SCHEMA = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
FT_MODEL_NAME = "bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql"
DATASET_NAME = "gretelai/synthetic_text_to_sql"
def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt, sql_schema):
#raise gr.Error("Please clone and bring your own credentials.")
if action == ACTION_1:
result = prompt_model(base_model_name, system_prompt, user_prompt, sql_schema)
elif action == ACTION_2:
result = fine_tune_model(base_model_name, dataset_name)
elif action == ACTION_3:
result = prompt_model(ft_model_name, system_prompt, user_prompt, sql_schema)
return result
def fine_tune_model(model_name, dataset_name):
# Load dataset
dataset = load_dataset(dataset_name)
print("### Dataset")
print(dataset)
print("###")
# Load model
model, tokenizer = load_model(model_name)
print("### Model")
print(model)
print("### Tokenizer")
print(tokenizer)
print("###")
# Pre-process dataset
def preprocess(examples):
model_inputs = tokenizer(examples["sql_prompt"], text_target=examples["sql"], max_length=512, padding="max_length", truncation=True)
return model_inputs
dataset = dataset.map(preprocess, batched=True)
print("### Pre-processed dataset")
print(dataset)
print("###")
# Split dataset into training and validation sets
train_dataset = dataset["train"].shuffle(seed=42).select(range(1000))
test_dataset = dataset["test"].shuffle(seed=42).select(range(100))
print("### Training dataset")
print(test_dataset)
print("### Validation dataset")
print(test_dataset)
print("###")
# Configure training arguments
training_args = Seq2SeqTrainingArguments(
output_dir="./results",
logging_dir="./logs",
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
eval_strategy="steps",
save_total_limit=2,
save_steps=500,
eval_steps=500,
warmup_steps=500,
weight_decay=0.01,
metric_for_best_model="accuracy",
greater_is_better=True,
load_best_model_at_end=True,
push_to_hub=True,
save_on_each_node=True,
)
print("### Training arguments")
print(training_args)
print("###")
# Create trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
)
print("### Trainer")
print(trainer)
print("###")
# Train model
#trainer.train()
def prompt_model(model_name, system_prompt, user_prompt, sql_schema):
pipe = pipeline("text-generation",
model=model_name,
device_map="auto",
max_new_tokens=1000)
messages = [
{"role": "system", "content": system_prompt.format(schema=sql_schema)},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": ""}
]
output = pipe(messages)
result = output[0]["generated_text"][-1]["content"]
print("###")
print(result)
print("###")
return result
def load_model(model_name):
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
#tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = model.tokenizer
if not tokenizer.pad_token:
tokenizer.pad_token = "[PAD]"
return model, tokenizer
demo = gr.Interface(fn=process,
inputs=[gr.Radio([ACTION_1, ACTION_2, ACTION_3], label = "Action", value = ACTION_3),
gr.Textbox(label = "Base Model Name", value = BASE_MODEL_NAME, lines = 1),
gr.Textbox(label = "Fine-Tuned Model Name", value = FT_MODEL_NAME, lines = 1),
gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
gr.Textbox(label = "SQL Schema", value = SQL_SCHEMA, lines = 2)],
outputs=[gr.Textbox(label = "Prompt Completion", value = os.environ["OUTPUT"])])
demo.launch() |