File size: 2,495 Bytes
ac66ae2
cbbb9fd
7465957
cbbb9fd
2acdb22
083fde1
7465957
c8534fb
251d88f
7465957
c8534fb
 
 
251d88f
 
7465957
251d88f
c8534fb
 
92146e5
cbbb9fd
2acdb22
c8534fb
 
251d88f
 
1c28313
7465957
1c28313
 
e66c7c3
92146e5
 
 
c8534fb
 
251d88f
cbbb9fd
251d88f
cbbb9fd
 
 
251d88f
cbbb9fd
0fb434b
251d88f
c8534fb
2b03f9f
 
 
a184b8b
 
 
 
 
 
 
2371111
a184b8b
 
 
 
 
 
 
 
 
083fde1
2371111
a184b8b
 
 
 
 
2371111
083fde1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
import os
from datasets import load_dataset
from huggingface_hub import HfApi, login
from transformers import AutoTokenizer, AutoModelForCausalLM

# Run on NVidia A10G Large (sleep after 1 hour)

# Model IDs:
#
# google/gemma-2-9b-it
# meta-llama/Meta-Llama-3-8B-Instruct

# Datasets:
#
# gretelai/synthetic_text_to_sql

profile = "bstraehle"

def download_model(model_id):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id)
    model.save_pretrained(model_id)

    return tokenizer

def download_dataset(dataset):
    ds = load_dataset(dataset)
    return ""
    
def fine_tune_model():
    return ""

def upload_model(model_id, tokenizer):
    model_name = model_id[model_id.rfind('/')+1:]
    model_repo_name = f"{profile}/{model_name}"

    login(token=os.environ["HF_TOKEN"])

    api = HfApi()
    api.create_repo(repo_id=model_repo_name)
    api.upload_folder(
        folder_path=model_id,
        repo_id=model_repo_name
    )

    tokenizer.push_to_hub(model_repo_name)

    return model_repo_name

def process(action, system_prompt, user_prompt, model_id, dataset):
    if action == "Prompt base model":
        result = "Prompt base model"
    elif action == "Fine-tuned base model":
        result = "Fine-tuned base model"
    elif action == "Prompt fine-tuned model":
        result = "Prompt fine-tuned model"
    
    #tokenizer = download_model(model_id)
    #model_repo_name = upload_model(model_id, tokenizer)
    
    return result

system_prompt = "You are a text to SQL query translator. Users will ask you a question in English and you will generate a SQL query."
user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
dataset = "gretelai/synthetic_text_to_sql"

demo = gr.Interface(fn=process, 
                    inputs=[gr.Radio(["Prompt base model", "Fine-tune base model", "Prompt fine-tuned model"], label = "Action")],
                            gr.Textbox(label = "System Prompt", value = system_prompt, lines = 1),
                            gr.Textbox(label = "User Prompt", value = user_prompt, lines = 1),
                            gr.Textbox(label = "Model ID", value = model_id, lines = 1),
                            gr.Textbox(label = "Dataset", value = dataset, lines = 1)],
                    outputs=[gr.Textbox(label = "Completion")])
demo.launch()