File size: 2,495 Bytes
ac66ae2 cbbb9fd 7465957 cbbb9fd 2acdb22 083fde1 7465957 c8534fb 251d88f 7465957 c8534fb 251d88f 7465957 251d88f c8534fb 92146e5 cbbb9fd 2acdb22 c8534fb 251d88f 1c28313 7465957 1c28313 e66c7c3 92146e5 c8534fb 251d88f cbbb9fd 251d88f cbbb9fd 251d88f cbbb9fd 0fb434b 251d88f c8534fb 2b03f9f a184b8b 2371111 a184b8b 083fde1 2371111 a184b8b 2371111 083fde1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import gradio as gr
import os
from datasets import load_dataset
from huggingface_hub import HfApi, login
from transformers import AutoTokenizer, AutoModelForCausalLM
# Run on NVidia A10G Large (sleep after 1 hour)
# Model IDs:
#
# google/gemma-2-9b-it
# meta-llama/Meta-Llama-3-8B-Instruct
# Datasets:
#
# gretelai/synthetic_text_to_sql
profile = "bstraehle"
def download_model(model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
model.save_pretrained(model_id)
return tokenizer
def download_dataset(dataset):
ds = load_dataset(dataset)
return ""
def fine_tune_model():
return ""
def upload_model(model_id, tokenizer):
model_name = model_id[model_id.rfind('/')+1:]
model_repo_name = f"{profile}/{model_name}"
login(token=os.environ["HF_TOKEN"])
api = HfApi()
api.create_repo(repo_id=model_repo_name)
api.upload_folder(
folder_path=model_id,
repo_id=model_repo_name
)
tokenizer.push_to_hub(model_repo_name)
return model_repo_name
def process(action, system_prompt, user_prompt, model_id, dataset):
if action == "Prompt base model":
result = "Prompt base model"
elif action == "Fine-tuned base model":
result = "Fine-tuned base model"
elif action == "Prompt fine-tuned model":
result = "Prompt fine-tuned model"
#tokenizer = download_model(model_id)
#model_repo_name = upload_model(model_id, tokenizer)
return result
system_prompt = "You are a text to SQL query translator. Users will ask you a question in English and you will generate a SQL query."
user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
dataset = "gretelai/synthetic_text_to_sql"
demo = gr.Interface(fn=process,
inputs=[gr.Radio(["Prompt base model", "Fine-tune base model", "Prompt fine-tuned model"], label = "Action")],
gr.Textbox(label = "System Prompt", value = system_prompt, lines = 1),
gr.Textbox(label = "User Prompt", value = user_prompt, lines = 1),
gr.Textbox(label = "Model ID", value = model_id, lines = 1),
gr.Textbox(label = "Dataset", value = dataset, lines = 1)],
outputs=[gr.Textbox(label = "Completion")])
demo.launch() |