File size: 3,645 Bytes
a70f9f8
ac66ae2
a993443
a70f9f8
 
a9bd106
d6a8f30
74c640a
c8534fb
e92ef1c
 
df44c11
38576e5
df44c11
467c88a
df44c11
a70f9f8
0aa11b1
df44c11
a9bd106
 
 
 
 
 
 
 
 
 
 
 
 
 
38576e5
03a8827
 
 
 
 
 
 
 
 
 
 
 
 
a9bd106
7f9f34a
613b540
 
 
 
 
 
ada7179
74c640a
cbbb9fd
 
c2cbd84
ada7179
cbbb9fd
ada7179
39546c6
0fb434b
ada7179
 
2b03f9f
74c640a
ada7179
74c640a
2b03f9f
2371111
835fa92
e6a4e68
13e776a
94ca6da
 
13e776a
73899fd
083fde1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# https://www.philschmid.de/fine-tune-llms-in-2024-with-trl#3-create-and-prepare-the-dataset
import gradio as gr
import os, torch
from datasets import load_dataset
from huggingface_hub import HfApi, login
from transformers import AutoTokenizer, AutoModelForCausalLM

hf_profile = "bstraehle"

action_1 = "Fine-tune pre-trained model"
action_2 = "Prompt fine-tuned model"

system_prompt = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"

base_model_id = "codellama/CodeLlama-7b-hf"
dataset = "b-mc2/sql-create-context"

def process(action, base_model_id, dataset, system_prompt, user_prompt, schema):
    #raise gr.Error("Please clone and bring your own credentials.")
    if action == action_1:
        result = fine_tune_model(base_model_id, dataset)
    elif action == action_2:
        fine_tuned_model_id = replace_hf_profile(base_model_id)
        result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
    return result

def fine_tune_model(base_model_id, dataset):
    tokenizer = download_model(base_model_id)
    fine_tuned_model_id = upload_model(base_model_id, tokenizer)
    return fine_tuned_model_id

def prompt_model(model_id, system_prompt, user_prompt, schema):
    pipe = pipeline("text-generation", 
                    model=model_id, 
                    model_kwargs={"torch_dtype": torch.bfloat16}, 
                    device_map="auto",
                    max_new_tokens=1000)
    messages = [
      {"role": "system", "content": system_prompt.format(schema=schema)},
      {"role": "user", "content": user_prompt},
      {"role": "assistant", "content": ""}
    ]
    output = pipe(messages)
    result = output[0]["generated_text"][-1]["content"]
    print(result)
    return result

def download_model(base_model_id):
    tokenizer = AutoTokenizer.from_pretrained(base_model_id)
    model = AutoModelForCausalLM.from_pretrained(base_model_id)
    model.save_pretrained(base_model_id)
    return tokenizer
    
def upload_model(base_model_id, tokenizer):
    fine_tuned_model_id = replace_hf_profile(base_model_id)
    login(token=os.environ["HF_TOKEN"])
    api = HfApi()
    #api.delete_repo(repo_id=fine_tuned_model_id, repo_type="model")
    api.create_repo(repo_id=fine_tuned_model_id)
    api.upload_folder(
        folder_path=base_model_id,
        repo_id=fine_tuned_model_id
    )
    tokenizer.push_to_hub(fine_tuned_model_id)
    return fine_tuned_model_id

def replace_hf_profile(base_model_id):
    model_id = base_model_id[base_model_id.rfind('/')+1:]
    return f"{hf_profile}/{model_id}"

demo = gr.Interface(fn=process, 
                    inputs=[gr.Radio([action_1, action_2], label = "Action", value = action_1),
                            gr.Textbox(label = "Base Model ID", value = base_model_id, lines = 1),
                            gr.Textbox(label = "Dataset", value = dataset, lines = 1),
                            gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
                            gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
                            gr.Textbox(label = "Schema", value = schema, lines = 2)],
                    outputs=[gr.Textbox(label = "Completion", value = os.environ["OUTPUT"])])
demo.launch()