File size: 5,636 Bytes
ac66ae2
a993443
a70f9f8
 
9cb6b16
e760c8d
2dfbd8a
 
 
c8534fb
2dfbd8a
 
 
df44c11
2dfbd8a
 
 
df44c11
2dfbd8a
a9bd106
5b45741
2dfbd8a
5b45741
2dfbd8a
5b45741
2dfbd8a
a9bd106
 
75f5c42
88543e6
 
9cb6b16
88543e6
 
12d90c5
 
 
956c49c
88543e6
 
 
75f5c42
ffef239
88543e6
bd9d23a
88543e6
 
 
340f2ae
88543e6
8cdd9a7
3d77c48
8cdd9a7
3d77c48
8cdd9a7
76d0fb3
93508c3
88543e6
 
12d90c5
 
 
956c49c
88543e6
 
 
12d90c5
 
 
 
88543e6
 
8cdd9a7
88543e6
 
 
 
 
76d0fb3
8cdd9a7
 
5a35f8f
1afcb19
8cdd9a7
 
 
 
 
 
 
 
 
 
 
 
 
76d0fb3
88543e6
 
 
 
76d0fb3
88543e6
76d0fb3
 
 
 
88543e6
8cdd9a7
76d0fb3
 
b85865d
1afcb19
 
b85865d
88543e6
03a8827
88543e6
03a8827
 
88543e6
03a8827
cda682a
03a8827
 
 
88543e6
03a8827
88543e6
03a8827
7f9f34a
88543e6
 
 
613b540
88543e6
2b03f9f
88543e6
 
1afcb19
8cdd9a7
160048e
1fca62f
88543e6
2371111
b4de4c9
2dfbd8a
 
 
 
 
 
88543e6
083fde1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import os, torch
from datasets import load_dataset
from huggingface_hub import HfApi, login
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline

ACTION_1 = "Prompt base model"
ACTION_2 = "Fine-tune base model"
ACTION_3 = "Prompt fine-tuned model"

SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
USER_PROMPT = "What is the total trade value and average price for each trader and stock in the trade_history table?"
SQL_SCHEMA = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"

BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
FT_MODEL_NAME = "bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql"
DATASET_NAME = "gretelai/synthetic_text_to_sql"

def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt, sql_schema):
    #raise gr.Error("Please clone and bring your own credentials.")
    if action == ACTION_1:
        result = prompt_model(base_model_name, system_prompt, user_prompt, sql_schema)
    elif action == ACTION_2:
        result = fine_tune_model(base_model_name, dataset_name)
    elif action == ACTION_3:
        result = prompt_model(ft_model_name, system_prompt, user_prompt, sql_schema)
    return result

def fine_tune_model(base_model_name, dataset_name):
    # Load dataset
    dataset = load_dataset(dataset_name)

    print("### Dataset")
    print(dataset)
    print("### Train example")
    print(dataset["train"][:1])
    print("### Test example")
    print(dataset["test"][:1])
    print("###")
    
    # Load model
    model, tokenizer = load_model(base_model_name)

    print("### Model")
    print(model)
    print("### Tokenizer")
    print(tokenizer)
    print("###")
    
    # Pre-process dataset
    
    def preprocess(examples):
        model_inputs = tokenizer(examples["sql_prompt"], text_target=examples["sql"]) #, max_length=512, padding="max_length", truncation=True)
        return model_inputs
        
    dataset = dataset.map(preprocess, batched=True)

    print("### Pre-processed dataset")
    print(dataset)
    print("### Train example")
    print(dataset["train"][:1])
    print("### Test example")
    print(dataset["test"][:1])
    print("###")
    
    # Split dataset into training and validation sets
    #train_dataset = dataset["train"]
    #test_dataset = dataset["test"]
    train_dataset = dataset["train"].shuffle(seed=42).select(range(100))
    test_dataset = dataset["test"].shuffle(seed=42).select(range(10))

    print("### Training dataset")
    print(train_dataset)
    print("### Validation dataset")
    print(test_dataset)
    print("###")
    
    # Configure training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir="./output",
        logging_dir="./logging",
        num_train_epochs=1,
        max_steps=2, ###
        #per_device_train_batch_size=16,
        #per_device_eval_batch_size=64,
        #eval_strategy="steps",
        #save_total_limit=2,
        #save_steps=500,
        #eval_steps=500,
        #warmup_steps=500,
        #weight_decay=0.01,
        #metric_for_best_model="accuracy",
        #greater_is_better=True,
        #load_best_model_at_end=True,
        #push_to_hub=True,
        #save_on_each_node=True,
    )

    print("### Training arguments")
    print(training_args)
    print("###")
    
    # Create trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        #compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
    )

    # Train and save model
    trainer.train()
    trainer.save_model()
    
def prompt_model(model_name, system_prompt, user_prompt, sql_schema):
    pipe = pipeline("text-generation", 
                    model=model_name,
                    device_map="auto",
                    max_new_tokens=1000)
    
    messages = [
      {"role": "system", "content": system_prompt.format(schema=sql_schema)},
      {"role": "user", "content": user_prompt},
      {"role": "assistant", "content": ""}
    ]
    
    output = pipe(messages)
    
    result = output[0]["generated_text"][-1]["content"]

    print("###")
    print(result)
    print("###")
    
    return result

def load_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    #tokenizer.pad_token = tokenizer.eos_token
    
    return model, tokenizer
    
demo = gr.Interface(fn=process, 
                    inputs=[gr.Radio([ACTION_1, ACTION_2, ACTION_3], label = "Action", value = ACTION_3),
                            gr.Textbox(label = "Base Model Name", value = BASE_MODEL_NAME, lines = 1),
                            gr.Textbox(label = "Fine-Tuned Model Name", value = FT_MODEL_NAME, lines = 1),
                            gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
                            gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
                            gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
                            gr.Textbox(label = "SQL Schema", value = SQL_SCHEMA, lines = 2)],
                    outputs=[gr.Textbox(label = "Prompt Completion", value = os.environ["OUTPUT"])])
demo.launch()