File size: 6,232 Bytes
a70f9f8
ac66ae2
a993443
a70f9f8
 
ffef239
d6a8f30
e760c8d
 
74c640a
c8534fb
e92ef1c
 
df44c11
38576e5
df44c11
467c88a
df44c11
361f8dd
0aa11b1
df44c11
a9bd106
 
 
 
 
 
 
 
 
 
76d0fb3
 
 
 
 
ffef239
 
 
 
76d0fb3
 
 
ffef239
87a35cb
3d77c48
 
 
08c0eb5
3d77c48
6bb7c86
76d0fb3
 
 
 
 
93508c3
76d0fb3
 
 
 
4f06478
76d0fb3
 
 
 
 
 
 
 
 
 
 
1e971a3
e8d1605
76d0fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9bd106
38576e5
03a8827
 
 
 
 
 
 
 
 
 
 
 
 
a9bd106
7f9f34a
613b540
46ea1b4
241fd2c
613b540
 
 
ada7179
74c640a
cbbb9fd
 
c2cbd84
ada7179
cbbb9fd
ada7179
39546c6
0fb434b
ada7179
 
2b03f9f
74c640a
ada7179
74c640a
2b03f9f
2371111
835fa92
e6a4e68
13e776a
94ca6da
 
13e776a
73899fd
083fde1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# https://www.philschmid.de/fine-tune-llms-in-2024-with-trl#3-create-and-prepare-the-dataset
import gradio as gr
import os, torch
from datasets import load_dataset
from huggingface_hub import HfApi, login
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

hf_profile = "bstraehle"

action_1 = "Fine-tune pre-trained model"
action_2 = "Prompt fine-tuned model"

system_prompt = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"

base_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset = "b-mc2/sql-create-context"

def process(action, base_model_id, dataset, system_prompt, user_prompt, schema):
    #raise gr.Error("Please clone and bring your own credentials.")
    if action == action_1:
        result = fine_tune_model(base_model_id, dataset)
    elif action == action_2:
        fine_tuned_model_id = replace_hf_profile(base_model_id)
        result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
    return result

def fine_tune_model(base_model_id, dataset):
#    tokenizer = download_model(base_model_id)
#    fine_tuned_model_id = upload_model(base_model_id, tokenizer)
#    return fine_tuned_model_id
    # Load the dataset
    dataset = load_dataset("gretelai/synthetic_text_to_sql")

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
    )
    
    # Load pre-trained model and tokenizer
    model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=bnb_config)    
    tokenizer = AutoTokenizer.from_pretrained(model_name)    

    # Preprocess the dataset
    def preprocess(examples):
        model_inputs = tokenizer(examples["sql_prompt"], text_target=examples["sql"], max_length=512, truncation=True)
        return model_inputs
    
    dataset = dataset.map(preprocess, batched=True)
    
    # Split dataset to training and validation sets
    train_dataset = dataset["train"].shuffle(seed=42).select(range(1000))  # Adjust the range as needed
    val_dataset = dataset["test"].shuffle(seed=42).select(range(100))  # Adjust the range as needed

    # Set training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir="./results",
        num_train_epochs=3,  # Adjust as needed
        per_device_train_batch_size=8,
        per_device_eval_batch_size=64,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir="./logs",
        save_total_limit=2,
        save_steps=500,
        eval_steps=500,
        metric_for_best_model="accuracy",
        greater_is_better=True,
        save_on_each_node=True,
        load_best_model_at_end=True,
        eval_strategy="steps",
        gradient_checkpointing=True,
    )
    
    # Create Trainer instance
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
    )
    
    # Train the model
    trainer.train()

    # Save the trained model
    trainer.save_model("./fine_tuned_model")

    # Create a repository object
    repo = Repository(
        local_dir="./fine_tuned_model",
        repo_type="model",
        repo_id="bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql",
    )
    
    # Login to the Hugging Face hub
    repo.login(token=os.environ["HF_TOKEN"])
    
    # Push the model to the hub
    repo.push_to_hub(commit_message="Initial commit")

def prompt_model(model_id, system_prompt, user_prompt, schema):
    pipe = pipeline("text-generation", 
                    model=model_id, 
                    model_kwargs={"torch_dtype": torch.bfloat16}, 
                    device_map="auto",
                    max_new_tokens=1000)
    messages = [
      {"role": "system", "content": system_prompt.format(schema=schema)},
      {"role": "user", "content": user_prompt},
      {"role": "assistant", "content": ""}
    ]
    output = pipe(messages)
    result = output[0]["generated_text"][-1]["content"]
    print(result)
    return result

def download_model(base_model_id):
    tokenizer = AutoTokenizer.from_pretrained(base_model_id)    
    model = AutoModelForCausalLM.from_pretrained(base_model_id)    
    model.save_pretrained(base_model_id)
    return tokenizer
    
def upload_model(base_model_id, tokenizer):
    fine_tuned_model_id = replace_hf_profile(base_model_id)
    login(token=os.environ["HF_TOKEN"])
    api = HfApi()
    #api.delete_repo(repo_id=fine_tuned_model_id, repo_type="model")
    api.create_repo(repo_id=fine_tuned_model_id)
    api.upload_folder(
        folder_path=base_model_id,
        repo_id=fine_tuned_model_id
    )
    tokenizer.push_to_hub(fine_tuned_model_id)
    return fine_tuned_model_id

def replace_hf_profile(base_model_id):
    model_id = base_model_id[base_model_id.rfind('/')+1:]
    return f"{hf_profile}/{model_id}"

demo = gr.Interface(fn=process, 
                    inputs=[gr.Radio([action_1, action_2], label = "Action", value = action_1),
                            gr.Textbox(label = "Base Model ID", value = base_model_id, lines = 1),
                            gr.Textbox(label = "Dataset", value = dataset, lines = 1),
                            gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
                            gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
                            gr.Textbox(label = "Schema", value = schema, lines = 2)],
                    outputs=[gr.Textbox(label = "Completion", value = os.environ["OUTPUT"])])
demo.launch()