File size: 3,624 Bytes
ac66ae2
cbbb9fd
7465957
cbbb9fd
df44c11
083fde1
7465957
c8534fb
251d88f
7465957
c8534fb
 
 
251d88f
 
7465957
251d88f
c8534fb
 
df44c11
 
 
 
96210aa
cd72e2e
 
 
 
df44c11
 
 
 
 
 
 
 
 
 
 
 
 
 
df9a90e
186a997
df9a90e
483c87c
df44c11
 
 
 
 
9ae3ad1
df44c11
92146e5
cbbb9fd
2acdb22
c8534fb
 
251d88f
 
df44c11
 
 
92146e5
 
c8534fb
 
251d88f
cbbb9fd
251d88f
cbbb9fd
 
 
251d88f
cbbb9fd
0fb434b
251d88f
c8534fb
2b03f9f
 
 
cd72e2e
df44c11
 
 
 
 
 
a184b8b
 
 
2371111
df44c11
94ca6da
 
cd72e2e
df44c11
 
a184b8b
2371111
083fde1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import os
from datasets import load_dataset
from huggingface_hub import HfApi, login
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# Run on NVidia A10G Large (sleep after 1 hour)

# Model IDs:
#
# google/gemma-2-9b-it
# meta-llama/Meta-Llama-3-8B-Instruct

# Datasets:
#
# gretelai/synthetic_text_to_sql

profile = "bstraehle"

action_1 = "Prompt base model"
action_2 = "Prompt fine-tuned model"
action_3 = "Fine-tune base model"

schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
system_prompt = """You are a text to SQL query translator.
                   Given a question in English, generate a SQL query based on the provided SCHEMA.
                   Do not generate any additional text.
                   SCHEMA: """
user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"

base_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
fine_tuned_model_id = "bstraehle/Meta-Llama-3-8B-Instruct"
dataset = "gretelai/synthetic_text_to_sql"

def prompt_model(model_id, system_prompt, user_prompt):
    pipe = pipeline("text-generation", model=model_id)

    messages = [
      {"role": "system", "content": system_prompt},
      {"role": "user", "content": user_prompt},
      {"role": "assistant", "content": ""}
    ]

    output = pipe(messages, model_kwargs={"torch_dtype": torch.bfloat16}, device="cuda")
    
    return output[0]["generated_text"][-1]["content"]

def fine_tune_model(model_id):
    tokenizer = download_model(model_id)
    model_repo_name = upload_model(model_id, tokenizer)

    return model_repo_name
        
def download_model(model_id):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id)
    model.save_pretrained(model_id)

    return tokenizer

#def download_dataset(dataset):
#    ds = load_dataset(dataset)
#    return ""

def upload_model(model_id, tokenizer):
    model_name = model_id[model_id.rfind('/')+1:]
    model_repo_name = f"{profile}/{model_name}"

    login(token=os.environ["HF_TOKEN"])

    api = HfApi()
    api.create_repo(repo_id=model_repo_name)
    api.upload_folder(
        folder_path=model_id,
        repo_id=model_repo_name
    )

    tokenizer.push_to_hub(model_repo_name)

    return model_repo_name

def process(action, system_prompt, user_prompt, schema, base_model_id, fine_tuned_model_id, dataset):
    if action == action_1:
        result = prompt_model(base_model_id, system_prompt, user_prompt)
    elif action == action_2:
        result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt)
    elif action == action_3:
        result = fine_tune_model(base_model_id)
    
    return result

demo = gr.Interface(fn=process, 
                    inputs=[gr.Radio([action_1, action_2, action_3], label = "Action", value = action_2),
                            gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
                            gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
                            gr.Textbox(label = "Schema", value = schema, lines = 2),
                            gr.Textbox(label = "Base Model ID", value = base_model_id, lines = 1),
                            gr.Textbox(label = "Fine-Tuned Model ID", value = fine_tuned_model_id, lines = 1),
                            gr.Textbox(label = "Dataset", value = dataset, lines = 1)],
                    outputs=[gr.Textbox(label = "Completion")])
demo.launch()