sft / app.py
bstraehle's picture
Update app.py
a184b8b verified
raw
history blame
2.5 kB
import gradio as gr
import os
from datasets import load_dataset
from huggingface_hub import HfApi, login
from transformers import AutoTokenizer, AutoModelForCausalLM
# Run on NVidia A10G Large (sleep after 1 hour)
# Model IDs:
#
# google/gemma-2-9b-it
# meta-llama/Meta-Llama-3-8B-Instruct
# Datasets:
#
# gretelai/synthetic_text_to_sql
profile = "bstraehle"
def download_model(model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
model.save_pretrained(model_id)
return tokenizer
def download_dataset(dataset):
ds = load_dataset(dataset)
return ""
def fine_tune_model():
return ""
def upload_model(model_id, tokenizer):
model_name = model_id[model_id.rfind('/')+1:]
model_repo_name = f"{profile}/{model_name}"
login(token=os.environ["HF_TOKEN"])
api = HfApi()
api.create_repo(repo_id=model_repo_name)
api.upload_folder(
folder_path=model_id,
repo_id=model_repo_name
)
tokenizer.push_to_hub(model_repo_name)
return model_repo_name
def process(action, system_prompt, user_prompt, model_id, dataset):
if action == "Prompt base model":
result = "Prompt base model"
elif action == "Fine-tuned base model":
result = "Fine-tuned base model"
elif action == "Prompt fine-tuned model":
result = "Prompt fine-tuned model"
#tokenizer = download_model(model_id)
#model_repo_name = upload_model(model_id, tokenizer)
return result
system_prompt = "You are a text to SQL query translator. Users will ask you a question in English and you will generate a SQL query."
user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
dataset = "gretelai/synthetic_text_to_sql"
demo = gr.Interface(fn=process,
inputs=[gr.Radio(["Prompt base model", "Fine-tune base model", "Prompt fine-tuned model"], label = "Action")],
gr.Textbox(label = "System Prompt", value = system_prompt, lines = 1),
gr.Textbox(label = "User Prompt", value = user_prompt, lines = 1),
gr.Textbox(label = "Model ID", value = model_id, lines = 1),
gr.Textbox(label = "Dataset", value = dataset, lines = 1)],
outputs=[gr.Textbox(label = "Completion")])
demo.launch()