File size: 6,431 Bytes
c713850
a28b4b8
ac66ae2
a993443
a70f9f8
256580a
e67b0f1
bf28b8c
 
e760c8d
2dfbd8a
 
 
c8534fb
3d0bfc5
 
da6722c
 
 
df44c11
2dfbd8a
3d0bfc5
2dfbd8a
df44c11
da6722c
a9bd106
5b45741
da6722c
5b45741
2dfbd8a
5b45741
da6722c
a9bd106
 
75f5c42
88543e6
5e0038e
092da5d
9cb6b16
092da5d
 
 
 
 
88543e6
 
5e0038e
092da5d
ffef239
092da5d
 
 
 
 
7be2c23
88543e6
8cdd9a7
3d77c48
c06669a
3d77c48
8cdd9a7
092da5d
93508c3
092da5d
 
 
 
 
88543e6
53b729b
5e0038e
9d8f256
53b729b
88543e6
092da5d
 
53b729b
 
092da5d
88543e6
 
1939ff5
bf28b8c
3d0bfc5
9b16331
 
1939ff5
53b729b
76d0fb3
88543e6
 
 
 
9b16331
7826053
9b16331
e67b0f1
 
7826053
7be2c23
 
 
 
e67b0f1
 
8f5a0c7
e67b0f1
76d0fb3
e67b0f1
2fb4377
e67b0f1
 
 
bf28b8c
 
 
 
 
 
 
 
88543e6
1939ff5
bf28b8c
 
bb99aa8
 
53b729b
 
bb99aa8
76d0fb3
8f45dd8
5e0038e
7826053
5e0038e
3d0bfc5
256580a
7826053
b85865d
da6722c
03a8827
88543e6
03a8827
 
88543e6
03a8827
da6722c
03a8827
 
 
88543e6
03a8827
88543e6
03a8827
7f9f34a
88543e6
 
 
613b540
88543e6
2b03f9f
88543e6
e67b0f1
1afcb19
10e80e0
160048e
1fca62f
88543e6
2371111
b4de4c9
2dfbd8a
3d0bfc5
2dfbd8a
 
 
cf51f99
88543e6
083fde1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# Run on Google TPU v5e 2x4 or equivalent (220 vCPU, 380 GB RAM, 128 GB VRAM)

import gradio as gr
import os, torch
from datasets import load_dataset
from huggingface_hub import HfApi, login
#from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, Trainer, TrainingArguments, pipeline

ACTION_1 = "Prompt base model"
ACTION_2 = "Fine-tune base model"
ACTION_3 = "Prompt fine-tuned model"

HF_ACCOUNT = "bstraehle"

SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SQL_CONTEXT. Do not generate any additional text. SQL_CONTEXT: {sql_context}"
USER_PROMPT = "How many new users joined from countries with stricter data privacy laws than the United States in the past month?"
SQL_CONTEXT = "CREATE TABLE users (user_id INT, country VARCHAR(50), joined_date DATE); CREATE TABLE data_privacy_laws (country VARCHAR(50), privacy_level INT); INSERT INTO users (user_id, country, joined_date) VALUES (1, 'USA', '2023-02-15'), (2, 'Germany', '2023-02-27'); INSERT INTO data_privacy_laws (country, privacy_level) VALUES ('USA', 5), ('Germany', 8);"

BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
FT_MODEL_NAME = "Meta-Llama-3.1-8B-Instruct-text-to-sql"
DATASET_NAME = "gretelai/synthetic_text_to_sql"

def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt, sql_context):
    #raise gr.Error("Please clone and bring your own credentials.")
    if action == ACTION_1:
        result = prompt_model(base_model_name, system_prompt, user_prompt, sql_context)
    elif action == ACTION_2:
        result = fine_tune_model(base_model_name, dataset_name)
    elif action == ACTION_3:
        result = prompt_model(ft_model_name, system_prompt, user_prompt, sql_context)
    return result

def fine_tune_model(base_model_name, dataset_name):
    # Load dataset
    
    dataset = load_dataset(dataset_name)

    print("### Dataset")
    print(dataset)
    print("### Example")
    print(dataset["train"][:1])
    print("###")
    
    # Load model
    
    model, tokenizer = load_model(base_model_name)

    print("### Model")
    print(model)
    print("### Tokenizer")
    print(tokenizer)
    print("###")
        
    # Pre-process dataset
    
    def preprocess(examples):
        model_inputs = tokenizer(examples["sql_prompt"], text_target=examples["sql"], max_length=512, padding="max_length", truncation=True)
        return model_inputs
        
    dataset = dataset.map(preprocess, batched=True)

    print("### Pre-processed dataset")
    print(dataset)
    print("### Example")
    print(dataset["train"][:1])
    print("###")
    
    # Split dataset into training and evaluation sets
    
    train_dataset = dataset["train"]
    eval_dataset = dataset["test"]

    print("### Training dataset")
    print(train_dataset)
    print("### Evaluation dataset")
    print(eval_dataset)
    print("###")
    
    # Configure training arguments

    training_args = TrainingArguments(
        output_dir=f"./{FT_MODEL_NAME}",
        num_train_epochs=3, # 37,500 steps
        max_steps=1, # overwrites num_train_epochs
        push_to_hub=True, # only pushes model, also need to push tokenizer (see below)
        # TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
    )

    print("### Training arguments")
    print(training_args)
    print("###")

    # PEFT

    #lora_config = LoraConfig(
    #    r=16,
        # TODO https://www.philschmid.de/fine-tune-flan-t5-peft
        #bias="none",
        #lora_alpha=32,
        #lora_dropout=0.05,
        #target_modules=["q", "v"],
    #    task_type=TaskType.SEQ_2_SEQ_LM,
    #)

    #model = prepare_model_for_kbit_training(model)
    
    #model = get_peft_model(model, lora_config)
    
    #print("### PEFT")
    #model.print_trainable_parameters() # trainable params: 6,815,744 || all params: 8,037,076,992 || trainable%: 0.0848
    #print("###")

    peft_model = PeftModel.from_pretrained(
        BASE_MODEL_NAME,
        tokenizer=tokenizer,
        adapter_name="lora",
        adapter_dim=16,
    )
    
    # Create trainer

    trainer = Trainer(
        model=peft_model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        # TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
    )

    # Train model
    
    trainer.train()

    # Push tokenizer to HF

    tokenizer.push_to_hub(FT_MODEL_NAME)
    
def prompt_model(model_name, system_prompt, user_prompt, sql_context):
    pipe = pipeline("text-generation", 
                    model=model_name,
                    device_map="auto",
                    max_new_tokens=1000)
    
    messages = [
      {"role": "system", "content": system_prompt.format(sql_context=sql_context)},
      {"role": "user", "content": user_prompt},
      {"role": "assistant", "content": ""}
    ]
    
    output = pipe(messages)
    
    result = output[0]["generated_text"][-1]["content"]

    print("###")
    print(result)
    print("###")
    
    return result

def load_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    return model, tokenizer
    
demo = gr.Interface(fn=process, 
                    inputs=[gr.Radio([ACTION_1, ACTION_2, ACTION_3], label = "Action", value = ACTION_3),
                            gr.Textbox(label = "Base Model Name", value = BASE_MODEL_NAME, lines = 1),
                            gr.Textbox(label = "Fine-Tuned Model Name", value = f"{HF_ACCOUNT}/{FT_MODEL_NAME}", lines = 1),
                            gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
                            gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
                            gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
                            gr.Textbox(label = "SQL Context", value = SQL_CONTEXT, lines = 4)],
                    outputs=[gr.Textbox(label = "Prompt Completion", value = os.environ["OUTPUT"])])
demo.launch()